from src import WorkflowPaths, Args, validate_config

config = validate_config(config) # validate the Snakemake config
paths = WorkflowPaths(config) # object handling the paths to the files that will be created
args = Args(paths, config) # object handling the arguments passed to the workflow

### Segmentation rules
include: "rules/utils.smk"
include: "rules/cellpose.smk"
include: "rules/comseg.smk"
include: "rules/baysor.smk"
include: "rules/proseg.smk"
include: "rules/stardist.smk"

localrules: all

rule all:
    input:
        paths.annotations if args.annotate else [],
        paths.explorer_experiment,
        paths.smk_explorer_raw,
        paths.report,
    params:
        sdata_path = paths.sdata_path,
        explorer_directory = paths.explorer_directory,
        explorer_experiment = paths.explorer_experiment,
        report = paths.report,
    shell:
        """
        echo 🎉 Successfully run sopa
        echo → SpatialData output directory: {params.sdata_path}
        echo → Explorer output directory: {params.explorer_directory}
        echo → Open the sopa report: 'open {params.report}'
        echo → Open the result in the explorer: 'open {params.explorer_experiment}'
        """

rule to_spatialdata:
    input:
        [] if config["is_toy_reader"] else paths.data_path,
    output:
        paths.sdata_json_attrs if paths.data_path else [],
    conda:
        "sopa"
    params:
        reader = args['read'].as_cli(),
        data_path = paths.data_path,
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa convert {params.data_path} --sdata-path {params.sdata_path} {params.reader}
        """

rule tissue_segmentation:
    input:
        paths.sdata_json_attrs if paths.data_path else [],
    output:
        touch(paths.segmentation_done("tissue")),
    conda:
        "sopa"
    params:
        tissue_segmentation = args["segmentation"]["tissue"].as_cli(),
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa segmentation tissue {params.sdata_path} {params.tissue_segmentation}
        """

checkpoint patchify_image:
    input:
        paths.sdata_json_attrs,
        paths.segmentation_done("tissue") if args.use("tissue") else [],
    output:
        patches_file = paths.smk_patches_file_image,
        patches = touch(paths.smk_patches),
    params:
        patchify_image = args["patchify"].as_cli(contains="pixel"),
        sdata_path = paths.sdata_path,
    conda:
        "sopa"
    shell:
        """
        sopa patchify image {params.sdata_path} {params.patchify_image}
        """

checkpoint patchify_transcripts:
    input:
        paths.sdata_json_attrs,
        paths.segmentation_done("cellpose") if args.use("cellpose") else [],
        paths.segmentation_done("tissue") if args.use("tissue") else [],
    output:
        directory(paths.smk_transcripts_temp_dir),
        patches_file = paths.smk_patches_file_transcripts,
    params:
        patchify_transcripts = args.patchify_transcripts(),
        sdata_path = paths.sdata_path,
    conda:
        "sopa"
    shell:
        """
        sopa patchify transcripts {params.sdata_path} {params.patchify_transcripts}
        """

rule aggregate:
    input:
        args.segmentation_boundaries(),
    output:
        touch(paths.smk_aggregation),
    conda:
        "sopa"
    params:
        aggregate = args["aggregate"].as_cli(),
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa aggregate {params.sdata_path} {params.aggregate}
        """

rule annotate:
    input:
        paths.smk_aggregation,
    output:
        directory(paths.annotations),
    conda:
        "sopa"
    resources:
        partition="gpgpuq" if args['annotation']['method'] == "tangram" else "shortq",
        gpu="a100:1" if args['annotation']['method'] == "tangram" else 0,
    params:
        method_name = args['annotation']['method'],
        annotation = args['annotation']['args'].as_cli(),
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa annotate {params.method_name} {params.sdata_path} {params.annotation}
        """

rule scanpy_preprocess:
    input:
        args.segmentation_boundaries(),
        paths.smk_aggregation,
        paths.annotations if args.annotate else [],
    output:
        paths.smk_scanpy_preprocess,
    params:
        sdata_path = paths.sdata_path,
        scanpy_preprocess = args["scanpy_preprocess"].as_cli(),
    conda:
        "sopa"
    shell:
        """
        sopa scanpy-preprocess {params.sdata_path} {params.scanpy_preprocess}
        """

rule explorer_raw:
    input:
        paths.sdata_json_attrs,
    output:
        touch(paths.smk_explorer_raw),
    conda:
        "sopa"
    params:
        explorer = args["explorer"].as_cli(keys=['lazy', 'ram_threshold_gb', 'pixel_size', 'pixelsize']),
        sdata_path = paths.sdata_path,
        explorer_directory = paths.explorer_directory,
    shell:
        """
        sopa explorer write {params.sdata_path} --output-path {params.explorer_directory} {params.explorer} --mode "+it" --no-save-h5ad
        """

rule explorer:
    input:
        args.segmentation_boundaries(),
        paths.smk_aggregation,
        paths.annotations if args.annotate else [],
        paths.smk_scanpy_preprocess if "scanpy_preprocess" in config else [],
    output:
        paths.explorer_experiment,
    conda:
        "sopa"
    params:
        explorer = args["explorer"].as_cli(),
        sdata_path = paths.sdata_path,
        explorer_directory = paths.explorer_directory,
    shell:
        """
        sopa explorer write {params.sdata_path} --output-path {params.explorer_directory} {params.explorer} --mode "-it"
        """

rule report:
    input:
        args.segmentation_boundaries(),
        paths.smk_aggregation,
        paths.annotations if args.annotate else [],
        paths.smk_scanpy_preprocess if "scanpy_preprocess" in config else [],
    output:
        paths.report,
    params:
        sdata_path = paths.sdata_path,
        report = paths.report,
    conda:
        "sopa"
    shell:
        """
        sopa report {params.sdata_path} {params.report}
        """
