import os
import snakePipes.common_functions as cf


### snakemake_workflows initialization ########################################
maindir = os.path.dirname(os.path.dirname(workflow.basedir))

# load conda ENVs (path is relative to "shared/rules" directory)
globals().update(cf.set_env_yamls())

# load config file
globals().update(cf.load_configfile(workflow.overwrite_configfiles[0], config["verbose"]))
# load organism-specific data, i.e. genome indices, annotation, etc.
globals().update(cf.load_organism_data(genome, maindir, config["verbose"]))
# return the pipeline version in the log
cf.get_version()

# do workflow specific stuff now
include: os.path.join(workflow.basedir, "internals.snakefile")

### include modules of other snakefiles ########################################
################################################################################

# FASTQ: either downsample FASTQ files or create symlinks to input files
include: os.path.join(maindir, "shared", "rules", "FASTQ.snakefile")

# FastQC
if fastqc:
    include: os.path.join(maindir, "shared", "rules", "FastQC.snakefile")

# trimming
if trim:
    include: os.path.join(maindir, "shared", "rules", "trimming.snakefile")

#umi_tools
include: os.path.join(maindir, "shared", "rules", "umi_tools.snakefile")

# HiCExplorer
include: os.path.join(maindir, "shared", "rules", "hicexplorer.snakefile")

# multiQC
include: os.path.join(maindir, "shared", "rules", "multiQC.snakefile")

### conditional/optional rules #################################################
################################################################################
def run_FastQC(fastqc):
    if fastqc:
        return( expand("FastQC/{sample}{read}_fastqc.html", sample = samples, read = reads) )
    else:
        return([])

def run_Trimming(trim, fastqc):
    if trim and fastqc:
        return( expand(fastq_dir+"/{sample}{read}.fastq.gz", sample = samples, read = reads) +
                expand("FastQC_trimmed/{sample}{read}_fastqc.html", sample = samples, read = reads) )
    elif trim:
        return( expand(fastq_dir+"/{sample}{read}.fastq.gz", sample = samples, read = reads) )
    else:
        return([])

# get samples to merge
sample_dict = get_sampleSheet(sampleSheet)
GROUPS = sample_dict.keys()

def run_build_matrices():
    file_list = []
    # SAMPLES TO MERGE??
    if mergeSamples:
        if nBinsToMerge > 0:
            suffix = "_Mbins" + str(nBinsToMerge) + "_" + matrixFile_suffix
        else:
            suffix = "_" + matrixFile_suffix

        file_list.append([
            expand("HiC_matrices/mergedSamples_{group}"+suffix+matrix_format, group=GROUPS),
            expand("HiC_matrices/QCplots/mergedSamples_{group}"+suffix+"_diagnostic_plot.pdf", group=GROUPS),
            expand("HiC_matrices/QCplots/mergedSamples_{group}"+suffix+"_mad_threshold.out", group=GROUPS)
            ])
        if not noCorrect:
            file_list.append(expand("HiC_matrices_corrected/mergedSamples_{group}"+suffix+".corrected"+matrix_format, group=GROUPS))
            if not noTAD:
                file_list.append(expand("TADs/mergedSamples_{group}"+suffix+"_boundaries.bed", group=GROUPS))

    # BINS TO MERGE??
    elif nBinsToMerge > 0:
        suffix = "_Mbins" + str(nBinsToMerge) + "_" + matrixFile_suffix
        file_list.append([
            expand("HiC_matrices/{sample}"+suffix+matrix_format, sample=samples),
            expand("HiC_matrices/QCplots/{sample}"+suffix+"_diagnostic_plot.pdf", sample=samples),
            expand("HiC_matrices/QCplots/{sample}"+suffix+"_mad_threshold.out", sample=samples)
            ])
        if not noCorrect:
            file_list.append(expand("HiC_matrices_corrected/{sample}"+suffix+".corrected"+matrix_format, sample=samples))
            if not noTAD:
                file_list.append(expand("TADs/{sample}"+suffix+"_boundaries.bed", sample=samples))
    # NOTHING TO MERGE??
    else:
        suffix = "_"+matrixFile_suffix
        file_list.append([
            expand("HiC_matrices/{sample}"+suffix+matrix_format, sample=samples),
            expand("HiC_matrices/QCplots/{sample}"+suffix+"_diagnostic_plot.pdf", sample=samples),
            expand("HiC_matrices/QCplots/{sample}"+suffix+"_mad_threshold.out", sample=samples)
            ])
        if not noCorrect:
            file_list.append(expand("HiC_matrices_corrected/{sample}"+suffix+".corrected"+matrix_format, sample=samples))
            if not noTAD:
                file_list.append(expand("TADs/{sample}"+suffix+"_boundaries.bed", sample=samples))

    return(file_list)


def run_dist_vs_count():
    if distVsCount:
       return(["dist_vs_counts.png"])
    else:
       return([])


### execute before workflow starts #############################################
################################################################################
onstart:
    if "verbose" in config and config["verbose"]:
        print("--- Workflow parameters --------------------------------------------------------")
        print("samples:", samples)
        print("fastq dir:", fastq_dir)
        print("-" * 80, "\n")

        print("--- Environment ----------------------------------------------------------------")
        print("$TMPDIR: ",os.getenv('TMPDIR', ""))
        print("$HOSTNAME: ",os.getenv('HOSTNAME', ""))
        print("-" * 80, "\n")
        print(sample_dict)

    if toolsVersion:
        usedEnvs = [CONDA_SHARED_ENV, CONDA_HIC_ENV]
        cf.writeTools(usedEnvs, outdir, "HiC", maindir)
    if sampleSheet:
        cf.copySampleSheet(sampleSheet, outdir)

### main rule ##################################################################
################################################################################

rule all:
    input:
        expand("FASTQ/{sample}{read}.fastq.gz", sample = samples, read = reads),
        run_FastQC(fastqc),
        run_Trimming(trim, fastqc),
        expand(aligner + "/{sample}{read}.bam", sample = samples, read = reads),
        run_build_matrices(),
        expand("HiC_matrices/QCplots/{sample}_QC/QC.log", sample = samples),
        run_dist_vs_count(),
        "multiQC/multiqc_report.html"

### execute after workflow finished ############################################
################################################################################
onsuccess:
    if "verbose" in config and config["verbose"]:
        print("\n--- Hi-C workflow finished successfully! --------------------------------\n")

onerror:
    print("\n !!! ERROR in HI-C workflow! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n")
