from snakemake.utils import min_version
min_version("6.0")

import sys
import os

# Load config
if "--configfile" in sys.argv:
    i = sys.argv.index("--configfile")
    config_path = sys.argv[i + 1]
    configfile: config_path
else:
    config_path = "config/config.yml"
    configfile: config_path

DATA_DIR = config['DATA_DIR']
ANALYSIS_DIR = config['ANALYSIS_DIR']
PLOTS_DIR = config['PLOTS_DIR']
LOG_DIR = config['LOG_DIR']
LOG_LEVEL = config['LOG_LEVEL']

COMMON_GENE_SEN_MARKERS = config['COMMON_GENE_SEN_MARKERS']
COMMON_PROTEIN_SEN_MARKERS = config['COMMON_PROTEIN_SEN_MARKERS']
ML_CLASSIFIER = config['ML_CLASSIFIER']

# Processing input counts data for ML applications
include: "rules/data.smk"
include: "rules/plots.smk"
include: "rules/cls.smk"

wildcard_constraints:
    input_counts = "|".join([
        "transcriptomics", "proteomics", 
        "IMR90_fibroblast_transcriptomics", "IMR90_fibroblast_transcriptomics.normalized", 
        "IMR90_fibroblast_proteomics", "IMR90_fibroblast_proteomics.normalized",
        #"IMR90_fibroblast_proteomics_DIA", "IMR90_fibroblast_proteomics_DIA.normalized",
        #"monocyte_proteomics", "monocyte_proteomics.normalized",
        "SenCat_HREC_proteomics", 
    ]),
    counts_type_markers = "|".join(["transcriptomics", "proteomics"]),
    ml_classifier = "|".join(ML_CLASSIFIER.keys()),

ruleorder: ml_classifier > marker_classifier

rule all:
    input:
        # copy config file to the analysis directory for reproducibility
        os.path.join(ANALYSIS_DIR, "config.yml"),

        # prepare input data
        expand(
            os.path.join(ANALYSIS_DIR, "{input_counts}.h5ad"), 
            input_counts = [
                "transcriptomics", "proteomics", # transcriptomics and proteomics are the main input data, loaded using anndata_from_excel rule
                "IMR90_fibroblast_transcriptomics", # IMR90_fibroblast_transcriptomics are data for testing transcriptomics markers, loaded using anndata_from_txt rule
                "IMR90_fibroblast_proteomics", # IMR90_fibroblast_proteomics are data for testing proteomics markers, loaded using anndata_from_csv rule
                "SenCat_HREC_proteomics", # SenCat_HREC_proteomics are data for testing proteomics markers, loaded using anndata_from_excel_SenCat_HREC_proteomics rule
            ]
        ),
        expand(os.path.join(ANALYSIS_DIR, "{input_counts}.for_SenCID.txt"), input_counts=["transcriptomics", "proteomics", "IMR90_fibroblast_transcriptomics"]),
        
        # qc plots on raw SenCat counts
        expand(os.path.join(PLOTS_DIR, "{input_counts}.library_size.pdf"), input_counts=["transcriptomics", "proteomics"]),
        expand(os.path.join(PLOTS_DIR, "{input_counts}.mito_plot.pdf"), input_counts=["transcriptomics", "proteomics"]),
        expand(os.path.join(PLOTS_DIR, "{input_counts}.ribo_plot.pdf"), input_counts=["transcriptomics", "proteomics"]),
        expand(os.path.join(PLOTS_DIR, "{input_counts}.hb_plot.pdf"), input_counts=["transcriptomics", "proteomics"]),

        # normalized data
        expand(
            os.path.join(ANALYSIS_DIR, "{input_counts}.normalized.h5ad"), 
            input_counts = [
                "IMR90_fibroblast_transcriptomics", 
                "IMR90_fibroblast_proteomics",
            ]
        ),
        
        # classification results
        expand(os.path.join(ANALYSIS_DIR, "{input_counts}.{ml_classifier}.classification_results.csv"), input_counts=["transcriptomics", "proteomics"], ml_classifier=ML_CLASSIFIER.keys()),
        expand(os.path.join(ANALYSIS_DIR, "{input_counts}.{ml_classifier}.tuned_classification_results.csv"), input_counts=["transcriptomics", "proteomics"], ml_classifier=ML_CLASSIFIER.keys()),
        expand(os.path.join(ANALYSIS_DIR, "{input_counts}_{ml_classifier}_common_features.csv"), input_counts=["transcriptomics", "proteomics"], ml_classifier=ML_CLASSIFIER.keys()),
        expand(os.path.join(ANALYSIS_DIR, "{input_counts}_{ml_classifier}_tuned_common_features.csv"), input_counts=["transcriptomics", "proteomics"], ml_classifier=ML_CLASSIFIER.keys()),

        # classification with markers results
        expand(os.path.join(ANALYSIS_DIR, "{input_counts}.{counts_type_markers}_{ml_classifier}_{markers_type}.results.csv"), 
            input_counts=["transcriptomics", "IMR90_fibroblast_transcriptomics.normalized"], 
            ml_classifier=ML_CLASSIFIER.keys(),  
            counts_type_markers=["transcriptomics"], 
            markers_type=["tuned_common_features"]),
        expand(os.path.join(ANALYSIS_DIR, "{input_counts}.{counts_type_markers}_{ml_classifier}_{markers_type}.results.csv"), 
            input_counts=["proteomics", "IMR90_fibroblast_proteomics.normalized", "IMR90_fibroblast_proteomics_DIA.normalized", "SenCat_HREC_proteomics"], 
            ml_classifier=ML_CLASSIFIER.keys(),
            counts_type_markers=["proteomics"],   
            markers_type=["tuned_common_features"]),

        # classification results plots
        expand(os.path.join(PLOTS_DIR, "{input_counts}.{prediction_method}.results.pdf"), 
            input_counts=["transcriptomics", "proteomics"],  
            prediction_method=[f"{a}.{b}" for a in list(ML_CLASSIFIER.keys()) for b in ["classification", "tuned_classification"]]),
        expand(os.path.join(PLOTS_DIR, "{input_counts}.{prediction_method}.results.pdf"),
            input_counts=["transcriptomics", "proteomics", "IMR90_fibroblast_transcriptomics"], 
            prediction_method="SenCID"),
        expand(os.path.join(PLOTS_DIR, "{input_counts}.{counts_type_markers}_{ml_classifier}_{markers_type}.results.pdf"), 
            input_counts=["transcriptomics", "IMR90_fibroblast_transcriptomics.normalized"], 
            ml_classifier=ML_CLASSIFIER.keys(), 
            counts_type_markers=["transcriptomics"],  
            markers_type=["common_features", "tuned_common_features"]),
        expand(os.path.join(PLOTS_DIR, "{input_counts}.{counts_type_markers}_{ml_classifier}_{markers_type}.results.pdf"), 
            input_counts=["proteomics", "IMR90_fibroblast_proteomics.normalized", "SenCat_HREC_proteomics"], 
            ml_classifier=ML_CLASSIFIER.keys(), 
            counts_type_markers=["proteomics"], 
            markers_type=["common_features", "tuned_common_features"]),

        # plot gene markers
        expand(os.path.join(PLOTS_DIR, "{input_counts}_counts.{counts_file_markers}_{ml_classifier}_gene_markers.pdf"), 
            input_counts=["transcriptomics"], 
            counts_file_markers=["transcriptomics"], 
            ml_classifier=ML_CLASSIFIER.keys()),
        expand(os.path.join(PLOTS_DIR, "{input_counts}_counts.{counts_file_markers}_{ml_classifier}_gene_markers.pdf"), 
            input_counts=["proteomics"], 
            counts_file_markers=["proteomics"], 
            ml_classifier=ML_CLASSIFIER.keys()),
        expand(os.path.join(PLOTS_DIR, "{input_counts}_counts.{gene_markers}_gene_markers.pdf"), 
            input_counts=["transcriptomics"], 
            gene_markers=COMMON_GENE_SEN_MARKERS.keys()),
        expand(os.path.join(PLOTS_DIR, "{input_counts}_counts.{gene_markers}_gene_markers.pdf"), 
            input_counts=["proteomics"], 
            gene_markers=COMMON_PROTEIN_SEN_MARKERS.keys()),

        # plot gene markers with ml markers
        expand(os.path.join(
            PLOTS_DIR, 
            "{input_counts}_counts.{gene_markers}_gene_markers.{counts_type_markers}_{ml_classifier}_{markers_type}_ml_markers.pdf"),
            input_counts=["transcriptomics"],
            gene_markers=[f"{a}_{b}" for a in ['transcriptomics'] for b in list(ML_CLASSIFIER.keys())] + list(COMMON_GENE_SEN_MARKERS.keys()),
            counts_type_markers=["transcriptomics"],
            ml_classifier=ML_CLASSIFIER.keys(),
            markers_type=["tuned_common_features"],
        ),
        expand(os.path.join(
            PLOTS_DIR, 
            "{input_counts}_counts.{gene_markers}_gene_markers.{counts_type_markers}_{ml_classifier}_{markers_type}_ml_markers.pdf"),
            input_counts=["proteomics"],
            gene_markers=[f"{a}_{b}" for a in ['proteomics'] for b in list(ML_CLASSIFIER.keys())] + list(COMMON_PROTEIN_SEN_MARKERS.keys()),
            counts_type_markers=["proteomics"],
            ml_classifier=ML_CLASSIFIER.keys(),
            markers_type=["tuned_common_features"],
        ),
        
        # plot ml marker
        expand(os.path.join(PLOTS_DIR, "{input_counts}.{counts_type_markers}_{ml_classifier}_{markers_type}_ml_markers.pdf"),
            input_counts=["transcriptomics", "IMR90_fibroblast_transcriptomics.normalized"], 
            counts_type_markers=["transcriptomics"], 
            ml_classifier=ML_CLASSIFIER.keys(),
            markers_type=["tuned_common_features"]),
        expand(os.path.join(PLOTS_DIR, "{input_counts}.{counts_type_markers}_{ml_classifier}_{markers_type}_ml_markers.pdf"), 
            input_counts=["proteomics", "IMR90_fibroblast_proteomics.normalized",  "SenCat_HREC_proteomics"], 
            counts_type_markers=["proteomics"], 
            ml_classifier=ML_CLASSIFIER.keys(),
            markers_type=["tuned_common_features"]),

# copy config file to the analysis directory
rule copy_config:
    input:
        config_path
    output:
        os.path.join(ANALYSIS_DIR, "config.yml")
    shell:
        """
        cp {input} {output}
        """