"""
# Stage 05 Channel-wise Wave Characterization
"""

from pathlib import Path
configfile: Path('configs') / 'config_template.yaml'
include: Path() / '..' / 'utils' / 'Snakefile'

#### Housekeeping ####

def measures_output(wildcards):
    return [OUTPUT_DIR / measure / str(config.EVENT_NAME + "_" + measure + ".csv")
            for measure in config.MEASURES]

if config.EVENT_NAME == 'wavemodes':
    config.MEASURES = [m for m in config.MEASURES if m not in
                        ['inter_wave_interval_local', 'flow_direction_local']]

def input(wildcards):
    spatial_measures = ['velocity_local', 'direction_local']
    if wildcards.measure in spatial_measures:
        return rules.spatial_derivative.output
    else:
        return config.STAGE_INPUT

ruleorder: spatial_derivative > compute_measure

#### UTILITY BLOCKS ####

use rule template as all with:
    input:
        check = OUTPUT_DIR / 'input.check',
        data = measures_output,
        script = SCRIPTS / 'merge_dataframes.py',
        # configfile = Path('configs') / f'config_{config.PROFILE}.yaml'
    params:
        params(merge_key = ['channel_id', str(config.EVENT_NAME + "_id")])
    output:
        OUTPUT_DIR / config.STAGE_OUTPUT,
        output_img = OUTPUT_DIR / 'overview_measures.html'

#### CHARACTERIZATION BLOCKS ####

use rule template as compute_measure with:
    input:
        data = input,
        script = SCRIPTS / str('{measure}' + ".py")
    params:
        params(config.__dict__)
    output:
        Path('{dir}') / '{measure}' / str(config.EVENT_NAME + "_" + '{measure}' + ".csv"),
        output_img = Path('{dir}') / '{measure}'
                      / str(config.EVENT_NAME + "_" + '{measure}' + "." + config.PLOT_FORMAT)


use rule compute_measure as spatial_derivative with:
    input:
        data = config.STAGE_INPUT,
        script = SCRIPTS / 'spatial_derivative.py'
    output:
        Path('{dir}') / '{rule_name}'
                      / str(config.EVENT_NAME + "_spatial_derivative.csv"),
        output_img = Path('{dir}') / '{rule_name}'
               / str(config.EVENT_NAME + "_spatial_derivative." + config.PLOT_FORMAT)
