from utils import WorkflowPaths, Args

paths = WorkflowPaths(config)
args = Args(paths, config)

localrules: all

rule all:
    input:
        paths.annotations if args.annotate else [],
        paths.explorer_experiment,
        paths.explorer_image,
        paths.report,
    shell:
        """
        echo 🎉 Successfully run sopa
        echo → SpatialData output directory: {paths.sdata_path}
        echo → Explorer output directory: {paths.explorer_directory}
        echo → Open the result in the explorer: 'open {paths.explorer_experiment}'
        """

rule to_spatialdata:
    input:
        paths.data_path if config["read"]["technology"] != "uniform" else [],
    output:
        paths.sdata_zgroup if paths.data_path else [],
    conda:
        "sopa"
    resources:
        mem_mb=128_000,
    params:
        args_reader = str(args['read'])
    shell:
        """
        sopa read {paths.data_path} --sdata-path {paths.sdata_path} {params.args_reader}
        """

checkpoint patchify_cellpose:
    input:
        paths.sdata_zgroup,
    output:
        patches_file = paths.smk_patches_file_image,
        patches = touch(paths.smk_patches),
    params:
        args_patchify = str(args["patchify"].where(contains="pixel")),
    conda:
        "sopa"
    shell:
        """
        sopa patchify image {paths.sdata_path} {params.args_patchify}
        """

checkpoint patchify_baysor:
    input:
        paths.sdata_zgroup,
        paths.smk_cellpose_boundaries if args.cellpose else [],
    output:
        patches_file = paths.smk_patches_file_baysor,
        smk_baysor_temp_dir = directory(paths.smk_baysor_temp_dir),
    params:
        args_patchify = str(args["patchify"].where(contains="micron")),
        args_baysor = args.dump_baysor_patchify() if args.baysor else "",
        arg_prior = "--use-prior" if args.cellpose else "",
    conda:
        "sopa"
    shell:
        """
        sopa patchify baysor {paths.sdata_path} {params.args_patchify} {params.args_baysor} {params.arg_prior}
        """

rule patch_segmentation_cellpose:
    input:
        paths.smk_patches_file_image,
        paths.smk_patches,
    output:
        f"{paths.smk_cellpose_temp_dir}/{{index}}.parquet",
    conda:
        "sopa"
    params:
        args_cellpose = str(args["segmentation"]["cellpose"]),
    shell:
        """
        sopa segmentation cellpose {paths.sdata_path} --patch-dir {paths.smk_cellpose_temp_dir} --patch-index {wildcards.index} {params.args_cellpose}
        """

rule patch_segmentation_baysor:
    input:
        patches_file = paths.smk_patches_file_baysor,
        baysor_patch = f"{paths.smk_baysor_temp_dir}/{{index}}",
    output:
        f"{paths.smk_baysor_temp_dir}/{{index}}/segmentation_polygons.json",
        f"{paths.smk_baysor_temp_dir}/{{index}}/segmentation_counts.loom",
    params:
        args_baysor_prior_seg = args.baysor_prior_seg,
    resources:
        mem_mb=128_000,
    shell:
        """
        if command -v module &> /dev/null; then
            module purge
        fi

        cd {input.baysor_patch}
        {config[executables][baysor]} run --save-polygons GeoJSON -c config.toml transcripts.csv {params.args_baysor_prior_seg}
        """

def get_input_resolve(name, dirs=False):
    def _(wilcards):
        with getattr(checkpoints, f"patchify_{name}").get(**wilcards).output.patches_file.open() as f:
            return paths.cells_paths(f.read(), name, dirs=dirs)
    return _

rule resolve_cellpose:
    input:
        get_input_resolve("cellpose"),
    output:
        touch(paths.smk_cellpose_boundaries),
    conda:
        "sopa"
    shell:
        """
        sopa resolve cellpose {paths.sdata_path} --patch-dir {paths.smk_cellpose_temp_dir}
        """

rule resolve_baysor:
    input:
        files = get_input_resolve("baysor"),
        dirs = get_input_resolve("baysor", dirs=True),
    output:
        touch(paths.smk_baysor_boundaries),
        touch(paths.smk_table),
    conda:
        "sopa"
    params:
        args_patches_dirs = lambda _, input: " ".join(f"--patches-dirs {directory}" for directory in input.dirs),
        args_min_area = args.min_area("baysor"),
    shell:
        """
        sopa resolve baysor {paths.sdata_path} --gene-column {args.gene_column} {params.args_min_area} {params.args_patches_dirs}

        rm -r {paths.smk_baysor_temp_dir}    # cleanup large baysor files
        """

rule aggregate:
    input:
        paths.smk_baysor_boundaries if args.baysor else paths.smk_cellpose_boundaries,
    output:
        touch(paths.smk_aggregation),
    conda:
        "sopa"
    params:
        args_aggregate = str(args["aggregate"] or ""),
    resources:
        mem_mb=64_000,
    shell:
        """
        sopa aggregate {paths.sdata_path} {params.args_aggregate}
        """

rule annotate:
    input:
        paths.smk_aggregation,
    output:
        directory(paths.annotations),
    conda:
        "sopa"
    resources:
        partition="gpgpuq" if args['annotation']['method'] == "tangram" else "shortq",
        gpu="a100:1" if args['annotation']['method'] == "tangram" else 0,
    params:
        method_name = args['annotation']['method'],
        args_annotation = str(args['annotation']['args']),
    shell:
        """
        sopa annotate {params.method_name} {paths.sdata_path} {params.args_annotation}
        """

rule image_write:
    input:
        paths.sdata_zgroup,
    output:
        paths.explorer_image,
    conda:
        "sopa"
    resources:
        mem_mb=64_000,
        partition="longq"
    params:
        args_explorer = str(args["explorer"].where(keys=['lazy', 'ram_threshold_gb', 'pixel_size', 'pixelsize'])),
    shell:
        """
        sopa explorer write {paths.sdata_path} --output-path {paths.explorer_directory} {params.args_explorer} --mode "+i" --no-save-h5ad
        """

rule report:
    input:
        paths.smk_baysor_boundaries if args.baysor else paths.smk_cellpose_boundaries,
        paths.smk_aggregation,
        paths.annotations if args.annotate else [],
    output:
        paths.report,
    conda:
        "sopa"
    shell:
        """
        sopa report {paths.sdata_path} {paths.report}
        """

rule explorer:
    input:
        paths.smk_baysor_boundaries if args.baysor else paths.smk_cellpose_boundaries,
        paths.smk_aggregation,
        paths.annotations if args.annotate else [],
    output:
        paths.explorer_experiment,
    conda:
        "sopa"
    resources:
        mem_mb=256_000,
    params:
        args_explorer = str(args["explorer"]),
    shell:
        """
        sopa explorer write {paths.sdata_path} --output-path {paths.explorer_directory} {params.args_explorer} --mode "-i"
        """
