import os
import snakePipes.common_functions as cf
import itertools
import warnings


### snakemake_workflows initialization ########################################
maindir = os.path.dirname(os.path.dirname(workflow.basedir))
workflow_rscripts=os.path.join(maindir, "shared", "rscripts")

# load conda ENVs (path is relative to "shared/rules" directory)
globals().update(cf.set_env_yamls())

# load config file
globals().update(cf.load_configfile(workflow.overwrite_configfiles[0], config["verbose"]))
# load organism-specific data, i.e. genome indices, annotation, etc.
globals().update(cf.load_organism_data(genome, maindir, config["verbose"]))
# return the pipeline version in the log
cf.get_version()

# do workflow specific stuff now
include: os.path.join(workflow.basedir, "internals.snakefile")

### include modules of other snakefiles ########################################
################################################################################
# deeptools cmds
include: os.path.join(maindir, "shared", "tools" , "deeptools_cmds.snakefile")
# fromBAM
if fromBAM:
    include: os.path.join(maindir, "shared", "rules", "LinkBam.snakefile")
    if not useSpikeInForNorm:
        include: os.path.join(maindir, "shared", "rules", "deepTools_qc.snakefile")

if useSpikeInForNorm:
        include: os.path.join(maindir, "shared", "rules", "split_bam_ops_ChIP_spikein.snakefile")
        include: os.path.join(maindir, "shared", "rules", "deepTools_ChIP_spikein.snakefile")
        include: os.path.join(maindir, "shared", "rules", "ChIP_peak_calling_spikein.snakefile")
        include: os.path.join(maindir, "shared", "rules", "filterGTF_spikein.snakefile")
else:
    # deepTools ChIP
    include: os.path.join(maindir, "shared", "rules", "deepTools_ChIP.snakefile")
    # MACS2, MACS2 peak QC, and Genrich
    include: os.path.join(maindir, "shared", "rules", "ChIP_peak_calling.snakefile")

# QC report for all samples
include: os.path.join(maindir, "shared", "rules", "ChiP-seq_qc_report.snakefile")


# deeptools for allelic bams (if present)


# histoneHMM (if mode is not allele-specific)
if not allele_info:
    include: os.path.join(maindir, "shared", "rules", "histoneHMM.snakefile")
else:
    include: os.path.join(maindir, "shared", "rules", "deepTools_ChIP_allelic.snakefile")

# CSAW for differential binding (if sampleinfo specified)
if sampleSheet and cf.check_replicates(sampleSheet):
    if not isMultipleComparison:
        include: os.path.join(maindir, "shared", "rules", "CSAW.singleComp.snakefile")
        include: os.path.join(maindir, "shared", "rules", "nearestGene.singleComp.snakefile")
    else:
        include: os.path.join(maindir, "shared", "rules", "CSAW.multiComp.snakefile")
        include: os.path.join(maindir, "shared", "rules", "nearestGene.multiComp.snakefile")
    include: os.path.join(maindir, "shared", "rules", "filterGTF.snakefile")

def run_histoneHMM(allele_info,peakCaller): # TODO what is a good practice for broad peaks of cutntag?
    if not allele_info and peakCaller=="histoneHMM":
        ## run histoneHMM broad enrichment calling only for samples annotated as *broad*
        file_list = expand("histoneHMM/{sample}.filtered.histoneHMM-regions.gff.gz", sample=broad_samples)
        file_list.append(expand("histoneHMM/{sample}_avgp0.5.bed", sample=broad_samples))
        file_list.append("histoneHMM_chipqc/sessionInfo.txt")
    else:
        file_list = []
    return(file_list)

def run_deepTools_qc(fromBAM,useSpikeInForNorm):
    if fromBAM and not useSpikeInForNorm:
        file_list = ["deepTools_qc/bamPEFragmentSize/fragmentSize.metric.tsv"]
        file_list.append([expand("bamCoverage/{sample}.filtered.seq_depth_norm.bw", sample = samples)])
        if len(samples) <= 20:
            file_list.append( ["deepTools_qc/plotCoverage/read_coverage.tsv"] )
        if len(samples)>1 and len(samples)<=20:
            file_list.append( [
                "deepTools_qc/plotCorrelation/correlation.pearson.read_coverage.tsv",
                "deepTools_qc/plotCorrelation/correlation.spearman.read_coverage.tsv",
                "deepTools_qc/plotPCA/PCA.read_coverage.tsv" ])
        file_list.append(expand("deepTools_qc/estimateReadFiltering/{sample}_filtering_estimation.txt",sample = samples))
    elif useSpikeInForNorm:
        file_list=expand("split_deepTools_qc/bamPEFragmentSize/{part}.fragmentSize.metric.tsv",part=part)
        if getSizeFactorsFrom=="genome":
            file_list=(expand("split_deepTools_qc/multiBamSummary/{part}.ChIP_read_coverage.bins.npz",part=part))
            file_list.append(expand("bamCoverage/{sample}.host_scaled.BY{part}.bw",sample=chip_samples,part=part))
        elif getSizeFactorsFrom=="TSS":
            file_list.append(expand("bamCoverage_TSS/{sample}.host_scaled.BYspikein.bw",sample=chip_samples))
        elif getSizeFactorsFrom=="input" and chip_samples_w_ctrl:
            file_list.append(expand("bamCoverage_input/{sample}.host_scaled.BYspikein.bw",sample=chip_samples_w_ctrl))
            file_list.append(expand("split_deepTools_qc/multiBamSummary/{part}.input_read_coverage.bins.npz",part=part))
    else:
        file_list = []
    return(file_list)

def run_deepTools_ChIP():
    file_list = []
    file_list.append(["deepTools_ChIP/plotFingerprint/plotFingerprint.metrics.txt"]) if not useSpikeInForNorm \
                    else file_list.append(["split_deepTools_ChIP/plotFingerprint/plotFingerprint.metrics.txt"])
    if chip_samples_w_ctrl:
        for chip_sample in chip_samples_w_ctrl:
            control_name = get_control_name(chip_sample)
            if not useSpikeInForNorm:
                # get bigwigtype
                if bigWigType == "subtract" or bigWigType == "both":
                    file_list.append(["deepTools_ChIP/bamCompare/"+chip_sample+".filtered.subtract."+control_name+".bw"])
                if bigWigType == "log2ratio" or bigWigType == "both":
                    file_list.append(["deepTools_ChIP/bamCompare/"+chip_sample+".filtered.log2ratio.over_"+control_name+".bw"])
            elif useSpikeInForNorm and getSizeFactorsFrom == "genome":
                # get bigwigtype
                if bigWigType == "subtract" or bigWigType == "both":
                    file_list.append(expand("split_deepTools_ChIP/bamCompare/"+chip_sample+".subtract."+control_name+".scaledBY{part}.bw",part=part))
                if bigWigType == "log2ratio" or bigWigType == "both":
                    file_list.append(expand("split_deepTools_ChIP/bamCompare/"+chip_sample+".log2ratio.over_"+control_name+".scaledBY{part}.bw",part=part))
    return(file_list)

def run_deepTools_allelic():
    file_list = []
    if os.path.isdir('allelic_bams') and os.listdir('allelic_bams') != []:
        for chip_sample in chip_samples_w_ctrl:
            control_name = get_control_name(chip_sample)
            file_list.append([
            "deepTools_ChIP/bamCompare/allele_specific/"+chip_sample+".genome1.log2ratio.over_"+control_name+".bw",
            "deepTools_ChIP/bamCompare/allele_specific/"+chip_sample+".genome2.log2ratio.over_"+control_name+".bw",
            ])
        file_list.append( [
        "deepTools_ChIP/plotEnrichment/plotEnrichment.gene_features_allelic.png",
        "deepTools_ChIP/plotEnrichment/plotEnrichment.gene_features_allelic.tsv",
        "deepTools_ChIP/plotFingerprint/plotFingerprint.metrics_allelic.txt"
         ] )
    return(file_list)

def run_CSAW():
    if sampleSheet and cf.check_replicates(sampleSheet):
        if not isMultipleComparison:
            file_list=["CSAW_{}_{}/CSAW.session_info.txt".format(peakCaller, sample_name),"CSAW_{}_{}/Full.results.bed".format(peakCaller, sample_name)]
            if not allele_info :
                file_list.append(["Annotation/genes.filtered.symbol","Annotation/genes.filtered.t2g","CSAW_{}_{}/CSAW.Stats_report.html".format(peakCaller, sample_name)])
                file_list.append([os.path.join("CSAW_{}_{}".format(peakCaller, sample_name), x) for x in list(itertools.chain.from_iterable([expand("CSAW.{change_dir}.cov.matrix",change_dir=change_direction),expand("CSAW.{change_dir}.cov.heatmap.png",change_dir=change_direction)]))])
                file_list.append([os.path.join("AnnotatedResults_{}_{}".format(peakCaller,sample_name), x) for x in list(itertools.chain.from_iterable([expand("Filtered.results.{change_dir}_withNearestGene.txt",change_dir=change_direction)]))])
                if chip_samples_w_ctrl and not useSpikeInForNorm:
                    file_list.append([os.path.join("CSAW_{}_{}".format(peakCaller, sample_name), x) for x in list(itertools.chain.from_iterable([expand("CSAW.{change_dir}.log2r.matrix",change_dir=change_direction),expand("CSAW.{change_dir}.log2r.heatmap.png",change_dir=change_direction)]))])
        else:
            file_list=expand("CSAW_{}_{}/CSAW.session_info.txt".format(peakCaller, sample_name + ".{compGroup}"),compGroup=cf.returnComparisonGroups(sampleSheet))
            file_list.append(expand("CSAW_{}_{}/Full.results.bed".format(peakCaller, sample_name + ".{compGroup}"),compGroup=cf.returnComparisonGroups(sampleSheet)))
            if not allele_info :
                file_list.append(["Annotation/genes.filtered.symbol","Annotation/genes.filtered.t2g",expand("CSAW_{}_{}/CSAW.Stats_report.html".format(peakCaller, sample_name + ".{compGroup}"),compGroup=cf.returnComparisonGroups(sampleSheet))])
                file_list.append([expand("CSAW_{}_{}".format(peakCaller, sample_name + ".{compGroup}") + "/CSAW.{change_dir}.cov.matrix",change_dir=change_direction,compGroup=cf.returnComparisonGroups(sampleSheet)),
                  expand("CSAW_{}_{}".format(peakCaller, sample_name + ".{compGroup}") +"/CSAW.{change_dir}.cov.heatmap.png",change_dir=change_direction,compGroup=cf.returnComparisonGroups(sampleSheet)
                  )])
                file_list.append(expand("AnnotatedResults_{}_{}".format(peakCaller,sample_name + ".{compGroup}") + "/Filtered.results.{change_dir}_withNearestGene.txt",change_dir=change_direction, compGroup=cf.returnComparisonGroups(sampleSheet)))
                if chip_samples_w_ctrl and not useSpikeInForNorm:
                    file_list.append([expand("CSAW_{}_{}".format(peakCaller, sample_name + ".{compGroup}") + "/CSAW.{change_dir}.log2r.matrix",change_dir=change_direction,compGroup=cf.returnComparisonGroups(sampleSheet)),
                      expand("CSAW_{}_{}".format(peakCaller, sample_name + ".{compGroup}") + "/CSAW.{change_dir}.log2r.heatmap.png",change_dir=change_direction,compGroup=cf.returnComparisonGroups(sampleSheet))
                      ])
        return( file_list )
    else:
        return([])


def run_genrich():
    if (peakCaller == "Genrich"):
        file_list = ["Genrich/all_samples.narrowPeak"]
        if sampleSheet:
            if not isMultipleComparison:
                file_list = ["Genrich/{}.narrowPeak".format(x) for x in genrichDict.keys()]
            else:
                #print(genrichDict)
                unique_group=[]
                file_list=[]
                for key,value in genrichDict.items():
                    for x in value.keys():
                        if zip(key,x) not in unique_group:
                            unique_group.append(zip(key,x))
                            file_list.append("Genrich/{}.{}.narrowPeak".format(x,key))
        return (file_list)
    else:
        return ([])


def run_macs2():
    if peakCaller == "MACS2":
        if useSpikeInForNorm:
            file_list = expand("MACS2/{chip_sample}_host.BAM_peaks.xls", chip_sample = chip_samples)
        else:
            file_list = expand("MACS2/{chip_sample}.filtered.BAM_peaks.xls", chip_sample = chip_samples)
        return (file_list)
    return ([])


def run_seacr():
    if (peakCaller == "SEACR"):
        if useSpikeInForNorm:
            file_list = expand("SEACR/{chip_sample}_host.{mode}.bed", chip_sample = chip_samples,mode=["stringent"])
        else:
            file_list = expand("SEACR/{chip_sample}.filtered.{mode}.bed", chip_sample = chip_samples,mode=["stringent"])
        return (file_list)
    return ([])


def run_chipqc():
    file_list = "{}_chipqc/sessionInfo.txt".format(peakCaller) if not peakCaller == "Genrich" else []
    return(file_list)



### execute before workflow starts #############################################
################################################################################
onstart:
    if "verbose" in config and config["verbose"]:
        print("--- Workflow parameter ---------------------------------------------------------")
        print("control samples:", control_samples)
        print("ChIP samples w ctrl:", chip_samples_w_ctrl)
        print("ChIP samples wo ctrl:", chip_samples_wo_ctrl)
        print("paired:", pairedEnd)
        print("-" * 80, "\n")

        print("--- Environment ----------------------------------------------------------------")
        print("$TMPDIR: ",os.getenv('TMPDIR', ""))
        print("$HOSTNAME: ",os.getenv('HOSTNAME', ""))
        print("-" * 80, "\n")

        print("--- Genome ---------------------------------------------------------------------")
        print("Genome:", genome)
        print("Effective genome size:", genome_size)
        print("Genome FASTA:", genome_fasta)
        print("Genome index:", genome_index)
        print("Genome 2bit:", genome_2bit)
        print("Bowtie2 index:", bowtie2_index)
        print("Gene annotation BED:", genes_bed)
        print("Gene annotation GTF:", genes_gtf)
        print("Blacklist regions BED:", blacklist_bed)
        print("Ignore for normalization (bigwigs):", ignoreForNormalization)
        print("-" * 80, "\n")

    if toolsVersion:
        usedEnvs = [CONDA_SHARED_ENV, CONDA_CHIPSEQ_ENV]
        cf.writeTools(usedEnvs, workingdir, "ChIPseq", maindir)
    if sampleSheet:
        cf.copySampleSheet(sampleSheet, workingdir)


### main rule ##################################################################
################################################################################
rule all:
    input:
        run_deepTools_qc(fromBAM,useSpikeInForNorm),
        run_deepTools_ChIP(),
        run_macs2(),
        run_genrich(),
        run_seacr(),
        # run histoneHMM if allelic_bams are absent (since it gives index error without allele_specific index)
        run_histoneHMM(allele_info,peakCaller),
        ## run deeptools-allelic only if dir "allelic_bams" present and non empty
        run_deepTools_allelic(),
        ## run csaw if asked for
        run_CSAW(),
        "QC_report/QC_report_all.tsv",
        run_chipqc()


### execute after workflow finished ############################################
################################################################################
onsuccess:
    if "verbose" in config and config["verbose"]:
        print("\n--- ChIPseq workflow finished successfully! -----------------------------------\n")

onerror:
    print("\n !!! ERROR in ChIPseq workflow! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n")


