"""
# Stage 02 Processing
"""

from pathlib import Path
configfile: Path('configs') / 'config_template.yaml'
include: Path() / '..' / 'utils' / 'Snakefile'

#### Housekeeping ####

def input_file(wildcards):
    return prev_rule_output(wildcards, rule_list=config.BLOCK_ORDER)

def is_clear(wildcards):
    if config.RERUN_MODE:
        return Path(wildcards.dir) / "clear.done"
    else:
        return []

#### UTILITY BLOCKS ####

use rule template_all as all with:
    input:
        check = OUTPUT_DIR / 'input.check',
        data = input_file,
        img = OUTPUT_DIR / str("processed_traces_%s-%ss" % (config.PLOT_TSTART,config.PLOT_TSTOP)),
        # configfile = Path('configs') / f'config_{PROFILE}.yaml'

rule clear:
    output:
        temp(Path('{path}') / 'clear.done')
    params:
        block_folder = [Path('{path}') / block for block in config.BLOCK_ORDER]
    shell:
        """
        rm -rf {params.block_folder:q}
        touch {output:q}
        """


use rule template as plot_processed_traces with:
    input:
        data = input_file,
        script = SCRIPTS / 'plot_processed_trace.py'
    params:
        params(channels=config.PLOT_CHANNELS,
               img_name="processed_trace_channel0." + config.PLOT_FORMAT,
               original_data=config.STAGE_INPUT)
    output:
        img_dir = directory(OUTPUT_DIR /
                            str("processed_traces_{t_start}-{t_stop}s")


use rule template as plot_power_spectrum with:
    input:
        data = input_file,
        script = SCRIPTS / 'plot_power_spectrum.py'
    output:
        Path('{dir}') / '{rule_name}' / str("power_spectrum." + config.PLOT_FORMAT)
    params:
        params('highpass_frequency', 'lowpass_frequency', 'psd_frequency_resolution', 'psd_overlap',
                config=config)

#### PROCESSING BLOCKS (choose any)####

use rule template as background_subtraction with:
    input:
        is_clear = is_clear,
        data = input_file,
        script = SCRIPTS / 'background_subtraction.py'
    params:
        params()
    output:
        Path('{dir}') / '{rule_name}' / str("background_subtraction." + config.NEO_FORMAT),
        output_img = Path('{dir}') / '{rule_name}' / str("background." + config.PLOT_FORMAT),
        output_array = Path('{dir}') / '{rule_name}' / 'background.npy'


use rule template as spatial_downsampling with:
    input:
        is_clear = is_clear,
        data = input_file,
        script = SCRIPTS / 'spatial_downsampling.py'
    output:
        Path('{dir}') / '{rule_name}' / str("spatial_downsampling." + config.NEO_FORMAT),
        output_img = Path('{dir}') / '{rule_name}' / str("spatial_downsampling." + config.PLOT_FORMAT)
    params:
        params(macro_pixel_dim=config.MACRO_PIXEL_DIM)


use rule template as normalization with:
    input:
        is_clear = is_clear,
        data = input_file,
        script = SCRIPTS / 'normalization.py'
    params:
        params('normalize_by', config=config)
    output:
        Path('{dir}') / '{rule_name}' / str("normalization." + config.NEO_FORMAT)


use rule template as detrending with:
    input:
        is_clear = is_clear,
        data = input_file,
        script = SCRIPTS / 'detrending.py'
    params:
        params(order = config.DETRENDING_ORDER,
               plot_channels = config.PLOT_CHANNELS,
               img_name = "detrending_channel0." + config.PLOT_FORMAT)
    output:
        Path('{dir}') / '{rule_name}' / str("detrending." + config.NEO_FORMAT),
        img_dir = directory(Path('{dir}') / '{rule_name}' / 'detrending_plots')


use rule template as frequency_filter with:
    input:
        is_clear = is_clear,
        data = input_file,
        img = Path('{dir}') / '{rule_name}' / str("power_spectrum." + config.PLOT_FORMAT),
        script = SCRIPTS / 'frequency_filter.py'
    params:
        params('highpass_frequency', 'lowpass_frequency', 'filter_function',
                order=config.FILTER_ORDER, config=config)
    output:
        Path('{dir}') / '{rule_name}' /  str("frequency_filter." + config.NEO_FORMAT)


use rule template as roi_selection with:
    input:
        is_clear = is_clear,
        data = input_file,
        script = SCRIPTS / 'roi_selection.py'
    params:
        params('intensity_threshold', 'crop_to_selection', config=config)
    output:
        Path('{dir}') / '{rule_name}' / str("roi_selection." + config.NEO_FORMAT),
        output_img =  Path('{dir}') / '{rule_name}' / str("roi_selection." + config.PLOT_FORMAT)


use rule template as logMUA_estimation with:
    input:
        is_clear = is_clear,
        data = input_file,
        script = SCRIPTS / 'logMUA_estimation.py'
    params:
        params('fft_slice', 'psd_overlap', 'plot_tstart', 'plot_tstop',
               'plot_channels', config=config,
               highpass_frequency = config.MUA_HIGHPASS_FREQUENCY,
               lowpass_frequency = config.MUA_LOWPASS_FREQUENCY,
               logMUA_rate = config.logMUA_RATE,
               img_name = "logMUA_trace_channel0." + config.PLOT_FORMAT)
    output:
        Path('{dir}') / '{rule_name}' / str("logMUA_estimation." + config.NEO_FORMAT),
        img_dir =  directory(Path('{dir}') / '{rule_name}' / 'logMUA_estimation_plots')


use rule template as phase_transform with:
    input:
        is_clear = is_clear,
        data = input_file,
        script = SCRIPTS / 'phase_transform.py'
    params:
        params()
    output:
        Path('{dir}') / '{rule_name}' /  str("phase_transform." + config.NEO_FORMAT)


use rule template as zscore with:
    input:
        is_clear = is_clear,
        data = input_file,
        script = SCRIPTS / 'zscore.py'
    params:
        params()
    output:
        Path('{dir}') / '{rule_name}' / str("zscore." + config.NEO_FORMAT)


use rule template as subsampling with:
    input:
        is_clear = is_clear,
        data = input_file,
        script = SCRIPTS / 'subsampling.py'
    params:
        params(target_rate = config.TARGET_RATE)
    output:
        Path('{dir}') / '{rule_name}' / str("subsampling." + config.NEO_FORMAT)
