
##### global workflow dependencies #####
# conda: "envs/global.yaml"

# libraries
import yaml
import pandas as pd
import os
from snakemake.utils import validate, min_version
import json
import csv
import sys
import subprocess

##### module name #####
module_name = "spilterlize_integrate"

##### set minimum snakemake version #####
min_version("8.20.1")

##### setup report #####
report: "report/workflow.rst"

##### load config and annotation #####
configfile: os.path.join("config","config.yaml")

### annotation
annotation = pd.read_csv(config["annotation"], index_col=0)

##### set global variables
result_path = os.path.join(config["result_path"],module_name)

# make splits
splits = ['all']
for split_var in config["split_by"]:
    # check if variable in columns
    if split_var in annotation.columns:
        splits = splits + [f'{split_var}_{entry}' for entry in annotation[split_var].unique()]
# print(splits)

# normalization methods
# replace "none" with "CPM"/"RPKM"
norm_edgeR_methods = ["norm"+config["edgeR_parameters"]["quantification"] if i == "none" else "norm"+i for i in config["edgeR_parameters"]["method"]]
norm_cqn = ["normCQN"] if config["cqn_parameters"]["x"] != "" else []
norm_voom = ["normVOOM"] if config["voom_parameters"]["normalize.method"] != "" else []

norm_methods = norm_edgeR_methods + norm_cqn + norm_voom
# print(norm_methods)

# integration methods
integration_data = []
if len(config["removeBatchEffect_parameters"]["desired"])!=0:
    integration_data = config["removeBatchEffect_parameters"]["norm_methods"]
    integration_data = ["norm"+method+"_integrated" for method in integration_data]
# print(integration_data)

# data for HVF selection
transformed_data = []
if len(integration_data)>0:
    transformed_data = [method+"_HVF" for method in integration_data]
elif len(norm_methods)>0:
    transformed_data = [method+"_HVF" for method in norm_methods]
# else:
#     transformed_data = ['filtered']
# print(transformed_data)

# visualization data
all_data = ['counts','filtered'] + norm_methods + integration_data + transformed_data
# print(all_data)

##### target rules #####
rule all:
    input:
        # SPLIT
#         expand(os.path.join(result_path, '{split}','counts.csv'), split=splits),
#         expand(os.path.join(result_path, '{split}','annotation.csv'), split=splits),
        # FILTER
#         expand(os.path.join(result_path, '{split}','filtered.csv'), split=splits),
        # NORMALIZE
#         expand(os.path.join(result_path, '{split}','{method}.csv'), split=splits, method=norm_methods),
        # INTEGRATE
#         expand(os.path.join(result_path, '{split}','{method}.csv'), split=splits, method=integration_data),
        # HVF
#         expand(os.path.join(result_path, '{split}','{transformed_data}.csv'), split=splits, transformed_data=transformed_data),
        # VISUALIZE
        expand(os.path.join(result_path, '{split}','plots','{all_data}.png'), split=splits, all_data=all_data),
        expand(os.path.join(result_path, '{split}','plots','{all_data}_heatmap_clustered.png'), split=splits, all_data=all_data),
        expand(os.path.join(result_path,'{split}','plots','CFA.png'), split=splits),
        # config
        envs = expand(os.path.join(result_path,'envs','{env}.yaml'),env=["python_basics","edgeR","cqn","ggplot"]),
        configs = os.path.join(result_path,'configs','{}_config.yaml'.format(config["project_name"])),
        annot = os.path.join(result_path,'configs','{}_annot.csv'.format(config["project_name"])),
    threads: config.get("threads", 1)
    resources:
        mem_mb=config.get("mem", "16000"),
    log:
        os.path.join("logs","rules","all.log")

##### load rules #####
include: os.path.join("rules", "process.smk")
include: os.path.join("rules", "normalize.smk")
include: os.path.join("rules", "integrate.smk")
include: os.path.join("rules", "visualize.smk")
include: os.path.join("rules", "envs_export.smk")


