""" Pipeline for calling haplotypes on Illumina sequences and combining with haplotypes from 454 data. """

import Bio.SeqIO
import collections
import datetime
import os
import plotly
import pprint
import time
import shutil
import json
import numpy as np
import pandas as pd

configfile: "../config/config.yml"
reference_file = config["reference"]
reference_sequences = {seq_record.id: seq_record for seq_record in Bio.SeqIO.parse(reference_file, 'fasta')}

loci=set(reference_sequences.keys())
loci_454 = {fn.strip(".fna") for fn in os.listdir(config["454_haplotypes_dir"])}
loci_not_in_454 = loci - loci_454
pair1, pair2 = "R1", "R2"
hap_numbers = ["1", "2"]

ruleorder: combine_454_and_illumina > include_loci_not_in_454
include: "rules/haplotype_illumina_data.smk"
include: "rules/align_454_haplotypes.smk"

rule all:
    input:
        "results/matrix.tsv",
        "results/loci_histogram.html",
        "results/read_counts.csv",
        'figures/reads_heatmap.svg',
        'figures/reads_samples.svg',
        'figures/reads_loci.svg'

rule combine_454_and_illumina:
    input:
        haps_454="intermediates_454/haps_aligned_filtered/{locus}.fna",
        haps_illumina="intermediates_illumina/haplotypes/{locus}.fna"
    output:
        protected("haplotypes/{locus}.fna")
    shell:
        "cat {input} > {output}"

rule include_loci_not_in_454:
    input:
        "intermediates_illumina/haplotypes/{locus}.fna"
    output:
        protected("haplotypes/{locus}.fna")
    params:
        locus="{locus}"
    run:
        if params.locus in loci_not_in_454:
            shutil.copy(input[0], output[0])

rule build_matrix:
    input:
        expand("haplotypes/{locus}.fna", locus=loci)
    output:
        matrix='results/matrix.tsv',
        histogram='results/loci_histogram.html'
    benchmark:
        "benchmarks/build_matrix.txt"
    run:
        individuals_to_seq_runs = collections.defaultdict(set)
        for working_dir, subdirs, files in os.walk(os.path.expanduser(config['illumina_fastq_dir'])):
            for filename in files:
                if filename[0] != '.' and filename[:4] != "Icon" :
                    indiv_id = filename.split('_')[0]
                    base, subdir = os.path.split(working_dir)
                    individuals_to_seq_runs[indiv_id].add(subdir)
        for filename in os.listdir(config["454_haplotypes_dir"]):
            with open(os.path.join(config["454_haplotypes_dir"], filename), 'r') as infile:
                for seq_record in Bio.SeqIO.parse(infile, 'fasta'):
                    individuals_to_seq_runs[seq_record.id.split('_')[0]].add("Tigrinum_454_2013")
        matrix = "individual_id\tsequencing_runs"
        input = sorted(input)
        individuals_to_loci = collections.defaultdict(set)
        for locus_filename in input:
            locus_id = locus_filename.split('/')[-1].split('.fna')[0]
            matrix += "\t" + locus_id
            with open(locus_filename, 'r') as file:
                for line in file:
                    if line[0] == '>':
                        indiv_id = line[1:].split('_')[0]
                        individuals_to_loci[indiv_id].add(locus_id)
        for indiv_id in sorted(individuals_to_loci):
            newline = '\n' + indiv_id + '\t' + ', '.join(sorted(individuals_to_seq_runs[indiv_id]))
            #print(individuals_to_seq_runs[indiv_id], '\tnewline:',newline)
            matrix += newline
            for locus_id in loci:
                matrix += '\t1' if locus_id in individuals_to_loci[indiv_id] else '\t0'
        with open(output.matrix, 'w') as file:
            file.write(matrix)
        data = [plotly.graph_objs.Histogram(x=[len(individuals_to_loci[indiv_id]) for indiv_id in individuals_to_loci])]
        layout = plotly.graph_objs.Layout(title=('Individuals in Loci'), xaxis=dict(title='Number of Loci'), yaxis=dict(title='Number of Individuals'))
        figure = plotly.graph_objs.Figure(data=data,layout=layout)
        plotly.offline.plot(figure, filename=output.histogram, auto_open=True)


rule plot_read_counts:
    input:
        'results/read_counts.csv'
    output:
        heatmap='figures/reads_heatmap.svg',
        barplot_samples='figures/reads_samples.svg',
        barplot_loci='figures/reads_loci.svg'
    run:
        read_counts_df = pd.read_csv(input[0], index_col=0)
        # heatmap
        heatmap, axis = plt.subplots(figsize=(6,4))
        sns.heatmap(read_counts_df, ax=axis)
        axis.set_title("Read counts")
        heatmap.savefig(output.heatmap, format='svg')
        plt.close()
        # samples barplot
        sample_read_counts = read_counts_df.sum()
        samples = sample_read_counts.plot(kind='bar', title='Read counts per sample', figsize=(6,4))
        plt.savefig(output.barplot_samples, format='svg')
        plt.close()
        # loci barplot
        locus_read_counts = read_counts_df.sum(axis=1)
        loci = locus_read_counts.plot(kind='bar', title='Read counts per locus', figsize=(16,9))
        plt.savefig(output.barplot_loci, format='svg')
        plt.close()

rule count_reads:
    input:
        expand("intermediates_illumina/alignments/{sample}.sorted.bam", sample=samples)
    output:
        'results/read_counts.csv'
    run:
        read_counts_dict = dict()
        for filename in input:
            sample_id = filename.split('/')[-1].strip('.sorted.bam')
            bam = HTSeq.BAM_Reader(filename)
            read_counts_dict[sample_id] = collections.Counter(aln.iv.chrom for aln in bam if aln.iv)
        read_counts_df = pd.DataFrame(read_counts_dict)
        read_counts_df = read_counts_df.fillna(0)
        read_counts_df = read_counts_df[read_counts_df.columns].astype(int)
        read_counts_df.to_csv(output[0])