""" Pipeline for calling SNPs, subsampling, and preparing scripts for running Structure on the DLX cluster. """

import Bio.SeqIO
import collections
import datetime
import os
import plotly
import time
import json
import numpy as np
import pandas as pd

subworkflow haps:
    workdir: "../haplotype_pipeline"

configfile: "../config/config.yml"
wildcards = glob_wildcards("haplotypes_curated/{locus}.fna")
loci = wildcards.locus
ruleorder: snp_subsample_skip_filter > snp_subsample

rule all:
    input:
        expand("snp_sites/{locus}.fna", locus=loci),
        expand("structure/strx_{n}_k{k}.sh", n=range(1, config["snp_subsample"]["num_subsamples"] + 1), k=range(config["structure"]["k_min"], config["structure"]["k_max"] + 1)),
        expand("structure/strx_all-snps-all-loci_k{k}.sh", k=range(config["structure"]["k_min"], config["structure"]["k_max"] + 1)),
        "results/snp_histogram.html"

rule set_indivs_to_drop:
    input:
        csv=haps("results/read_counts.csv")
    output:
        csv="results/read_counts_filtered.csv",
        json_loci="results/indivs_to_drop_from_loci.json",
        json_indivs="results/indivs_to_drop_all.json"
    params:
        read_count_threshold=config['read_count_threshold'],
        loci_missing_threshold=config['loci_missing_threshold']
    log:
        "logs/set_indivs_to_drop.log"
    benchmark:
        "benchmarks/set_indivs_to_drop.txt"
    run:
        read_counts = pd.read_csv(input.csv, index_col=0)
        read_counts_filtered = read_counts.copy()
        num_loci = len(read_counts.index)

        indivs_to_drop_from_loci = collections.defaultdict(list)  # indiv: set{locus1, locus2...}
        for locus in read_counts.index:
            for indiv_id in read_counts.columns:
                count = read_counts.at[locus, indiv_id]
                if count < params.read_count_threshold:
                    indivs_to_drop_from_loci[indiv_id].append(locus)
                    read_counts_filtered.at[locus, indiv_id] = 0

        indivs_to_drop_all = [indiv for indiv in indivs_to_drop_from_loci
                              if len(indivs_to_drop_from_loci[indiv]) >= params.loci_missing_threshold * num_loci]
        for indiv in indivs_to_drop_all:
            indivs_to_drop_from_loci.pop(indiv)
        read_counts_filtered = read_counts_filtered.drop(list(indivs_to_drop_all), axis=1)

        read_counts_filtered.to_csv(output.csv)
        with open(output.json_loci, "w") as out_file:
            json.dump(indivs_to_drop_from_loci, out_file)
        with open(output.json_indivs, "w") as out_file:
            json.dump(indivs_to_drop_all, out_file)

rule filter_haps:
    input:
        fna="haplotypes_curated/{locus}.fna",
        json_loci=rules.set_indivs_to_drop.output.json_loci,
        json_indivs=rules.set_indivs_to_drop.output.json_indivs
    output:
        fna="haplotypes_filtered/{locus}.fna"
    log:
        "logs/filter_haps/{locus}.log"
    benchmark:
        "benchmarks/filter_haps/{locus}.txt"
    run:
        with open(input.json_indivs) as json_file:
            indivs_to_drop_all = set(json.load(json_file))
        with open(input.json_loci) as json_file:
            indivs_to_drop_from_loci = {key: set(value) for key, value in json.load(json_file).items()}
        fasta_file = Bio.SeqIO.parse(input.fna, 'fasta')
        locus = wildcards.locus
        filtered_seqs = [seq for seq in fasta_file if ((seq.id.split('_')[0] not in indivs_to_drop_all) and
                                                      ((seq.id.split('_')[0] not in indivs_to_drop_from_loci) or
                                                      (locus not in indivs_to_drop_from_loci[seq.id.split('_')[0]])))]
        Bio.SeqIO.write(filtered_seqs, output.fna, "fasta")

rule snp_sites:
    input:
        fna=rules.filter_haps.output.fna
    output:
        fna="snp_sites/{locus}.fna"
    log:
        "logs/snp_sites/{locus}.log"
    benchmark:
        "benchmarks/snp_sites/{locus}.txt"
    shell:
        "snp-sites -o {output} {input} 2> {log}"

rule snp_subsample:
    input:
        expand("snp_sites/{locus}.fna", locus=loci)
    output:
        expand("snp_sites_filtered/{locus}.fna", locus=loci),
        expand("snp_subsamples/{filename_base}{subsample}_strx.txt", filename_base=config["snp_subsample"]["output_filename_base"], subsample=range(1, 1 + config["snp_subsample"]["num_subsamples"])),
        "snp_subsamples/all-snps-all-loci_strx.txt"
    params:
        fn_base = config["snp_subsample"]["output_filename_base"],
        num_subsamples = config["snp_subsample"]["num_subsamples"],
        output_format = config["snp_subsample"]["output_format"]
    shell:
        "python scripts/snp_subsample.py snp_sites/ snp_subsamples/{params.fn_base} "
        "--all-snps-all-loci --output-format {params.output_format} --num_subsamples {params.num_subsamples} "
        "--output-filtered-fasta-dir snp_sites_filtered/"

rule snp_subsample_skip_filter:
    input:
        expand("snp_sites_filtered/{locus}.fna", locus=loci)
    output:
        expand("snp_subsamples/{filename_base}{subsample}_strx.txt", filename_base=config["snp_subsample"]["output_filename_base"], subsample=range(1, 1 + config["snp_subsample"]["num_subsamples"])),
        "snp_subsamples/all-snps-all-loci_strx.txt"
    params:
        fn_base = config["snp_subsample"]["output_filename_base"],
        num_subsamples = config["snp_subsample"]["num_subsamples"],
        output_format = config["snp_subsample"]["output_format"]
    shell:
        "python scripts/snp_subsample.py snp_sites/ snp_subsamples/{params.fn_base} "
        "--all-snps-all-loci --output-format {params.output_format} --num_subsamples {params.num_subsamples} "
        "--skip-filter"


rule write_subsample_strx_script:
    input:
        "snp_subsamples/{filename_base}{{subsample}}_strx.txt".format(filename_base=config["snp_subsample"]["output_filename_base"])
    output:
        expand("structure/strx_{{subsample}}_k{k}.sh", k=range(config["structure"]["k_min"], config["structure"]["k_max"] + 1))
    params:
        subsample = "{subsample}",
        strx_path = config["structure"]["path"],
        num_runs = config["structure"]["num_runs"],
        k_min = config["structure"]["k_min"],
        k_max = config["structure"]["k_max"],
        email = config["structure"]["email"]
    run:
        N = 0  # individuals
        L = 0  # loci
        with open(input[0], 'r') as file:
            is_first = True
            for line in file:
                N += 1
                if is_first:
                    is_first = False
                    L = len(line.split()) - 1
        N = N / 2
        for k in range(params.k_min, params.k_max + 1):
            script = "#!/bin/bash\n#SBATCH --mail-type=ALL\n#SBATCH --mail-user={email}\nSBATCH_NODELIST: $SBATCH_NODELIST\n\n".format(email=params.email)
            for run in range(1, params.num_runs + 1):
                script += "{strx_path} -i {input_fn} -K {k} -L {L} -N {N} -o {base}{subsample}_k{k}_run{run}.out &\nsleep 10\n\n".format(strx_path=params.strx_path, input_fn=input[0], k=k, L=L, N=N, run=run, subsample=params.subsample, base=config["snp_subsample"]["output_filename_base"])
            script += 'wait'
            with open("structure/strx_{subsample}_k{k}.sh".format(subsample=params.subsample, k=k), 'w') as script_file:
                script_file.write(script)

rule write_all_snps_strx_script:
    input:
        "snp_subsamples/all-snps-all-loci_strx.txt"
    output:
        expand("structure/strx_all-snps-all-loci_k{k}.sh", k=range(config["structure"]["k_min"], config["structure"]["k_max"] + 1))
    params:
        strx_path = config["structure"]["path"],
        num_runs = config["structure"]["num_runs"],
        k_min = config["structure"]["k_min"],
        k_max = config["structure"]["k_max"],
        email = config["structure"]["email"]
    run:
        N = 0  # individuals
        L = 0  # loci
        with open(input[0], 'r') as file:
            is_first = True
            for line in file:
                N += 1
                if is_first:
                    is_first = False
                    L = len(line.split()) - 1
        N = N / 2
        for k in range(params.k_min, params.k_max + 1):
            script = "#!/bin/bash\n#SBATCH --mail-type=ALL\n#SBATCH --mail-user={email}\nSBATCH_NODELIST: $SBATCH_NODELIST\n\n".format(email=params.email)
            for run in range(1, params.num_runs + 1):
                script += "{strx_path} -i {input_fn} -K {k} -L {L} -N {N} -o all-snps-all-loci_k{k}_run{run}.out &\nsleep 10\n\n".format(strx_path=params.strx_path, input_fn=input[0], k=k, L=L, N=N, run=run)
            script += 'wait'
            with open("structure/strx_all-snps-all-loci_k{k}.sh".format(k=k), 'w') as script_file:
                script_file.write(script)

rule histogram:
    input:
        snps_raw=expand("snp_sites/{locus}.fna", locus=loci),
        snps_filtered=expand("snp_sites_filtered/{locus}.fna", locus=loci)
    output:
        hist="results/snp_histogram.html",
        log="logs/histogram/snps_per_locus.log"
    run:
        snps_filt_hist = plotly.graph_objs.Histogram(x=[len(next(Bio.SeqIO.parse(filename, 'fasta')).seq) for filename in input.snps_filtered], name="snps filtered", autobinx=True, opacity=0.5)
        snps_raw_hist = plotly.graph_objs.Histogram(x=[len(next(Bio.SeqIO.parse(filename, 'fasta')).seq) for filename in input.snps_raw], name="snps unfiltered", autobinx=True, opacity=0.5)
        data = [snps_filt_hist, snps_raw_hist]
        layout = plotly.graph_objs.Layout(barmode='overlay', title=('SNP sites per locus'), xaxis=dict(title='Number of SNPs'), yaxis=dict(title='Number of loci'))
        plotly.offline.plot(plotly.graph_objs.Figure(data=data, layout=layout), filename=output.hist, auto_open=False)
        avg_snps_per_locus = str(sum(len(next(Bio.SeqIO.parse(filename, 'fasta')).seq) for filename in input.snps_filtered) / len(input.snps_filtered))
        with open(output.log, 'w') as logfile:
            logfile.write(avg_snps_per_locus)
