from src import WorkflowPaths, Args, validate_config

config = validate_config(config) # validate the Snakemake config
paths = WorkflowPaths(config) # object handling the paths to the files that will be created
args = Args(paths, config) # object handling the arguments passed to the workflow

localrules: all

rule all:
    input:
        paths.annotations if args.annotate else [],
        paths.explorer_experiment,
        paths.explorer_image,
        paths.report,
    params:
        sdata_path = paths.sdata_path,
        explorer_directory = paths.explorer_directory,
        explorer_experiment = paths.explorer_experiment,
    shell:
        """
        echo 🎉 Successfully run sopa
        echo → SpatialData output directory: {params.sdata_path}
        echo → Explorer output directory: {params.explorer_directory}
        echo → Open the result in the explorer: 'open {params.explorer_experiment}'
        """

rule to_spatialdata:
    input:
        [] if config["is_toy_reader"] else paths.data_path,
    output:
        paths.sdata_zgroup if paths.data_path else [],
    conda:
        "sopa"
    resources:
        mem_mb=128_000,
    params:
        reader = args['read'].as_cli(),
        data_path = paths.data_path,
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa convert {params.data_path} --sdata-path {params.sdata_path} {params.reader}
        """

rule tissue_segmentation:
    input:
        paths.sdata_zgroup if paths.data_path else [],
    output:
        touch(paths.smk_tissue_segmentation),
    conda:
        "sopa"
    params:
        tissue_segmentation = args["segmentation"]["tissue"].as_cli(),
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa segmentation tissue {params.sdata_path} {params.tissue_segmentation}
        """

checkpoint patchify_image:
    input:
        paths.sdata_zgroup,
        paths.smk_tissue_segmentation if args.tissue_segmentation else [],
    output:
        patches_file = paths.smk_patches_file_image,
        patches = touch(paths.smk_patches),
    params:
        patchify_image = args["patchify"].as_cli(contains="pixel"),
        sdata_path = paths.sdata_path,
    conda:
        "sopa"
    shell:
        """
        sopa patchify image {params.sdata_path} {params.patchify_image}
        """

checkpoint patchify_transcripts:
    input:
        paths.sdata_zgroup,
        paths.smk_cellpose_boundaries if args.cellpose else [],
        paths.smk_tissue_segmentation if args.tissue_segmentation else [],
    output:
        directory(paths.smk_transcripts_temp_dir),
        patches_file = paths.smk_patches_file_transcripts,
    params:
        patchify_transcripts = args.patchify_transcripts(),
        sdata_path = paths.sdata_path,
    conda:
        "sopa"
    shell:
        """
        sopa patchify transcripts {params.sdata_path} {params.patchify_transcripts}
        """

rule patch_segmentation_cellpose:
    input:
        paths.smk_patches_file_image,
        paths.smk_patches,
    output:
        paths.smk_cellpose_temp_dir / "{index}.parquet",
    conda:
        "sopa"
    params:
        cellpose = args["segmentation"]["cellpose"].as_cli(),
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa segmentation cellpose {params.sdata_path} --patch-index {wildcards.index} {params.cellpose}
        """

rule patch_segmentation_baysor:
    input:
        paths.smk_patches_file_transcripts,
    output:
        paths.smk_transcripts_temp_dir / "{index}" / "segmentation_counts.loom",
    conda:
        "sopa"
    resources:
        mem_mb=128_000,
        runtime=120, # in minutes
    params:
        baysor_config = args["segmentation"]["baysor"].as_cli(keys=["config"]),
        sdata_path = paths.sdata_path,
    shell:
        """
        if command -v module &> /dev/null; then
            module purge
        fi

        sopa segmentation baysor {params.sdata_path} --patch-index {wildcards.index} {params.baysor_config}
        """

rule patch_segmentation_comseg:
    input:
        paths.smk_patches_file_transcripts,
    output:
        paths.smk_transcripts_temp_dir / "{index}" / "segmentation_polygons.json",
        paths.smk_transcripts_temp_dir / "{index}" / "segmentation_counts.h5ad",
    conda:
        "sopa"
    resources:
        mem_mb=128_000,
    params:
        comseg_config = args["segmentation"]["comseg"].as_cli(keys=["config"]),
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa segmentation comseg {params.sdata_path} --patch-index {wildcards.index} {params.comseg_config}
        """

def get_input_resolve(name: str, method_name: str):
    def _(wilcards):
        with getattr(checkpoints, f"patchify_{name}").get(**wilcards).output.patches_file.open() as f:
            resolve_paths = paths.temporary_boundaries_paths(f.read(), method_name)
            return [str(path.as_posix()) for path in resolve_paths]  # snakemake uses posix paths (fix issue #64)
    return _

rule resolve_cellpose:
    input:
        get_input_resolve("image", "cellpose"),
    output:
        touch(paths.smk_cellpose_boundaries),
    conda:
        "sopa"
    params:
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa resolve cellpose {params.sdata_path}
        """

rule resolve_baysor:
    input:
        files = get_input_resolve("transcripts", "baysor"),
    output:
        touch(paths.smk_baysor_boundaries),
        touch(paths.smk_table),
    conda:
        "sopa"
    params:
        resolve = args.resolve_transcripts(),
        sdata_path = paths.sdata_path,
        smk_transcripts_temp_dir = paths.smk_transcripts_temp_dir,
    shell:
        """
        sopa resolve baysor {params.sdata_path} {params.resolve}

        rm -r {params.smk_transcripts_temp_dir}    # cleanup large baysor files
        """

rule resolve_comseg:
    input:
        files = get_input_resolve("transcripts", "comseg"),
    output:
        touch(paths.smk_comseg_boundaries),
        touch(paths.smk_table),
    conda:
        "sopa"
    params:
        resolve = args.resolve_transcripts(),
        sdata_path = paths.sdata_path,
        smk_transcripts_temp_dir = paths.smk_transcripts_temp_dir,
    shell:
        """
        sopa resolve comseg {params.sdata_path} {params.resolve}

        rm -r {params.smk_transcripts_temp_dir}    # cleanup large comseg files
        """

def get_smk_boundaries(args):
    if args.baysor:
        return paths.smk_baysor_boundaries
    elif args.comseg:
        return paths.smk_comseg_boundaries
    elif args.cellpose:
        return paths.smk_cellpose_boundaries
    else:
        raise ValueError("No segmentation method selected")

rule aggregate:
    input:
        get_smk_boundaries(args),
    output:
        touch(paths.smk_aggregation),
    conda:
        "sopa"
    resources:
        mem_mb=64_000,
    params:
        aggregate = args["aggregate"].as_cli(),
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa aggregate {params.sdata_path} {params.aggregate}
        """

rule annotate:
    input:
        paths.smk_aggregation,
    output:
        directory(paths.annotations),
    conda:
        "sopa"
    resources:
        partition="gpgpuq" if args['annotation']['method'] == "tangram" else "shortq",
        gpu="a100:1" if args['annotation']['method'] == "tangram" else 0,
    params:
        method_name = args['annotation']['method'],
        annotation = args['annotation']['args'].as_cli(),
        sdata_path = paths.sdata_path,
    shell:
        """
        sopa annotate {params.method_name} {params.sdata_path} {params.annotation}
        """

rule image_write:
    input:
        paths.sdata_zgroup,
    output:
        paths.explorer_image,
    conda:
        "sopa"
    resources:
        mem_mb=64_000,
        runtime=1500, # in minutes
    params:
        explorer = args["explorer"].as_cli(keys=['lazy', 'ram_threshold_gb', 'pixel_size', 'pixelsize']),
        sdata_path = paths.sdata_path,
        explorer_directory = paths.explorer_directory,
    shell:
        """
        sopa explorer write {params.sdata_path} --output-path {params.explorer_directory} {params.explorer} --mode "+i" --no-save-h5ad
        """

rule report:
    input:
        get_smk_boundaries(args),
        paths.smk_aggregation,
        paths.annotations if args.annotate else [],
    output:
        paths.report,
    params:
        sdata_path = paths.sdata_path,
        report = paths.report,
    conda:
        "sopa"
    shell:
        """
        sopa report {params.sdata_path} {params.report}
        """

rule explorer:
    input:
        get_smk_boundaries(args),
        paths.smk_aggregation,
        paths.annotations if args.annotate else [],
    output:
        paths.explorer_experiment,
    conda:
        "sopa"
    resources:
        mem_mb=256_000,
    params:
        explorer = args["explorer"].as_cli(),
        sdata_path = paths.sdata_path,
        explorer_directory = paths.explorer_directory,
    shell:
        """
        sopa explorer write {params.sdata_path} --output-path {params.explorer_directory} {params.explorer} --mode "-i"
        """
