import os
import snakePipes.common_functions as cf
import warnings


### snakemake_workflows initialization ########################################
maindir = os.path.dirname(os.path.dirname(workflow.basedir))

# load conda ENVs (path is relative to "shared/rules" directory)
globals().update(cf.set_env_yamls())

# load config file
globals().update(cf.load_configfile(workflow.overwrite_configfiles[0], config["verbose"]))
# load organism-specific data, i.e. genome indices, annotation, etc.
globals().update(cf.load_organism_data(genome, maindir, config["verbose"]))
# return the pipeline version in the log
cf.get_version()

# do workflow specific stuff now
include: os.path.join(workflow.basedir, "internals.snakefile")

### include modules of other snakefiles ########################################
################################################################################

# deeptools cmds
include: os.path.join(maindir, "shared", "tools" , "deeptools_cmds.snakefile")

## filtered annotation (GTF)
include: os.path.join(maindir, "shared", "rules", "filterGTF.snakefile")

if not "allelic-counting" in mode:
    ## bamCoverage_RPKM
    include: os.path.join(maindir, "shared", "rules", "deepTools_RNA.snakefile")


if not fromBAM:
    ## FASTQ: either downsample FASTQ files or create symlinks to input files
    include: os.path.join(maindir, "shared", "rules", "FASTQ.snakefile")

    ## FastQC
    if fastqc:
        include: os.path.join(maindir, "shared", "rules", "FastQC.snakefile")

    ## Trim
    if trim:
        include: os.path.join(maindir, "shared", "rules", "trimming.snakefile")

    ## sambamba
    include:os.path.join(maindir, "shared", "rules", "sambamba.snakefile")

    ## bam filtering
    include: os.path.join(maindir, "shared", "rules", "bam_filtering.snakefile")
    
    #umi_tools
    include: os.path.join(maindir, "shared", "rules", "umi_tools.snakefile")

    ## Allele-specific JOBS
    if "allelic-mapping" in mode:
        genome_alias = os.path.splitext(os.path.basename(genome))[0]
        # Updated global vars if mode = "allelic-mapping"
        if allele_mode == 'create_and_map':
            star_index_allelic = 'snp_genome/star_Nmasked/Genome'
            if len(strains) == 1:
                allele_hybrid = 'single'
                SNPFile = "snp_genome/all_SNPs_" + strains[0] + "_" + genome_alias + ".txt.gz"
            elif len(strains) == 2:
                allele_hybrid = 'dual'
                SNPFile = "snp_genome/all_" + strains[1] + "_SNPs_" + strains[0] + "_reference.based_on_" + genome_alias + ".txt"

            include: os.path.join(maindir, "shared", "rules", "masked_genomeIndex.snakefile")
        elif allele_mode == 'map_only':
            star_index_allelic = NMaskedIndex
            SNPFile = SNPfile
        ## mapping rules
        include: os.path.join(maindir, "shared", "rules", "RNA_mapping_allelic.snakefile")
        ## SNPsplit
        include: os.path.join(maindir, "shared", "rules", "SNPsplit.snakefile")
        # deepTools QC
        include: os.path.join(maindir, "shared", "rules", "deepTools_RNA_allelic.snakefile")
        if "alignment-free" in mode:
            include: os.path.join(maindir, "shared", "rules", "Salmon_allelic.snakefile")
            include: os.path.join(maindir, "shared", "rules", "sleuth.singleComp.snakefile")
#            include: os.path.join(maindir, "shared", "rules", "DESeq2.singleComp.snakefile")
    elif "allelic-whatshap" in mode:
        include: os.path.join(maindir, "shared", "rules", "RNA_mapping.snakefile")
        include: os.path.join(maindir, "shared", "rules", "whatshap.snakefile")
        include: os.path.join(maindir, "shared", "rules", "deepTools_RNA_allelic.snakefile")
        if "alignment-free" in mode:
            include: os.path.join(maindir, "shared", "rules", "Salmon_allelic.snakefile")
            include: os.path.join(maindir, "shared", "rules", "sleuth.singleComp.snakefile")            
    else:
        # HISAT2/STAR
        include: os.path.join(maindir, "shared", "rules", "RNA_mapping.snakefile")
        ## Salmon
        if "alignment-free" in mode:
            include: os.path.join(maindir, "shared", "rules", "Salmon.snakefile")
            ## Sleuth (on Salmon)
            if sampleSheet:
                if isMultipleComparison:
                    include: os.path.join(maindir, "shared", "rules", "sleuth.multiComp.snakefile")
                else:
                    include: os.path.join(maindir, "shared", "rules", "sleuth.singleComp.snakefile")

else:
    fastqc=False
    trim=False
    downsample=None
    if "allelic-mapping" in mode:
        allele_mode = "from_bam"
        star_index_allelic = NMaskedIndex
        SNPFile = SNPfile
        # SNPsplit
        include: os.path.join(maindir, "shared", "rules", "SNPsplit.snakefile")
        # deepTools QC
        include: os.path.join(maindir, "shared", "rules", "deepTools_RNA_allelic.snakefile")
    elif "allelic-counting" in mode:
        allele_mode = "from_split_bam"
        include: os.path.join(maindir, "shared", "rules", "deepTools_RNA_allelic.snakefile")
    elif "allelic-whatshap" in mode:
        include: os.path.join(maindir, "shared", "rules", "whatshap.snakefile")
        include: os.path.join(maindir, "shared", "rules", "deepTools_RNA_allelic.snakefile")
    include: os.path.join(maindir, "shared", "rules", "LinkBam.snakefile")


if "allelic-mapping" in mode or "allelic-counting" in mode or "allelic-whatshap" in mode:
    ## featureCounts_allelic
    include: os.path.join(maindir, "shared", "rules", "featureCounts_allelic.snakefile")
else:
    ## non-allelic featureCounts
    include: os.path.join(maindir, "shared", "rules", "featureCounts.snakefile")

# Three Prime Sequencing mode
if "three-prime-seq" in mode:
    include: os.path.join(maindir, "shared", "rules", "three_prime_seq.snakefile")

##Genomic_contamination
include: os.path.join(maindir, "shared", "rules", "GenomicContamination.snakefile")


## DESeq2
if sampleSheet and not "three-prime-seq" in mode:
    if isMultipleComparison :
        include: os.path.join(maindir, "shared", "rules", "DESeq2.multipleComp.snakefile")
        if rMats and not "allelic-mapping" in mode and not "allelic-counting" in mode:
            include: os.path.join(maindir, "shared", "rules", "rMats.multipleComp.snakefile")
    else:
        include: os.path.join(maindir, "shared", "rules", "DESeq2.singleComp.snakefile")
        if rMats and not "allelic-mapping" in mode and not "allelic-counting" in mode:
            include: os.path.join(maindir, "shared", "rules", "rMats.singleComp.snakefile")

## MultiQC
include: os.path.join(maindir, "shared", "rules", "multiQC.snakefile")

### conditional/optional rules #################################################
################################################################################
def run_FastQC(fastqc):
    readsUse = reads
    if not pairedEnd:
        readsUse = [reads[0]]
    if fastqc:
        return( expand("FastQC/{sample}{read}_fastqc.html", sample = samples, read = readsUse) )
    else:
        return([])


def run_Trimming(trim, fastqc):
    readsUse = reads
    if not pairedEnd:
        readsUse = [reads[0]]
    if trim and fastqc:
        return( expand(fastq_dir+"/{sample}{read}.fastq.gz", sample = samples, read = readsUse) +
                expand("FastQC_trimmed/{sample}{read}_fastqc.html", sample = samples, read = readsUse) )
    elif trim:
        return( expand(fastq_dir+"/{sample}{read}.fastq.gz", sample = samples, read = readsUse) )
    else:
        return([])


def run_alignment_free():
    if "alignment-free" in mode:
        if not "allelic-mapping" in mode and not "allelic-whatshap" in mode:
            file_list = [
            expand("Salmon/{sample}.quant.sf", sample=samples),
            "Salmon/TPM.transcripts.tsv",
            "Salmon/counts.transcripts.tsv",
            expand("Salmon/{sample}/abundance.h5", sample=samples) ]

            if sampleSheet:
                sample_name = os.path.splitext(os.path.basename(sampleSheet))[0]
                if not isMultipleComparison:
                    file_list.append( ["DESeq2_Salmon_{}/DESeq2.session_info.txt".format(sample_name)]) if not LRT else file_list.append( ["DESeq2_Salmon_{}_LRT/DESeq2.session_info.txt".format(sample_name)])
                    if cf.check_replicates(sampleSheet):
                        file_list.append( ["sleuth_Salmon_{}/so.rds".format(sample_name)] )
                    if rMats:
                        file_list.append( ["rMats_{}/RI.MATS.JCEC.txt".format(sample_name)] )
                        file_list.append( ["rMats_{}/b2.csv".format(sample_name)] )
                else:
                    file_list.append(expand("DESeq2_Salmon_{}".format(sample_name) + ".{compGroup}/DESeq2.session_info.txt",compGroup=cf.returnComparisonGroups(sampleSheet))) if not LRT else file_list.append(expand("DESeq2_Salmon_{}".format(sample_name) + ".{compGroup}_LRT/DESeq2.session_info.txt",compGroup=cf.returnComparisonGroups(sampleSheet)))
                    file_list.append(expand("sleuth_Salmon_{}".format(sample_name) + ".{compGroup}/so.rds",compGroup=cf.returnComparisonGroups(sampleSheet)))
                    if rMats:
                        file_list.append(expand("rMats_{}".format(sample_name) + ".{compGroup}/RI.MATS.JCEC.txt",compGroup=cf.returnComparisonGroups(sampleSheet)))
                        file_list.append(expand("rMats_{}".format(sample_name) + ".{compGroup}/b2.csv",compGroup=cf.returnComparisonGroups(sampleSheet)))
        else:
            file_list = [
            expand("SalmonAllelic/{sample}.{allelic_suffix}.quant.sf", sample=samples,allelic_suffix=allelic_suffix),
            "SalmonAllelic/TPM.transcripts.tsv",
            "SalmonAllelic/counts.transcripts.tsv",
            expand("SalmonAllelic/{sample}.{allelic_suffix}/abundance.h5", sample=samples,allelic_suffix=allelic_suffix) ]
            if sampleSheet:
                sample_name = os.path.splitext(os.path.basename(sampleSheet))[0]
                if not isMultipleComparison:
                    file_list.append( ["DESeq2_SalmonAllelic_{}/DESeq2.session_info.txt".format(sample_name)]) if not LRT else file_list.append( ["DESeq2_SalmonAllelic_{}_LRT/DESeq2.session_info.txt".format(sample_name)])
                    file_list.append( ["sleuth_SalmonAllelic_{}/so.rds".format(sample_name)] )
        return(file_list)
    else:
        return([])


def run_alignment():
    if "alignment" in mode:
        file_list = [
        expand("filtered_bam/{sample}.filtered.bam", sample = samples),
        expand("filtered_bam/{sample}.filtered.bam.bai", sample = samples),
        expand("featureCounts/{sample}.counts.txt", sample=samples),
        "featureCounts/counts.tsv"
        ]
        if sampleSheet:
            #warnings.warn("LRT value is set to " + str(LRT))
            sample_name = os.path.splitext(os.path.basename(sampleSheet))[0]
            if not isMultipleComparison:
                if LRT:
                    #warnings.warn("LRT value is set to " + str(LRT) + "\n Appending " + "DESeq2_{}_LRT/DESeq2.session_info.txt to file list".format(sample_name))
                    file_list.append( ["DESeq2_{}_LRT/DESeq2.session_info.txt".format(sample_name)] )
                else:
                    #warnings.warn("LRT value is set to " + str(LRT) + "\n Appending " + "DESeq2_{}/DESeq2.session_info.txt to file list".format(sample_name))
                    file_list.append( ["DESeq2_{}/DESeq2.session_info.txt".format(sample_name)] )
                if rMats:
                    file_list.append( ["rMats_{}/RI.MATS.JCEC.txt".format(sample_name)] )
                    file_list.append( ["rMats_{}/b2.csv".format(sample_name)] )
            else:
                if LRT:
                    file_list.append(expand("DESeq2_{}".format(sample_name) + ".{compGroup}_LRT/DESeq2.session_info.txt",compGroup=cf.returnComparisonGroups(sampleSheet)))
                else:
                    file_list.append(expand("DESeq2_{}".format(sample_name) + ".{compGroup}/DESeq2.session_info.txt",compGroup=cf.returnComparisonGroups(sampleSheet)))
                if  rMats:
                    file_list.append(expand("rMats_{}".format(sample_name) + ".{compGroup}/RI.MATS.JCEC.txt",compGroup=cf.returnComparisonGroups(sampleSheet)))
                    file_list.append(expand("rMats_{}".format(sample_name) + ".{compGroup}/b2.csv",compGroup=cf.returnComparisonGroups(sampleSheet)))
        return(file_list)
    else:
        return([])

def make_nmasked_genome():
    if allele_mode == 'create_and_map':
        genome1 = "snp_genome/" + strains[0] + '_SNP_filtering_report.txt'
        file_list = [
                genome1,
                SNPFile,
                star_index_allelic,
                ]
        return(file_list)
    else:
        return([])

def run_allelesp_mapping():
    if "allelic-mapping" in mode or "allelic-counting" in mode or "allelic-whatshap" in mode:
        allele_suffix = ['allele_flagged', 'genome1', 'genome2', 'unassigned']
        file_list = [
        expand("allelic_bams/{sample}.{suffix}.sorted.bam", sample = samples,
                        suffix = allele_suffix),
        expand("allelic_bams/{sample}.{suffix}.sorted.bam.bai", sample = samples,
                        suffix = allele_suffix),
        expand("bamCoverage/allele_specific/{sample}.{suffix}.RPKM.bw", sample = samples,
                        suffix = ['genome1', 'genome2']),
        expand("featureCounts/{sample}.allelic_counts.txt", sample=samples),
        "featureCounts/counts_allelic.tsv"
        ]
        if sampleSheet:
            sample_name = os.path.splitext(os.path.basename(sampleSheet))[0]
            if not isMultipleComparison:
                file_list.append( ["DESeq2_{}/DESeq2.session_info.txt".format(sample_name)] ) if not LRT else file_list.append( ["DESeq2_{}_LRT/DESeq2.session_info.txt".format(sample_name)] )
                #if rMats:
                    #file_list.append( ["rMats_{}/RI.MATS.JCEC.txt".format(sample_name)] )
                    #file_list.append( ["rMats_{}/b2.csv".format(sample_name)] )
            else:
                file_list.append(expand("DESeq2_{}".format(sample_name) + ".{compGroup}/DESeq2.session_info.txt",compGroup=cf.returnComparisonGroups(sampleSheet))) if not LRT else file_list.append(expand("DESeq2_{}".format(sample_name) + ".{compGroup}_LRT/DESeq2.session_info.txt",compGroup=cf.returnComparisonGroups(sampleSheet)))
        return(file_list)
    else:
        return([])

def run_deepTools_qc():
    if "deepTools_qc" in mode:
        file_list = [
        expand("bamCoverage/{sample}.RPKM.bw", sample = samples),
        expand("bamCoverage/{sample}.scaleFactors.bw",sample = samples),
        expand("bamCoverage/{sample}.coverage.bw", sample = samples),
        expand("bamCoverage/{sample}.uniqueMappings.fwd.bw", sample = samples),
        expand("bamCoverage/{sample}.uniqueMappings.rev.bw", sample = samples),
        "deepTools_qc/plotEnrichment/plotEnrichment.tsv",
        expand("deepTools_qc/estimateReadFiltering/{sample}_filtering_estimation.txt",sample = samples)]
        if pairedEnd:
            file_list.append("deepTools_qc/bamPEFragmentSize/fragmentSize.metric.tsv")
        if len(samples)>1:
            file_list.append( ["deepTools_qc/multiBigwigSummary/coverage.bed.npz",
                              "deepTools_qc/plotCorrelation/correlation.pearson.bed_coverage.tsv",
                              "deepTools_qc/plotCorrelation/correlation.spearman.bed_coverage.tsv",
                              "deepTools_qc/plotPCA/PCA.bed_coverage.tsv"] )
        if 'allelic-mapping' in mode or 'allelic-whatshap' in mode:
            file_list.append(["deepTools_qc/plotEnrichment/plotEnrichment_allelic.tsv"])
            if len(samples)>1:
                file_list.append( ["deepTools_qc/multiBigwigSummary/coverage_allelic.bed.npz",
                                   "deepTools_qc/plotCorrelation/correlation.pearson.bed_coverage_allelic.tsv",
                                   "deepTools_qc/plotCorrelation/correlation.spearman.bed_coverage_allelic.tsv",
                                   "deepTools_qc/plotPCA/PCA.bed_coverage_allelic.tsv"] )
        return(file_list)
    else:
        return([])


def run_threePrimeSeq():
    file_list = []
    if "three-prime-seq" in mode: 
        file_list += [
            expand("filtered_bam/{sample}.filtered.bam",sample=samples),
            "three_prime_seq/combined_polyA.png",
            "three_prime_seq/counts.tsv",
            "featureCounts/counts.tsv"
        ]
        if sampleSheet:
            sample_name = os.path.splitext(os.path.basename(sampleSheet))[0]
            if not isMultipleComparison:
                file_list.append( ["DESeq2_{}/DESeq2.session_info.txt".format(sample_name)] )
    return file_list


def run_GenomicContamination():
    if dnaContam:
       file_list = ["GenomicContamination/genomic_contamination_featurecount_report.tsv"]
       return (file_list)
    else:
       return([])

### execute before  starts #####################################################
################################################################################
onstart:
    if "verbose" in config and config["verbose"] and not fromBAM:
        print()
        print("--- Workflow parameters --------------------------------------------------------")
        print("mode:", mode)
        print("samples:", samples)
        print("paired:", pairedEnd)
        print("read extension:", reads)
        print("fastq dir:", fastq_dir)
        print("filter GTF:", filterGTF)
        print("FeatureCounts library type:", libraryType)
        print("Sample sheet:", sampleSheet)
        print("-" * 80, "\n")

        print("--- Environment ----------------------------------------------------------------")
        print("$TMPDIR: ", os.getenv('TMPDIR', ""))
        print("$HOSTNAME: ", os.getenv('HOSTNAME', ""))
        print("-" * 80, "\n")

    elif "verbose" in config and config["verbose"] and fromBAM:
        print("--- Workflow parameters --------------------------------------------------------")
        print("samples:" + str(samples))
        print("input dir:" + indir)
        print("paired:", pairedEnd)
        print("-" * 80)  # , "\n"

        print("--- Environment ----------------------------------------------------------------")
        print("$TMPDIR: " + os.getenv('TMPDIR', ""))
        print("$HOSTNAME: " + os.getenv('HOSTNAME', ""))
        print("-" * 80,)

    if toolsVersion:
        usedEnvs = [CONDA_SHARED_ENV, CONDA_RNASEQ_ENV]
        cf.writeTools(usedEnvs, outdir, "mRNAseq", maindir)
    if sampleSheet:
        cf.copySampleSheet(sampleSheet, outdir)

### main rule ##################################################################
################################################################################

if not fromBAM:
    rule all:
        input:
            run_FastQC(fastqc),
            run_Trimming(trim, fastqc),
            run_alignment_free(),            # Salmon
            run_alignment(),                 # classical mapping + counting
            run_allelesp_mapping(),        # allelic-mapping
            make_nmasked_genome(),
            run_deepTools_qc(),
            run_GenomicContamination(),
            run_threePrimeSeq(),
            "multiQC/multiqc_report.html"

else:

    rule all:
        input:
            run_alignment(),
            run_allelesp_mapping(),
            run_deepTools_qc(),
            run_GenomicContamination(),
            "multiQC/multiqc_report.html"

### execute after  finished ####################################################
################################################################################
onsuccess:
    if "verbose" in config and config["verbose"]:
        print("\n--- RNA-seq workflow finished successfully! ------------------------------------\n")

onerror:
    print("\n !!! ERROR in RNA-seq workflow! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n")
