"""
This plots the amount that the genes in each group are mixed.

This version uses an unwrapped groupby dataframe.
Needs coordinates from the genomes
"""
# This block imports fasta-parser as fasta
import os
import sys
snakefile_path = os.path.dirname(os.path.realpath(workflow.snakefile))
dependencies_path = os.path.join(snakefile_path, "../dependencies/fasta-parser")
sys.path.insert(1, dependencies_path)
import fasta

import matplotlib.pyplot as plt
import matplotlib.patches as mplpatches
import odp_functions as OdpF
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np

import operator
import pandas as pd
import random
import sys

def compress_homopolymer_string(input_string):
    """
    removes stretches of homopolymers from the string
    """
    newstring =""
    prev = ""
    for thischar in input_string:
        if thischar != prev:
            newstring += thischar
        prev = thischar
    return newstring

class MixedString:
    def __init__(self, genestring, num_iterations = 1000000, compress_homopolymer = True):
        self.compress_homopolymer = compress_homopolymer
        self.origstring = genestring
        # calculate compressed string
        self.compressed_string = compress_homopolymer_string(self.origstring)
        if compress_homopolymer:
            self.genestring = self.compressed_string
        else:
            self.genestring = self.origstring
        # first ensure there are only 2 chars
        self.charcount = set()
        for i in range(len(self.genestring)):
            self.charcount.add(self.genestring[i])
            if len(self.charcount) > 2:
                raise IOError("The string must only have two characters. It was {}.".format(instr))
        self.charcount_list = list(self.charcount)
        self.charA      = self.charcount_list[0]
        self.charB      = self.charcount_list[1]
        self.charAcount = self.genestring.count(self.charcount_list[0])
        self.charBcount = self.genestring.count(self.charcount_list[1])
        self.num_iterations = num_iterations
        # calculate the actual mixing value
        #print("# orig string")
        #print(self.origstring)
        self.mixing_value = self.quantify_mixing(self.genestring)
        # calculate the distribution of possible values given this string
        self.dist_dict = self.distribution_of_mixing(
                      self.genestring, self.compress_homopolymer, self.num_iterations)
        self.mode = self.mode_of_dist(self.dist_dict)
        self.mean = self.mean_of_dist(self.dist_dict, self.num_iterations)
        # calculate other measures
        self.number_of_sims_leq_mode = sum([self.dist_dict[x] for x
                                            in self.dist_dict if x <= self.mode])
        self.numsims_between_0_and_mixingvalue = sum([self.dist_dict[x] for x
                                            in self.dist_dict
                                            if x <= self.mixing_value])
        # I don't think this measure is appropriate given
        #  that the Gaussian distribution naturally tightens around
        #  0.5 as the number of genes in A and B group increase
        self.p_value_unmixed = self.numsims_between_0_and_mixingvalue/self.num_iterations

    def quantify_mixing(self, instr):
        """
        quantifies how well mixed a string is.
        This doesn't measure entropy, but returns a float between
               0 (unmixed, just completely sequential)
           and 1 (ABABABABABA - completely alternating)
        We return
          # first measure
          m= (AB + BA - 1)/(len(instr)-2)
          # second measure
          m= (AB + BA - 1)/(min(|A|, |B|) - 2)
          # third measure
          m= (AB + BA - 1)/(A*(B/(|A| + |B|)) + B*(A/(|A| + |B|)))
        """
        # first ensure there are only 2 chars
        charcount = set()
        for i in range(len(instr)):
            charcount.add(instr[i])
            if len(charcount) > 2:
                raise IOError("The string must only have two characters. It was {}.".format(instr))
        # Now count
        chars = [x for x in charcount]
        Ac = chars[0]
        NA = instr.count(Ac)
        Bc = chars[1]
        NB = instr.count(Bc)

        transitions = ["{}{}".format(Ac, Bc), "{}{}".format(Bc, Ac)]
        Tcount = 0
        for i in range(len(instr)-1):
            pair = instr[i:i+2]
            if pair in transitions:
                Tcount += 1

        #print("instr", instr)
        #print("   Tcount - 1: ", Tcount-1)
        #print("len(instr) -2: ", len(instr) -2)
        #print()
        # figure out how many possible transitions there are
        #minlen = min(instr.count(Ac), instr.count(Bc))
        #double_minlen = minlen * 2
        #if double_minlen == 2:
        #    return(0)
        #else:
        #    return (Tcount-1)/(double_minlen - 1)
        numerator = Tcount - 1
        denominator = (NA*(NB/(len(instr)))) + (NB*(NA/(len(instr)))) - 1
        return numerator/denominator

    def distribution_of_mixing(self, instr, compress_homopolymer, num_of_iterations = 1000000):
        """
        This estimates the distribution of mixing values for a given
         input string. It outputs a data structure that contains the
         distribution
        """
        #print("We're in distribution now. instring: {}, compress = {}".format(instr, compress_homopolymer))
        thisstring = instr
        thisvalue = -1
        dist_counter = {0: 0, 1: 0}
        for i in range(num_of_iterations):
            thisstring = ''.join(random.sample(instr,len(instr)))
            prev = thisstring
            if compress_homopolymer:
                thisstring = compress_homopolymer_string(thisstring)
            thisvalue = self.quantify_mixing(thisstring)
            if thisvalue not in dist_counter:
                dist_counter[thisvalue] = 0
            dist_counter[thisvalue] += 1
            if i % 100 == 0:
                print("             {}/{} ({:.1f} %) done: {} {:.4f}".format(
                    i+1, num_of_iterations, ((i+1)/num_of_iterations)*100,
                    thisstring[:40], thisvalue), end = "\r", file = sys.stderr)
        print("             {}/{} ({:.1f} %) done: {} {:.4f}".format(
            i+1, num_of_iterations, ((i+1)/num_of_iterations)*100,
            thisstring[:40], thisvalue), end = "\r", file = sys.stderr)

        print(file = sys.stderr)
        return dist_counter

    def mode_of_dist(self, dist_dict):
        """
        Finds the mode of the distribution. It is simply the largest bin.
        """
        largest_key = max(dist_dict.items(), key=operator.itemgetter(1))[0]
        largest_value = dist_dict[largest_key]
        return largest_key

    def mean_of_dist(self, dist_dict, num_iterations):
        """
        Finds the mode of the distribution. It is simply the largest bin.
        """
        tot_val = sum([x * dist_dict[x] for x in dist_dict])/num_iterations
        return tot_val

    def __str__(self):
        printstring = "Gene mixing analysis\n"
        printstring += " Original Gene string: {}\n".format(self.genestring)
        printstring += " Compressed    string: {}\n".format(self.genestring)
        printstring += " Used    compressed? : {}\n".format(self.compress_homopolymer)
        printstring += " Char {} count: {}\n".format(self.charA, self.charAcount)
        printstring += " Char {} count: {}\n".format(self.charB, self.charBcount)
        printstring += " mixing  value: {}\n".format(self.mixing_value)
        printstring += "\n"
        printstring += "m\tcount\thist\n"
        sorted_values = list(sorted([x for x in self.dist_dict]))
        num_chars_histo = 40
        max_count = self.dist_dict[self.mode]
        for key in sorted_values:
            colsize = int((self.dist_dict[key]/max_count)*num_chars_histo)
            hist_string = "".join(["*"] * colsize)

            tempstr = "{:.4f}\t{}\t{}".format(
                key, self.dist_dict[key], hist_string)
            printstring = printstring + tempstr + "\n"
        printstring += "\nDistribution analysis\n"
        printstring += " Mode: {}".format(self.mode) + "\n"
        printstring += " Mean: {:.4f}\n".format(self.mean)
        printstring += " input mixing value: {}\n".format(self.mixing_value)
        printstring += " The observed string's mixing value, m, is {},\n".format(self.mixing_value)
        printstring += "  which is {:.2f}% of the way to an ideal randomly\n".format(min(1,(self.mixing_value/self.mode)) * 100)
        printstring += "  mixed state, the mode of {:.4f}.\n".format(self.mode)
        printstring += " The observed string's mixing value, m, is >= to {:.4f}% of all\n".format((self.numsims_between_0_and_mixingvalue/self.number_of_sims_leq_mode) * 100)
        printstring += "  simulations <= the mode (the minimal ideal mixed state)\n"

        return printstring

configfile: "config.yaml"

config["tool"] = "odp_rbh_plot_mixing"

OdpF.check_legality(config)

# make sure we have the right info in the config.yaml
if "rbh_file" not in config:
    raise IOError("You must specify 'rbh_file' in config")
if "merge_pairs" not in config:
    raise IOError("You must specify 'merge_pairs' in config")
if "separate" not in config:
    raise IOError("You must specify 'separate' in config")

# make sure that merge_pairs and separate are the right format
for thispair in config["merge_pairs"]:
    if len(thispair) != 2:
        raise IOError("There should be two groups in this pair: {}".format(thispair))
    if thispair[0] == thispair[1]:
        raise IOError("The pairs should not be the same group: {}".format(thispair))
for thisspecies in config["separate"]:
    if thisspecies not in config["species"]:
        raise IOError("{} in separate must be in 'species' in the config.".format(
            thisspecies))

# make sure none of the sample names have underscores
for thissample in config["species"]:
    if "_" in thissample:
        raise IOError("Sample names can't have '_' char: {}".format(thissample))

rule all:
    input:
        expand(config["tool"] + "/output/{species}_plots.pdf",
               species = config["species"]),
        # now the mixing measures
        expand(config["tool"] + "/output_mixing/{species}_mixing.tsv",
               species = config["species"]),
        expand(config["tool"] + "/output_mixing/{species}_mixing_simulation.txt",
               species = config["species"]),
        # final mixing measure
        config["tool"] + "/output_mixing_merged/all_species_mixing.tsv"

def mixing_analysis(rbh_file, species,
                    grouping_list_of_lists,
                    genome_size_tsv,
                    mixfile, simfile, compress_homopolymer = True):
    """
    Figures out the degree of how much the pairs are mixed.
    Gets the rows in which sample A and sample B have the same chromosome

    outputs a table of the measures of how mixed something is.

    This works on a per-species basis. One file per species, and
     the file will have all of the chromosomes in that species with at
     least 5 genes per group on that chromosome.
    """

    # read in the genome_sizes
    chr_size = {}
    with open(genome_size_tsv, "r") as f:
        for line in f:
            line = line.strip()
            fields = line.split("\t")
            chr_size[fields[1]] = int(fields[2])

    # open the outfiles
    outsim = open(simfile, "w")

    df = pd.read_csv(rbh_file, sep = "\t", index_col = 0)
    spscaf = "{}_scaf".format(species)
    sppos  = "{}_pos".format(species)
    df = df[["gene_group", spscaf, sppos]]
    df = df.dropna()
    df = df.mask(df.eq('None')).dropna()
    #print(df)
    df[sppos] = pd.to_numeric(df[sppos])
    list_of_results = []
    for thisgrouping in grouping_list_of_lists:
        groupA = thisgrouping[0]
        groupB = thisgrouping[1]
        remap_dict = {groupA: "A", groupB: "B"}
        ABdf  = df.loc[df["gene_group"].isin(thisgrouping), ]
        for thisscaf in ABdf[spscaf].unique():
            Adf  = ABdf.loc[ABdf["gene_group"].str.strip() == groupA, ]
            Adf  = Adf.loc[ Adf[spscaf].str.strip() == thisscaf, ]
            Adf  = Adf.sort_values(sppos)
            Bdf  = ABdf.loc[ABdf["gene_group"].str.strip() == groupB, ]
            Bdf  = Bdf.loc[ Bdf[spscaf].str.strip() == thisscaf, ]
            Bdf  = Bdf.sort_values(sppos)
            if len(Adf) >= 3 and len(Bdf) >= 3:
                #print(Adf)
                #print(Bdf)
                bases_covered = 0
                A_range = [min(Adf[sppos]), max(Adf[sppos])]
                B_range = [min(Bdf[sppos]), max(Bdf[sppos])]
                group1 = []
                group2 = []
                if A_range[0] <= B_range[0]:
                    group1 = A_range
                    group2 = B_range
                elif B_range[0] <= A_range[0]:
                    group1 = B_range
                    group2 = A_range
                else:
                    raise IOError("This didn't work and not sure why")

                if group1[1] < group2[1]:
                    if group1[1] >= group2[0]:
                        # scenario1
                        #   group1: AAAAAAAAAAAAAAAAAAAAAA
                        #   group2:      BBBBBBBBBBBBBBBBBBBBBBBB
                        #print("scenario1")
                        bases_covered = group1[1] - group2[0]
                    elif group1[1] < group2[0]:
                        # scenario3
                        #   group1: AAAAAAAAAAAAAAAA
                        #   group2:                        BBBBBBBBBBBBBBB
                        #print("scenario3")
                        bases_covered = 0
                    else:
                        raise IOError("some erors")
                elif group1[1] >= group2[1]:
                    # scenario2
                    #   group1: AAAAAAAAAAAAAAAAAAAAAA
                    #   group2:      BBBBBBBBB
                    #print("scenario2")
                    bases_covered = group2[1] - group2[0]
                else:
                    raise IOError("We shouldn't have gotten here")

                if bases_covered == 0:
                    percent_genome_covered_mixed = 0
                else:
                    percent_genome_covered_mixed = (bases_covered/chr_size[thisscaf])*100
                if percent_genome_covered_mixed > 100:
                    raise IOError ("Percent was greater than 100")
                elif percent_genome_covered_mixed < 0:
                    raise IOError ("Percent was less than 0")

                bases_covered = int(bases_covered)
                merged = pd.concat([Adf, Bdf])
                merged = merged.sort_values(sppos)
                mixstring = "".join(list(merged["gene_group"].map(remap_dict)))
                analysis = MixedString(mixstring, num_iterations = 100000, compress_homopolymer = compress_homopolymer)
                print("# Species   : {}".format(species), file = outsim)
                print("# Grouping A: {}".format(groupA), file = outsim)
                print("# Grouping B: {}".format(groupB), file = outsim)
                print("# Scaffold  : {}".format(thisscaf), file = outsim)
                print(analysis, file = outsim)
                print("# ################", file = outsim)

                # now add something to the table
                list_of_results.append({
                    "species": species,
                    "scaffold": thisscaf,
                    "mixed_groups": "{}+{}".format(groupA, groupB),
                    "genestring": analysis.genestring,
                    "charAcount": analysis.charAcount,
                    "charBcount": analysis.charBcount,
                    "m": analysis.mixing_value,
                    "mode": analysis.mode,
                    "mean": analysis.mean,
                    "bases_covered_mixing": bases_covered,
                    "scaffold_total_bases": chr_size[thisscaf],
                    "percent_scaffold_mixing": percent_genome_covered_mixed,
                    "m_percent_of_way_to_mode": min(1,(analysis.mixing_value/analysis.mode)) * 100,
                    "p_value_unmixed": analysis.p_value_unmixed,
                    "num_iterations": analysis.num_iterations,
                })


    outsim.close()
    results = pd.DataFrame(list_of_results)
    results.to_csv(mixfile, sep = "\t", index = False)

def overlapping_plot(thisspecies, thischrom, chrom_size,
                     left_pos,   right_pos,
                     left_group, right_group,
                     left_color, right_color, pdf):
    """
    This makes a square and shows the extent of mixing of groups
    """

    width = 3
    height = 2
    plt.figure(figsize=(width,2))

    panel1=plt.axes([0.5/width,
                     0.5/height,
                     (width-1)/width,
                     1/height])
    panel1.set_xlim([0, chrom_size])
    panel1.set_ylim([0, 1])
    panel1.tick_params(axis='both',which='both',\
                       bottom='off', labelbottom='off',\
                       left='off', labelleft='off', \
                       right='off', labelright='off',\
                       top='off', labeltop='off')

    rectangle1=mplpatches.Rectangle( (0,0), chrom_size,1,\
                                     linewidth=0,\
                                     linestyle='-',\
                                     edgecolor='black',\
                                     facecolor="#999999")
    panel1.add_patch(rectangle1)


    left_bottom  = 0.05
    left_top     = 0.75
    right_bottom = 0.25
    right_top    = 0.95

    if len(left_pos) > 1:
        # plot the horizontal line for left
        x_values = [min(left_pos), max(left_pos)]
        y_mid = ((left_top - left_bottom)/2) + left_bottom
        y_values = [y_mid, y_mid]
        panel1.plot(x_values,y_values, color = left_color)
    if len(right_pos) > 1:
        # plot the horizontal line for right
        x_values = [min(right_pos), max(right_pos)]
        y_mid = ((right_top - right_bottom)/2) + right_bottom
        y_values = [y_mid, y_mid]
        panel1.plot(x_values,y_values, color = right_color)

    # plot the vertical lines for left
    if len(left_pos) > 0:
        for pos in left_pos:
            x_values = [pos, pos]
            y_values = [left_bottom, left_top]
            panel1.plot(x_values,y_values, color = left_color)

    # plot the vertical lines for right
    if len(right_pos) > 0:
        for pos in right_pos:
            x_values = [pos, pos]
            y_values = [right_bottom, right_top]
            panel1.plot(x_values,y_values, color = right_color)

    thistitle = "{} {} (len {} bp)".format(
        thisspecies, thischrom, chrom_size)
    plt.figtext(0.5, 0.95, thistitle, fontsize=7, color = "black", ha ='center')
    plt.figtext(0.49, 0.90, "{}".format(left_group), fontsize=7, color=left_color, ha ='right')
    plt.figtext(0.51, 0.90, "{}".format(right_group), fontsize=7, color=right_color, ha ='left')

    plt.axis('off')
    pdf.savefig()
    plt.close()

rule read_in_genome_sizes:
    """
    This just reads in the genome sizes because it takes so long.
    """
    input:
        genome = lambda wildcards: config["species"][wildcards.species]["genome"]
    output:
        size = config["tool"] + "/input/{species}_chrom_sizes.tsv"
    threads: 1
    params:
        species = lambda wildcards: wildcards.species
    run:
        outhandle = open(output.size, "w")
        for record in fasta.parse(input.genome):
            thischrom = record.id
            print("{}\t{}\t{}".format(
                params.species,
                thischrom,
                len(record.seq)),
                file = outhandle)
        outhandle.close()

rule gen_mixing_plots:
    input:
        rbh = config["rbh_file"],
        sizes = expand(config["tool"] + "/input/{species}_chrom_sizes.tsv",
                       species = config["species"])
    output:
        plots = expand(config["tool"] + "/output/{species}_plots.pdf",
                       species = config["species"])
    threads: 1
    params:
        plot_stem = config["tool"] + "/output"
    run:
        merge_pairs = config["merge_pairs"]
        separate    = config["separate"]
        rbhfile     = config["rbh_file"]

        df = pd.read_csv(rbhfile, sep = "\t", header = 0,
                         na_values="None", index_col = None)
        #for thiscol in df.columns:
        #    if thiscol.endswith("_pos"):
        #        df[thiscol] = df[thiscol].astype(float)
        #        df[thiscol] = df[thiscol].astype(int)

        species_list = [x.split("_")[0] for x in df.columns if x.endswith("_gene")]

        # tell us if there are missing species in either the config.yaml
        #  or in the species list
        config_missing_in_rbh_file = [x for x in species_list
                                      if x not in config["species"]]
        rbh_missing_in_config_file = [x for x in config["species"]
                                      if x not in species_list]
        if len(config_missing_in_rbh_file):
            raise IOError("There are some species in the config file that are missing in the rbh file: {}".format(config_missing_in_rbh_file))
        if len(rbh_missing_in_config_file):
            raise IOError("There are some species in the rbh file that are missing in the config: {}".format(rbh_missing_in_config_file))


        species_to_genome_size = {}
        # we need the genome sizes for everything
        for thisspecies in species_list:
            thisfile = config["tool"] + "/input/{}_chrom_sizes.tsv".format(thisspecies)
            species_to_genome_size[thisspecies] = {}
            with open(thisfile, "r") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        fields = line.split("\t")
                        species_to_genome_size[thisspecies][fields[1]] = int(fields[2])

        left_color  = "#c91b31"
        if "left_color" in config:
            left_color = config["left_color"]
        right_color = "#f3b74a"
        if "right_color" in config:
            right_color = config["right_color"]

        for thisspecies in species_list:
            with PdfPages('{}/{}_plots.pdf'.format(params.plot_stem,
                                                   thisspecies)) as pdf:
                for thispair in merge_pairs:
                    species_chroms = list(df["{}_scaf".format(thisspecies)].unique())
                    species_chroms = [x for x in species_chroms
                                      if x not in [None, np.nan, "nan", "NaN"] ]
                    for thischrom in species_chroms:
                        left_group  = thispair[0]
                        right_group = thispair[1]
                        left_pos = list(df.loc[(df["gene_group"] == thispair[0]) &
                                        (df["{}_scaf".format(thisspecies)] == thischrom), "{}_pos".format(thisspecies)])
                        right_pos = list(df.loc[(df["gene_group"] == thispair[1]) &
                                        (df["{}_scaf".format(thisspecies)] == thischrom), "{}_pos".format(thisspecies)])
                        left_pos =  [x for x in left_pos  if x != "None"]
                        right_pos = [x for x in right_pos if x != "None"]
                        if (len(left_pos) > 0) or (len(right_pos) > 0):
                            print(thisspecies, thischrom, thispair)
                            print(left_pos)
                            print(right_pos)
                            chrom_size = int(species_to_genome_size[thisspecies][thischrom])
                            overlapping_plot(thisspecies, thischrom, chrom_size,
                                             left_pos,   right_pos,
                                             left_group, right_group,
                                             left_color, right_color, pdf)

rule text_mixing_measures:
    input:
        rbh = config["rbh_file"],
        csize = config["tool"] + "/input/{species}_chrom_sizes.tsv"
    output:
        mf = config["tool"] + "/output_mixing/{species}_mixing.tsv",
        sf = config["tool"] + "/output_mixing/{species}_mixing_simulation.txt",
    params:
        species = lambda wildcards: wildcards.species
    threads: 1
    run:
        mixing_analysis(input.rbh, params.species,
                        config["merge_pairs"], input.csize,
                        output.mf, output.sf, False)


rule merge_all_measures:
    input:
        mf = expand(config["tool"] + "/output_mixing/{species}_mixing.tsv",
               species = config["species"])
    output:
        outtable = config["tool"] + "/output_mixing_merged/all_species_mixing.tsv"
    threads: 1
    run:
        dataframes = []
        for filename in input.mf:
            try:
                df = pd.read_csv(filename, sep = "\t", index_col = None)
                dataframes.append(df)
            except:
                pass
        final = pd.concat(dataframes)
        final = final.sort_values("mixed_groups")
        final.to_csv(output.outtable, sep = "\t", index = False)
