"""
# Stage 04 Wave Detection
"""

from pathlib import Path
configfile: Path('configs') / 'config_template.yaml'
include: Path() / '..' / 'utils' / 'Snakefile'

#### UTILITY BLOCKS ####

use rule template_all as all with:
    input:
        check = OUTPUT_DIR / 'input.check',
        data = OUTPUT_DIR / 'merge_wave_definitions' / config.STAGE_OUTPUT,
        # img = OUTPUT_DIR / 'wave_plots'
        # configfile = Path('configs') / f'config_{PROFILE}.yaml',


def additional_properties(wildcards):
    return [Path(wildcards.dir) / prop / str(prop + "." + config.NEO_FORMAT)
            for prop in config.ADDITIONAL_PROPERTIES]

rule merge_wave_definitions:
    input:
        data = Path('{dir}') / config.DETECTION_BLOCK / config.STAGE_OUTPUT,
        additional_data = additional_properties,
        script = SCRIPTS / 'merge_wave_definitions.py'
    output:
        Path('{dir}') / 'merge_wave_definitions' / config.STAGE_OUTPUT
    shell:
        """
        {ADD_UTILS}
        python3 {input.script:q} --data {input.data:q} \
                                 --properties {input.additional_data:q} \
                                 --output {output:q}
        """

#### DETECTION BLOCK ####

use rule template_plus_plot_script as trigger_clustering with:
    input:
        data = config.STAGE_INPUT,
        script = SCRIPTS / 'trigger_clustering.py',
        plot_script = SCRIPTS / 'plot_clustering.py'
    params:
        params('metric', 'time_space_ratio', 'neighbour_distance',
               config=config, time_slice=config.PLOT_TSTOP,
               min_samples=config.MIN_SAMPLES_PER_WAVE)
    output:
        Path('{dir}') / 'trigger_clustering' / config.STAGE_OUTPUT,
        img = Path('{dir}') / 'trigger_clustering'
                            / str("wave_cluster." + config.PLOT_FORMAT)

#### ADDITIONAL PROPERTIES ####

use rule template as optical_flow with:
    input:
        data = config.STAGE_INPUT,
        script = SCRIPTS / 'optical_flow.py'
    params:
        params('alpha', 'max_Niter', 'convergence_limit', 'gaussian_sigma',
               'derivative_filter', 'use_phases', config=config)
    output:
        Path('{dir}') / '{rule_name}' / str("optical_flow." + config.NEO_FORMAT),
        output_img = Path('{dir}') / '{rule_name}' / str("optical_flow." + config.PLOT_FORMAT)


use rule template_plus_plot_script as critical_points with:
    input:
        data = Path('{dir}') / 'optical_flow' / str("optical_flow." + config.NEO_FORMAT),
        script = SCRIPTS / 'critical_points.py',
        plot_script = SCRIPTS / 'plot_critical_points.py',
    params:
        params(frame_id=0, skip_step=1)
    output:
        Path('{dir}') / '{rule_name}' / str("critical_points." + config.NEO_FORMAT),
        img = Path('{dir}') / '{rule_name}' / str("critical_points." + config.PLOT_FORMAT)


# use rule template as critical_points_clustering with:
#     # ToDo
#     input:
#         data = Path('{dir}') / 'critical_points' / config.STAGE_OUTPUT
#     params:
#     output:
#         Path('{dir}') / 'critical_points_clustering' / config.STAGE_OUTPUT
#     shell:
#         """
#         {ADD_UTILS}
#         cp {input.data:q} {output.data:q}
#         """


use rule template as wave_mode_clustering with:
    input:
        data = rules.trigger_clustering.output,
        script = SCRIPTS / 'wave_mode_clustering.py'
    params:
        params('min_trigger_fraction', 'num_wave_neighbours', 'pca_dims',
               'wave_outlier_quantile', 'num_kmeans_cluster',
               'interpolation_step_size', 'interpolation_smoothing', config=config)
    output:
        Path('{dir}') / '{rule_name}' / str("wave_mode_clustering." + config.NEO_FORMAT),
        output_img = Path('{dir}') / '{rule_name}' / str("wave_mode_clustering." + config.PLOT_FORMAT)

#### PLOTTING ####

use rule template as plot_waves with:
    input:
        data = Path('{dir}') / 'merge_wave_definitions' / config.STAGE_OUTPUT,
        script = SCRIPTS / 'plot_waves.py'
    params:
        params(time_window = 0.4,  # in s
               colormap="viridis",
               img_name="wave_plot_id0." + config.PLOT_FORMAT)
    output:
        directory(Path('{dir}') / 'wave_plots')


#### MOVIE PLOTTING ####

use rule template as time_slice with:
    input:
        data = Path('{dir}') / str('{data_name}' + "." + config.NEO_FORMAT),
        script = SCRIPTS / 'time_slice.py'
    params:
        params()
    output:
        Path('{dir}') / str('{data_name}_{t_start}-{t_stop}s.' + config.NEO_FORMAT)


use rule template as plot_movie_frames with:
    input:
        data = Path('{dir}') / str('{data_name}' + "." + config.NEO_FORMAT),
        script = SCRIPTS / 'plot_movie_frames.py'
    params:
        params('colormap', 'frame_rate', 'marker_color', 'plot_event',
               config=config, frame_name='frame', frame_format='png')
    output:
        frame_folder = directory(Path('{dir}') / str('{data_name}' + "_frames"))


rule plot_movie:
    input:
        Path('{dir}') / str('{data_name}' + "_frames")
    output:
        Path('{dir}') / str('{data_name}' + ".mp4")
    params:
        frame_path = lambda wildcards, input: Path(input[0],
                                                'frame_%05d.{}'\
                                                .format(config.PLOT_FORMAT)),
        quality = config.QUALITY,
        scale_x = config.SCALE_X,
        scale_y = config.SCALE_Y,
        bitrate = config.BITRATE,
        fps = config.FPS
    shell:
        """
        ffmpeg -y -framerate {params.fps} \
               -i {params.frame_path:q} \
               -crf {params.quality} \
               -vb {params.bitrate} \
               -vcodec libx264 \
               -vf scale={params.scale_x}:{params.scale_y} \
               {output:q}
        """

rule mp4_to_webm:
    input:
        Path('{path}.mp4')
    params:
        quality = config.QUALITY,
        scale_x = config.SCALE_X,
        scale_y = config.SCALE_Y,
        bitrate = config.BITRATE
    output:
        Path('{path}.webm')
    shell:
        """
        ffmpeg -i {input:q} \
               -vcodec libvpx \
               -acodec libvorbis \
               -crf {params.quality} \
               -vb {params.bitrate} \
               -vf scale={params.scale_x}:{params.scale_y} \
               {output:q}
        """

rule mp4_to_gif:
    input:
        Path('{path}.mp4')
    output:
        Path('{path}' + "_" + '{scale}' + "px.gif")
    shell:
        """
        ffmpeg -i {input:q} \
        -vf "scale={wildcards.scale}:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" \
        -loop 0 {output:q}
        """
