import os
import snakePipes.common_functions as cf
import itertools


### snakemake_workflows initialization ########################################
maindir = os.path.dirname(os.path.dirname(workflow.basedir))
workflow_rscripts=os.path.join(maindir, "shared", "rscripts")

# load conda ENVs (path is relative to "shared/rules" directory)
globals().update(cf.set_env_yamls())

# load config file
globals().update(cf.load_configfile(workflow.overwrite_configfiles[0], config["verbose"]))
# load organism-specific data, i.e. genome indices, annotation, etc.
globals().update(cf.load_organism_data(genome, maindir, config["verbose"]))
# return the pipeline version in the log
cf.get_version()

outdir_MACS2 = 'MACS2/'
short_bams = 'short_bams/'
outdir_ATACqc = 'MACS2_QC/'
deeptools_ATAC='deepTools_ATAC/'

# do workflow specific stuff now
include: os.path.join(workflow.basedir, "internals.snakefile")

### include modules of other snakefiles ########################################
################################################################################
# Import deeptools cmds
include: os.path.join(maindir, "shared", "tools", "deeptools_cmds.snakefile")

if fromBAM:
    include: os.path.join(maindir, "shared", "rules", "LinkBam.snakefile")
    include: os.path.join(maindir, "shared", "rules", "deepTools_qc.snakefile")

# Import deeptools ATAC
include: os.path.join(maindir, "shared", "rules", "deepTools_ATAC.snakefile")

#import multiQC
include: os.path.join(maindir, "shared", "rules", "multiQC.snakefile")

# ATACseq open chromatin
include: os.path.join(maindir, "shared", "rules", "ATAC.snakefile")

# ATAC QC open chromatin
include: os.path.join(maindir, "shared", "rules", "ATAC_qc.snakefile")

# flags for allelic mode
if os.path.isdir(os.path.join(workingdir, 'allelic_bams')) and os.listdir(os.path.join(workingdir, 'allelic_bams')) != []:
    allele_info = True
else:
    allele_info = False

# CSAW for differential binding (if sampleinfo specified)
# needs a "ChIP-sample" variable, here all samples are ChIP-Samples
chip_samples = samples
pairedEnd = True
if sampleSheet:
    if not isMultipleComparison:
        include: os.path.join(maindir, "shared", "rules", "CSAW.singleComp.snakefile")
        include: os.path.join(maindir, "shared", "rules", "nearestGene.singleComp.snakefile")
    else:
        include: os.path.join(maindir, "shared", "rules", "CSAW.multiComp.snakefile")
        include: os.path.join(maindir, "shared", "rules", "nearestGene.multiComp.snakefile")
    include: os.path.join(maindir, "shared", "rules", "filterGTF.snakefile")

## add outputs as asked
def run_deepTools_allelic():
    file_list = []
    if os.path.isdir('allelic_bams') and os.listdir('allelic_bams') != []:
        file_list.append( [
        os.path.join(deeptools_ATAC, "plotFingerprint/plotFingerprint.metrics_allelic.txt")
         ] )
    return(file_list)

def run_CSAW():
    if sampleSheet:
        if not isMultipleComparison:
            file_list=["CSAW_{}_{}/CSAW.session_info.txt".format(peakCaller, sample_name),
                   "Annotation/genes.filtered.symbol","Annotation/genes.filtered.t2g",
                   "CSAW_{}_{}/CSAW.Stats_report.html".format(peakCaller, sample_name)]
            file_list.append([os.path.join("CSAW_{}_{}".format(peakCaller, sample_name), x) \
                         for x in list(itertools.chain.from_iterable([expand("CSAW.{change_dir}.cov.matrix",\
                         change_dir=change_direction),expand("CSAW.{change_dir}.cov.heatmap.png",change_dir=change_direction)]))])
            file_list.append([os.path.join("AnnotatedResults_{}_{}".format(peakCaller, sample_name), x) \
                         for x in list(itertools.chain.from_iterable([expand("Filtered.results.{change_dir}_withNearestGene.txt",\
                         change_dir=change_direction)]))])
        else:
            file_list=expand("CSAW_{}_{}/CSAW.session_info.txt".format(peakCaller, sample_name + ".{compGroup}"),compGroup=cf.returnComparisonGroups(sampleSheet))
            file_list.append(["Annotation/genes.filtered.symbol","Annotation/genes.filtered.t2g",expand("CSAW_{}_{}/CSAW.Stats_report.html".format(peakCaller, sample_name + ".{compGroup}"),compGroup=cf.returnComparisonGroups(sampleSheet))])
            file_list.append([expand("CSAW_{}_{}".format(peakCaller, sample_name + ".{compGroup}") + "/CSAW.{change_dir}.cov.matrix",change_dir=change_direction,compGroup=cf.returnComparisonGroups(sampleSheet)),
                  expand("CSAW_{}_{}".format(peakCaller, sample_name + ".{compGroup}") +"/CSAW.{change_dir}.cov.heatmap.png",change_dir=change_direction,compGroup=cf.returnComparisonGroups(sampleSheet)
                  )])
            file_list.append(expand("AnnotatedResults_{}_{}".format(peakCaller,sample_name + ".{compGroup}") + "/Filtered.results.{change_dir}_withNearestGene.txt",change_dir=change_direction, compGroup=cf.returnComparisonGroups(sampleSheet)))
        return( file_list )
    else:
        return([])

def run_deepTools_qc(fromBAM):
    file_list = []
    if fromBAM:
        file_list = ["deepTools_qc/bamPEFragmentSize/fragmentSize.metric.tsv"]
        file_list.append([expand("bamCoverage/{sample}.filtered.seq_depth_norm.bw", sample = samples)])
        if len(samples) <= 20:
            file_list.append( ["deepTools_qc/plotCoverage/read_coverage.tsv"] )
        if len(samples)>1 and len(samples)<=20:
            file_list.append( [
                "deepTools_qc/plotCorrelation/correlation.pearson.read_coverage.tsv",
                "deepTools_qc/plotCorrelation/correlation.spearman.read_coverage.tsv",
                "deepTools_qc/plotPCA/PCA.read_coverage.tsv" ])
        file_list.append(expand("deepTools_qc/estimateReadFiltering/{sample}_filtering_estimation.txt",sample = samples))
    return (file_list)


# This needs to be optional
def run_HMMRATAC():
    file_list = []
    if peakCaller == "HMMRATAC":
        file_list = expand("HMMRATAC/{sample}_peaks.gappedPeak", sample=samples)
    return (file_list)


def run_genrich():
    if peakCaller == "Genrich":
        file_list = ["Genrich/all_samples.narrowPeak"]
        if sampleSheet:
            if not isMultipleComparison:
                file_list = ["Genrich/{}.narrowPeak".format(x) for x in genrichDict.keys()]
            else:
                #print(genrichDict)
                unique_group=[]
                file_list=[]
                for key,value in genrichDict.items():
                    for x in value.keys():
                        if zip(key,x) not in unique_group:
                            unique_group.append(zip(key,x))
                            file_list.append("Genrich/{}.{}.narrowPeak".format(x,key))
        file_list.append(expand(short_bams + "{sample}.short.namesorted.bam",sample=samples))
        return (file_list)
    return ([])


### execute before workflow starts #############################################
################################################################################
onstart:
    if "verbose" in config and config["verbose"]:
        print("--- Workflow parameters --------------------------------------------------------")
        print("samples:", samples)
        print("ATAC fragment cutoff: {}-{}".format(minFragmentSize, maxFragmentSize))
        print("-" * 80, "\n")

        print("--- Environment ----------------------------------------------------------------")
        print("$TMPDIR: ",os.getenv('TMPDIR', ""))
        print("$HOSTNAME: ",os.getenv('HOSTNAME', ""))
        print("-" * 80, "\n")

        print("--- Genome ---------------------------------------------------------------------")
        print("Genome:", genome)
        print("Effective genome size:", genome_size)
        print("Genome FASTA:", genome_fasta)
        print("Genome index:", genome_index)
        print("Genome 2bit:", genome_2bit)
        print("Bowtie2 index:", bowtie2_index)
        print("Gene annotation BED:", genes_bed)
        print("Gene annotation GTF:", genes_gtf)
        print("Blacklist regions BED:", blacklist_bed)
        print("Ignore for normalization (bigwigs):", ignoreForNormalization)
        print("-" * 80, "\n")

    if toolsVersion:
        usedEnvs = [CONDA_SHARED_ENV, CONDA_ATAC_ENV, CONDA_RMD_ENV]
        cf.writeTools(usedEnvs, workingdir, "ATACseq", maindir)
    if sampleSheet:
        cf.copySampleSheet(sampleSheet, workingdir)

### main rule ##################################################################
################################################################################
rule all:
    input:
        expand(os.path.join(outdir_MACS2, "{sample}.filtered.short.BAM_peaks.xls"), sample = samples),
        expand(os.path.join(outdir_ATACqc,"{sample}.filtered.BAM_peaks.qc.txt"), sample = samples),
        os.path.join(deeptools_ATAC, "plotFingerprint/plotFingerprint.metrics.txt"),
        ## run deeptools-allelic only if dir "allelic_bams" present and non empty
        run_deepTools_allelic(),
        ## run HMMRATAC
        run_HMMRATAC(),
        ## run Genrich, this sets genrich_groups
        run_genrich(),
        ## run csaw if asked for
        run_CSAW(),
        expand("deepTools_ATAC/bamCompare/{sample}.filtered.bw",sample=samples),
        run_deepTools_qc(fromBAM),
        expand(os.path.join(short_bams, "{sample}.short.cleaned.bam"),sample=samples),
        expand(os.path.join(short_bams, "{sample}.short.cleaned.bam.bai"),sample=samples)

### execute after workflow finished ############################################
################################################################################
onsuccess:
    if "verbose" in config and config["verbose"]:
        print("\n--- The ATACseq open chromatin workflow finished successfully! --------------------------------\n")

onerror:
    print("\n !!! ERROR in the ATACseq open chromatin workflow! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n")
