"""
This script takes the rbh hits file, specifically of a 4-way rbh search,
  and mutates one group to match the number of chromosomes of another group.

This program considers the chromosome-scale genomes of four species,
with one known outgroup species. The program assumes the other three species
are a polytomy.

      ,========= outer
    =={
      {  ,====== species1
      '=={
         {====== species2
         {
         '====== inner

Using fusion-with-mixing events, the program looks for support for these three
possible phylogenetic topologies:

      hypothesis 1               hypothesis 2               hypothesis 3

     ,========= outer           ,========= outer           ,========= outer
   =={                        =={                        =={
     {  ,====== species1        {  ,====== species2        {  ,====== inner
     '=={                 OR    '=={                 OR    '=={
        {  ,=== species2           {  ,=== species1           {  ,=== species1
        '=={                       '=={                       '=={
           '=== inner                 '=== inner                 '=== species2


The config file will look something like this:

```
# this parameter is the rbh file you will use to look for groups
rbh_file: "EMU_HCA_RESLi_SRO_reciprocal_best_hits.rbh"
outer:  "SRO"
species1: "HCA"
species2: "EMU"
inner: "RESLi"

# This is the decay simulation.
num_swaps: 5000 # around 7 miniutes
num_simulations: 5000

# Randomization trials. Should be divisible by 100,000
# Each of the four genomes will be randomized in this many trials
num_randomizations:  1000000
```

"""

# data stuff
import pandas as pd
import numpy as np
# simulation stuff
import time
import random
from random import randrange
import sys
# plotting stuff
import ast
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mplpatches
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pylab as pl
import operator

configfile: "config.yaml"

config["tool"] = "odp_genome_mutation_analysis"

#OdpF.check_legality(config)

# make sure the appropriate fields are present
check_these = ["rbh_file",
               "species1",
               "species2",
               "inner",
               "outer",
               "num_simulations",
               "species"]
for check_this in check_these:
    if not check_this in config:
        raise IOError("You must specify '{}' in config".format(check_this))

# make sure that the group fields are present
check_these = ["species1",
               "species2",
               "inner",
               "outer"]
for check_this in check_these:
    if config[check_this] not in config["species"]:
        raise IOError("The sample {} is not in the species".format(config[check_this]))

# randomizations
num_randomizations_per_round=10000
num_rand_rounds = int(config["num_randomizations"]/num_randomizations_per_round)
#print("number of randomization rounds is : {}".format(num_rand_rounds))

# make sure none of the sample names have underscores
for thissample in config["species"]:
    if "_" in thissample:
        raise IOError("Sample names can't have '_' char: {}".format(thissample))

four_species = [config["outer"], config["species1"],
                config["species2"], config["inner"]]

# check that the four species are present in the dataframe
tempdf = pd.read_csv(config["rbh_file"], index_col = None, sep = "\t")
print(tempdf.columns)
for thisspecies in four_species:
    if "{}_scaf".format(thisspecies) not in tempdf.columns:
        raise IOError("Species {} was one of the defined species, but wasn't in the provided rbh file".format(thisspecies))
for thisspecies in [x.replace("_scaf", "") for x in tempdf.columns if x.endswith("_scaf")]:
    if thisspecies not in four_species:
        raise IOError("Species {} was in the rbh file, but wasn't in the specified species for this analysis.".format(thisspecies))

rule all:
    input:
        # just check the data
        config["tool"] + "/sim_randomization/initialization/random_sim_init.txt",
        config["tool"] + "/sim_randomization/initialization/groups_supporting_sp1_sister.txt",
        config["tool"] + "/sim_randomization/initialization/groups_supporting_sp2_sister.txt",

        expand(config["tool"] + "/sim_randomization/{species}/random_sim_{species}_{randcore}.txt",
               randcore = list(range(1, num_rand_rounds+1) ),
               species  = four_species),
        expand("odp_genome_mutation_analysis/figures/randomization_marginals_mutating_{species}.pdf",
               species = four_species),
        expand("odp_genome_mutation_analysis/figures/randomization_stats_mutating_{species}.txt",
               species = four_species),
        expand("odp_genome_mutation_analysis/figures/FLG_support_{species}.pdf",
               species = four_species)

def flatten3(li):
    """
    flattens nested lists
    """
    flatten = []
    for i in li:
        for ii in i:
            for iii in ii:
                flatten.append(iii)
    return flatten

def _get_counts_supporting_hypothesis(informative,
                                      outer, speciesA,
                                      speciesB, inner):
    """
    Takes a df of potentially informative relationships and looks for
     phylogenetic support of this tree:

              ,========= outer
            =={
              {  ,====== speciesA
              '=={
                 {  ,=== speciesB
                 '=={
                    '=== inner

    Just returns a dataframe of the groupings that support this
     hypothesis

     This data structure: [[5,14], [7,10,5], [8,12]]
       - shows that there are three pairings of FLGs that support this hypothesis
       - the pairings have 5 and 14 genes, 7 and 10 and 5, 8 and 12 genes...
    """
    # test hypothesis1
    hypAdf = informative.groupby(["{}_scaf".format(speciesB),
                                  "{}_scaf".format(inner)])

    keeps = []
    for group_name, df_group in hypAdf:
        # Remove the entries where the outer has two of the same scaf.
        # These are derived splits in species1.
        subgroup = df_group.groupby(["{}_scaf".format(outer),
                                     "{}_scaf".format(speciesB),
                                     "{}_scaf".format(inner)
                                       ]).filter(lambda x: len(x) < 2)

        # This gets rid of cases where species1 has two of the same scaf.
        # These are cases where we're not sure whether the outer split
        #  is derived in the outer group or a merge in the ancestor of
        #  species1, species2, and inner.
        subgroup = subgroup.groupby(["{}_scaf".format(speciesA),
                                     "{}_scaf".format(speciesB),
                                     "{}_scaf".format(inner)
                                     ]).filter(lambda x: len(x) < 2)

        if len(subgroup) > 1:
            keeps.append(subgroup)

    if len(keeps) > 0:
        hypAdf = pd.concat(keeps)
        hypAdf = hypAdf.groupby(["{}_scaf".format(speciesB),
                                 "{}_scaf".format(inner),
                                 ]).agg(list).reset_index()

    else:
        hypAdf = pd.DataFrame(columns = informative.columns)
    return hypAdf

def stats_on_df(df,
                outer, species1,
                species2, inner):
    """
    This calculates some statistics about the dataframe.
    All it records is the number of genes in the FLG pairs:

     This data structure: [[5,14], [7,10,5], [8,12]]
       - shows that there are three pairings of FLGs that support this hypothesis
       - the pairings have 5 and 14 genes, 7 and 10 and 5, 8 and 12 genes...

    The output is two columns:
    - column1: counts supporting hyp1 data structure
    - column2: counts supporting hyp2 data structure
    - column3: rbh groups supporting sp1 as sister clade
    - column4: rbh groups supporting sp2 as sister clade
    """
    gene_cols = ["{}_gene".format(x)
                 for x in [outer, species1, species2, inner] ]
    scaf_cols = ["{}_scaf".format(x)
                 for x in [outer, species1, species2, inner] ]
    grouped_multiple = df.groupby(scaf_cols).agg(list).reset_index()
    # get the size
    grouped_multiple["count"] = grouped_multiple.rbh.str.len()
    grouped_multiple = grouped_multiple.loc[grouped_multiple["count"] > 4, ]

    ## sort
    grouped_multiple = grouped_multiple[scaf_cols + ["rbh", "count"]]

    # only get things that are informative excluding the outgroup
    informative = grouped_multiple.groupby(["{}_scaf".format(species2), "{}_scaf".format(inner)]).filter(lambda x: len(x) >= 2)
    ## don't need the next line, it's just for QC
    #informative = informative.sort_values(["{}_scaf".format(inner),
    #                                       "{}_scaf".format(species2),
    #                                       "{}_scaf".format(species1),
    #                                       "{}_scaf".format(outer)]).reset_index(drop=True)


    r1 = _get_counts_supporting_hypothesis(informative, outer, species1, species2, inner)
    r2 = _get_counts_supporting_hypothesis(informative, outer, species2, species1, inner)
    return [list(r1["count"]), list(r2["count"]),
            flatten3(list(r1["rbh"])),
            flatten3(list(r2["rbh"]))]

# randomization simulation
rule randomization_simulation:
    """
    This rule performs the randomization simulation.
    It prints out the stats from each round onto a line.
    """
    input:
        rbh_file = config["rbh_file"],
    output:
        random_sim = config["tool"] + "/sim_randomization/{species}/random_sim_{species}_{randcore}.txt"
    threads: 1
    params:
        outer   = config["outer"],
        species1  = config["species1"],
        species2  = config["species2"],
        inner  = config["inner"],
        species = lambda wildcards: wildcards.species,
        num_rounds = num_randomizations_per_round
    run:
        # now we load up the rbh table
        outhandle = open(output.random_sim, "w")
        df = pd.read_csv(input.rbh_file, index_col = None, sep = "\t")
        scaf_cols = [x for x in df.columns if "_scaf" in x]
        df = df[['rbh'] + scaf_cols]

        # species1 col
        species1_col = "{}_scaf".format(params.species)

        # go until we say to stop
        # how many days for simulation
        start_time = time.time()
        counter = 1
        while True:
            # randomize the species1 column
            df[species1_col] = np.random.permutation(df[species1_col].values)
            #print(list(df[species1_col][0:10]))
            ssdone = False
            stats = stats_on_df(
                    df,
                    params.outer,
                    params.species1,
                    params.species2,
                    params.inner)
            print("{}\t{}".format(stats[0], stats[1]), file = outhandle)

            # cleanup
            if counter % 10 == 0:
                elapsed = time.time() - start_time
                days,  rem       = divmod(elapsed, 86400)
                hours, rem       = divmod(rem, 3600)
                minutes, seconds = divmod(rem, 60)
                days = int(days)
                hours = int(hours)
                minutes = int(minutes)
                seconds = int(seconds)
                print("  Done with {} iterations. Time elapsed: {}d {}h {}m {}s. \r".format(
                    counter, days, hours, minutes, seconds ), end = "\r")
            if counter == params.num_rounds:
                break
            counter = counter + 1

        outhandle.close()

# randomization simulation
rule randomization_initialization:
    """
    This measures the initial state of the genome.
    """
    input:
        rbh_file = config["rbh_file"],
    output:
        random_sim  = config["tool"] + "/sim_randomization/initialization/random_sim_init.txt",
        species1sis = config["tool"] + "/sim_randomization/initialization/groups_supporting_sp1_sister.txt",
        species2sis = config["tool"] + "/sim_randomization/initialization/groups_supporting_sp2_sister.txt",
    threads: 1
    params:
        outer  = config["outer"],
        species1 = config["species1"],
        species2 = config["species2"],
        inner = config["inner"],
        num_rounds = num_rand_rounds
    run:
        # now we load up the rbh table
        outhandle = open(output.random_sim, "w")
        df = pd.read_csv(input.rbh_file, index_col = None, sep = "\t")
        scaf_cols = [x for x in df.columns if "_scaf" in x]
        df = df[['rbh'] + scaf_cols]

        stats = stats_on_df(
            df,
            params.outer,
            params.species1,
            params.species2,
            params.inner)
        print("{}\t{}".format(stats[0], stats[1]),
              file = outhandle)
        outhandle.close()
        with open(output.species1sis, "w") as f:
            for x in stats[2]:
                print(x, file = f)
        with open(output.species2sis, "w") as f:
            for x in stats[3]:
                print(x, file = f)

rule make_randomization_summary:
    """
    This reads in all of the randomization data and turns it into a plottable,
    machine-readable file.

    The goal is to only execute this once per randomization trial.

    Get these fields for each hypothesis:
      - max_number_of_genes_in_grouping
      - total_number_of_genes_in_FLGs
      - total_number_of_groupings
      - mean_number_of_genes_in_each_FLG
    """
    input:
        random_sim = config["tool"] + "/sim_randomization/{species}/random_sim_{species}_{randcore}.txt"
    output:
        sim_plot = "odp_genome_mutation_analysis/sim_randomization/plottable/{species}/random_sim_{species}_{randcore}_plottable.txt"
    run:
        outhandle = open(output.sim_plot, "w")
        with open(input.random_sim, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    fields = line.split("\t")
                    hyp1_data  = ast.literal_eval(fields[0])
                    hyp2_data  = ast.literal_eval(fields[1])
                    hyp1_unwrapped = [item for sublist in hyp1_data
                                      for item in sublist]
                    hyp2_unwrapped = [item for sublist in hyp2_data
                                      for item in sublist]

                    hyp1_max_number_of_genes_in_single_FLG = 0
                    hyp1_max_number_of_genes_in_grouping   = 0
                    hyp1_total_number_of_genes_in_FLGs     = 0
                    hyp1_total_number_of_groupings         = 0
                    hyp1_mean_number_of_genes_in_FLGs      = 0
                    hyp2_max_number_of_genes_in_single_FLG = 0
                    hyp2_max_number_of_genes_in_grouping   = 0
                    hyp2_total_number_of_genes_in_FLGs     = 0
                    hyp2_total_number_of_groupings         = 0
                    hyp2_mean_number_of_genes_in_FLGs      = 0

                    try:
                        hyp1_max_number_of_genes_in_single_FLG = max(hyp1_unwrapped)
                    except:
                        pass
                    try:
                        hyp1_max_number_of_genes_in_grouping   = max(sum(x) for x in hyp1_data)
                    except:
                        pass
                    hyp1_total_number_of_genes_in_FLGs     = sum(hyp1_unwrapped)
                    hyp1_total_number_of_groupings         = len(hyp1_data)
                    try:
                        hyp1_mean_number_of_genes_in_FLGs      = sum(hyp1_unwrapped)/len(hyp1_unwrapped)
                    except:
                        pass

                    try:
                        hyp2_max_number_of_genes_in_single_FLG = max(hyp2_unwrapped)
                    except:
                        pass
                    try:
                        hyp2_max_number_of_genes_in_grouping   = max(sum(x) for x in hyp2_data)
                    except:
                        pass
                    hyp2_total_number_of_genes_in_FLGs     = sum(hyp2_unwrapped)
                    hyp2_total_number_of_groupings         = len(hyp2_data)
                    try:
                        hyp2_mean_number_of_genes_in_FLGs      = sum(hyp2_unwrapped)/len(hyp2_unwrapped)
                    except:
                        pass
                    print("\t".join([str(x) for x in [
                           hyp1_max_number_of_genes_in_single_FLG,
                           hyp1_max_number_of_genes_in_grouping,
                           hyp1_total_number_of_genes_in_FLGs,
                           hyp1_total_number_of_groupings,
                           hyp1_mean_number_of_genes_in_FLGs,
                           hyp2_max_number_of_genes_in_single_FLG,
                           hyp2_max_number_of_genes_in_grouping,
                           hyp2_total_number_of_genes_in_FLGs,
                           hyp2_total_number_of_groupings,
                           hyp2_mean_number_of_genes_in_FLGs]]),
                                     file = outhandle)
        outhandle.close()

rule cat_together_randomization_plottables:
    """
    This just cats together all the plottables into a single file
    """
    input:
        sim_plot = expand("odp_genome_mutation_analysis/sim_randomization/plottable/{{species}}/random_sim_{{species}}_{randcore}_plottable.txt",
               randcore = list(range(1, num_rand_rounds+1) ),
               species = four_species)
    output:
        outfile = "odp_genome_mutation_analysis/sim_randomization/plottable_final/random_sim_final_plottable_{species}.txt"
    shell:
        """
        echo "" | awk '{{printf("hyp1_max_number_of_genes_in_single_FLG\\thyp1_max_number_of_genes_in_grouping\\thyp1_total_number_of_genes_in_FLGs\\thyp1_total_number_of_groupings\\thyp1_mean_number_of_genes_in_FLGs\\thyp2_max_number_of_genes_in_single_FLG\\thyp2_max_number_of_genes_in_grouping\\thyp2_total_number_of_genes_in_FLGs\\thyp2_total_number_of_groupings\\thyp2_mean_number_of_genes_in_FLGs\\n")}}' > {output.outfile}
        cat {input.sim_plot} >> {output.outfile}
        """

rule plottable_init_file:
    """
    just make the initial measurement of the data plottable
    """
    input:
        init = "odp_genome_mutation_analysis/sim_randomization/initialization/random_sim_init.txt"
    output:
        init = "odp_genome_mutation_analysis/sim_randomization/initialization/random_sim_init_plottable.txt"
    run:
        # get the info for the init data
        hyp1_max_number_of_genes_in_single_FLG = 0
        hyp1_max_number_of_genes_in_grouping   = 0
        hyp1_total_number_of_genes_in_FLGs     = 0
        hyp1_total_number_of_groupings         = 0
        hyp1_mean_number_of_genes_in_FLGs      = 0
        hyp2_max_number_of_genes_in_single_FLG = 0
        hyp2_max_number_of_genes_in_grouping   = 0
        hyp2_total_number_of_genes_in_FLGs     = 0
        hyp2_total_number_of_groupings         = 0
        hyp2_mean_number_of_genes_in_FLGs      = 0

        with open(input.init, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    fields = line.split("\t")
                    hyp1_data  = ast.literal_eval(fields[0])
                    hyp2_data  = ast.literal_eval(fields[1])
                    hyp1_unwrapped = [item for sublist in hyp1_data
                                      for item in sublist]
                    hyp2_unwrapped = [item for sublist in hyp2_data
                                      for item in sublist]

                    try:
                        hyp1_max_number_of_genes_in_single_FLG = max(hyp1_unwrapped)
                    except:
                        pass
                    try:
                        hyp1_max_number_of_genes_in_grouping   = max(sum(x) for x in hyp1_data)
                    except:
                        pass
                    hyp1_total_number_of_genes_in_FLGs     = sum(hyp1_unwrapped)
                    hyp1_total_number_of_groupings         = len(hyp1_data)
                    try:
                        hyp1_mean_number_of_genes_in_FLGs      = sum(hyp1_unwrapped)/len(hyp1_unwrapped)
                    except:
                        pass

                    try:
                        hyp2_max_number_of_genes_in_single_FLG = max(hyp2_unwrapped)
                    except:
                        pass
                    try:
                        hyp2_max_number_of_genes_in_grouping   = max(sum(x) for x in hyp2_data)
                    except:
                        pass
                    hyp2_total_number_of_genes_in_FLGs     = sum(hyp2_unwrapped)
                    hyp2_total_number_of_groupings         = len(hyp2_data)
                    try:
                        hyp2_mean_number_of_genes_in_FLGs      = sum(hyp2_unwrapped)/len(hyp2_unwrapped)
                    except:
                        pass

        outhandle = open(output.init, "w")
        print("\t".join(["hyp1_max_number_of_genes_in_single_FLG",
                         "hyp1_max_number_of_genes_in_grouping",
                         "hyp1_total_number_of_genes_in_FLGs",
                         "hyp1_total_number_of_groupings",
                         "hyp1_mean_number_of_genes_in_FLGs",
                         "hyp2_max_number_of_genes_in_single_FLG",
                         "hyp2_max_number_of_genes_in_grouping",
                         "hyp2_total_number_of_genes_in_FLGs",
                         "hyp2_total_number_of_groupings",
                         "hyp2_mean_number_of_genes_in_FLGs"]),
              file = outhandle)
        print("\t".join([str(x) for x in [
                         hyp1_max_number_of_genes_in_single_FLG,
                         hyp1_max_number_of_genes_in_grouping,
                         hyp1_total_number_of_genes_in_FLGs,
                         hyp1_total_number_of_groupings,
                         hyp1_mean_number_of_genes_in_FLGs,
                         hyp2_max_number_of_genes_in_single_FLG,
                         hyp2_max_number_of_genes_in_grouping,
                         hyp2_total_number_of_genes_in_FLGs,
                         hyp2_total_number_of_groupings,
                         hyp2_mean_number_of_genes_in_FLGs]]),
              file = outhandle)
        outhandle.close()

rule plot_joint_clouds:
    """
    Makes density maps with marginal distribution plots.
    """
    input:
        data = "odp_genome_mutation_analysis/sim_randomization/plottable_final/random_sim_final_plottable_{species}.txt",
        init = "odp_genome_mutation_analysis/sim_randomization/initialization/random_sim_init_plottable.txt"
    output:
        pdf = "odp_genome_mutation_analysis/figures/randomization_marginals_mutating_{species}.pdf",
        stats = "odp_genome_mutation_analysis/figures/randomization_stats_mutating_{species}.txt"
    run:
        import matplotlib.cm as cm
        outhandle = open(output.stats, "w")
        initdf = pd.read_csv(input.init, index_col = None, sep = "\t")
        print("# INIT STATS", file = outhandle)
        print(initdf.iloc[0], file=outhandle)
        print("", file = outhandle)

        df = pd.read_csv(input.data, index_col = None, sep = "\t")

        # just collect some basic stats
        print("# SIMULATION STATS", file = outhandle)
        print("total number of simulations: {}".format(len(df)), file = outhandle)
        print("minimum alpha: {}".format(1/len(df)), file = outhandle)
        print("", file = outhandle)

        # collect stats about each individual field
        print("# SINGLE VARIABLE STATS", file = outhandle)
        for thisfield in df.columns:
            initval = initdf.iloc[0][thisfield]
            tempdf = df.loc[df[thisfield] >= initval,]
            tempdf = tempdf.loc[tempdf[thisfield] != 0, ]
            print("Variable: {}".format(thisfield), file =outhandle)
            print("  - Num of simulations >= init's value of {}: {}".format(
                initval, len(tempdf)),
                  file = outhandle)
            if len(tempdf) == 0:
                print("  - alpha: too_low_to_measure",
                      file = outhandle)
            else:
                print("  - alpha: {}".format(len(tempdf)/len(df)),
                      file = outhandle)
        print("", file = outhandle)

        # print joint probabilities
        print("# TWO-VAR JOINT PROB", file = outhandle)
        for thishyp in ["hyp1", "hyp2"]:
            these = [x for x in df.columns if thishyp in x]
            combos = list(itertools.combinations(these, 2))
            for thiscombo in combos:
                field0 = thiscombo[0]
                field1 = thiscombo[1]
                initval0 = initdf.iloc[0][field0]
                initval1 = initdf.iloc[0][field1]
                tempdf = df.loc[df[field0] >= initval0, ]
                tempdf = tempdf.loc[tempdf[field0] != 0, ]
                tempdf = tempdf.loc[tempdf[field1] >= initval1, ]
                tempdf = tempdf.loc[tempdf[field1] != 0, ]
                print("Variables: {} and {}".format(field0, field1), file =outhandle)
                print("  - Num of simulations >= init's values of {} and {}: {}".format(
                    initval0, initval1, len(tempdf)),
                      file = outhandle)
                if len(tempdf) == 0:
                    print("  - alpha: too_low_to_measure",
                          file = outhandle)
                else:
                    print("  - alpha: {}".format(len(tempdf)/len(df)),
                          file = outhandle)
            print("", file = outhandle)
        print("", file = outhandle)

        # CALL THIS TO GET THE VISUAL STYLE WE NEED
        odp_plot.format_matplotlib()

        # define custom colormap with fixed colour and alpha gradient
        # use simple linear interpolation in the entire scale
        cdict = {'red':   [(0.,0.,0),
                           (0.,0.,0)],

                 'green': [(0.,0.,0.),
                           (0.,0.,0.)],

                 'blue':  [(0.,0.,0.),
                           (0.,0.,0.)],

                 'alpha': [(1.,1,1),
                           (0.,0,0)]}

        cm.register_cmap(cmap=LinearSegmentedColormap("alpha_gradient", cdict))

        N = 256
        vals = np.ones((N, 4))
        vals[:, 0] = np.linspace(0, 0, N)
        vals[:, 1] = np.linspace(0, 0, N)
        vals[:, 2] = np.linspace(0, 0, N)
        vals[:, 3] = np.linspace(0, 1, N)
        newcmp = ListedColormap(vals)


        with PdfPages(output.pdf) as pdf_pages:
            possible_vars = list(set(["_".join(x.split("_")[1::]) for x in df.columns]))
            combos = list(itertools.combinations(possible_vars, 2))
            for thiscombo in combos:
                for orientation in [[thiscombo[0], thiscombo[1]],
                                    [thiscombo[1], thiscombo[0]]]:

                    initval0 = initdf.iloc[0]["hyp1_{}".format(orientation[0])]
                    initval1 = initdf.iloc[0]["hyp1_{}".format(orientation[1])]

                    # PLOTTING hyp1
                    field0 = "hyp1_{}".format(orientation[0])
                    field1 = "hyp1_{}".format(orientation[1])
                    tempdf = df[[field0, field1]]
                    tempdf = tempdf.drop_duplicates(
                        subset = [field0, field1]).reset_index(drop = True)

                    xmax = max(initval0, max(df[field0])) * 1.2
                    ymax = max(initval1, max(df[field1])) * 1.2
                    print("plotting {} and {}".format(field0, field1))
                    g = sns.JointGrid(data=df,
                                      x=field0,
                                      y=field1,
                                      xlim = [0,xmax],
                                      ylim = [0,ymax],
                                      marginal_ticks=True)
                    # plot the outliers
                    g.ax_joint.scatter(
                        data=tempdf,
                        x=field0,
                        y=field1,
                        linewidth = 0,
                        c='#808080', marker='o',
                        s=2)
                    # plot the init data
                    g.ax_joint.scatter(
                        data=initdf,
                        x=field0,
                        y=field1,
                        c='r', marker='o',
                        zorder=10)
                    cax = g.figure.add_axes([.15, .55, .02, .2])

                    g.plot_joint(sns.histplot, discrete=(True, True),
                        #cmap="light:#03012d", cbar=True, cbar_ax=cax)
                        #pthresh = 0.0001,
                        #cmap="gist_yarg", cbar=True, cbar_ax=cax)
                        cmap=newcmp, cbar=True, cbar_ax=cax)

                    maxfield0 = max(df[field0])
                    maxfield1 = max(df[field1])

                    xbins=np.arange(0, maxfield0, 1)
                    _ = g.ax_marg_x.hist(df[field0], color="#000000", alpha=.6,
                                          bins=xbins, align="left")
                    ybins=np.arange(0, maxfield1, 1)
                    _ = g.ax_marg_y.hist(df[field1], color="#000000", alpha=.6,
                                          bins=ybins,
                                          orientation="horizontal", align="left")

                    fig = g.fig
                    pdf_pages.savefig(fig)
                    plt.close(g.fig)

                    #now plot the same thing for hypothesis 2
                    # PLOTTING hyp1
                    field0 = "hyp2_{}".format(orientation[0])
                    field1 = "hyp2_{}".format(orientation[1])
                    tempdf = df[[field0, field1]]
                    tempdf = tempdf.drop_duplicates(
                        subset = [field0, field1]).reset_index(drop = True)


                    print("plotting {} and {}".format(field0, field1))
                    g = sns.JointGrid(data=df,
                                      x=field0,
                                      y=field1,
                                      xlim = [0,xmax],
                                      ylim = [0,ymax],
                                      marginal_ticks=True)
                    # plot the outliers
                    g.ax_joint.scatter(
                        data=tempdf,
                        x=field0,
                        y=field1,
                        linewidth = 0,
                        c='#808080', marker='o',
                        s=2)
                    # plot init
                    g.ax_joint.scatter(
                        data=initdf,
                        x=field0,
                        y=field1,
                        c='r', marker='o',
                        zorder=10)
                    cax = g.figure.add_axes([.15, .55, .02, .2])

                    g.plot_joint(sns.histplot, discrete=(True, True),
                        #cmap="light:#03012d", cbar=True, cbar_ax=cax)
                        #pthresh = 0.0001,
                        #cmap="gist_yarg", cbar=True, cbar_ax=cax)
                        cmap=newcmp, cbar=True, cbar_ax=cax)

                    maxfield0 = max(df[field0])
                    maxfield1 = max(df[field1])

                    _ = g.ax_marg_x.hist(df[field0], color="#000000", alpha=.6,
                                          bins=xbins, align="left")
                    _ = g.ax_marg_y.hist(df[field1], color="#000000", alpha=.6,
                                          bins=ybins, align="left",
                                          orientation="horizontal")

                    fig = g.fig
                    pdf_pages.savefig(fig)
                    plt.close(g.fig)

rule plot_hypothesis_support:
    """
    Plots the hypothesis support of S1 or S2 or S3 as sister.
    """
    input:
        data = "odp_genome_mutation_analysis/sim_randomization/plottable_final/random_sim_final_plottable_{species}.txt",
        init = "odp_genome_mutation_analysis/sim_randomization/initialization/random_sim_init_plottable.txt"
    output:
        pdf = "odp_genome_mutation_analysis/figures/FLG_support_{species}.pdf",
    params:
        species = lambda wildcards: wildcards.species
    run:
        import matplotlib.cm as cm
        from matplotlib import pyplot as plt
        import matplotlib
        matplotlib.rcParams['pdf.fonttype'] = 42
        matplotlib.rcParams['ps.fonttype'] = 42
        initdf = pd.read_csv(input.init, index_col = None, sep = "\t")
        df = pd.read_csv(input.data, index_col = None, sep = "\t")

        # CALL THIS TO GET THE VISUAL STYLE WE NEED
        odp_plot.format_matplotlib()

        possible_vars = list(set(["_".join(x.split("_")[1::]) for x in df.columns]))
        plot_these = ["total_number_of_genes_in_FLGs",
                      "total_number_of_groupings",
                      "max_number_of_genes_in_grouping",
                      "max_number_of_genes_in_single_FLG"]

        plot_to_xlabel = {"total_number_of_genes_in_FLGs": "Genes in PI LGs",
                          "total_number_of_groupings": "Number of LG groupings",
                          "max_number_of_genes_in_grouping": "Genes in largest LG grouping",
                          "max_number_of_genes_in_single_FLG": "Genes in largest LG"
                         }

        #print(plot_these)
        #print(possible_vars)
        for x in plot_these:
            if x not in possible_vars:
                raise IOError("{} was not in the input df".format(x))

        with PdfPages(output.pdf) as pdf_pages:
            for thisthing in plot_these:
                fig, ax = plt.subplots(2, 1, sharey = True, sharex = True)
                fig.suptitle("{}, {} randomized".format(thisthing, params.species), fontsize=8)
                fig.set_size_inches(5, 2.5)
                columns = ["{}_{}".format(x, thisthing) for x in ["hyp1", "hyp2"]]
                index = 0

                hyp1_init = initdf.iloc[0][columns[0]]
                hyp2_init = initdf.iloc[0][columns[1]]
                xmax = int(max([hyp1_init, hyp2_init,
                            max(df[columns[0]]),
                            max(df[columns[1]])
                          ]) * 1.2)
                ymax = max([len(df[x]) for x in columns])
                for thiscol in columns:
                    sns.histplot(df[thiscol], discrete = True,
                                 label=thiscol, color = "#000000",
                                      ax=ax[index])
                    ax[index].set_title(thiscol, fontsize = 6)
                    ax[index].ticklabel_format(axis='y', style='sci', scilimits=(0,0))
                    ax[index].tick_params(axis='both', which='major', labelsize=8)
                    ax[index].xaxis.label.set_size(8)
                    ax[index].yaxis.label.set_size(8)
                    sns.despine(offset=2, trim=False)
                    index += 1
                ax[0].set(ylabel="count\n{} sister".format(config["species1"]))
                ax[1].set(ylabel="count\n{} sister".format(config["species2"]))


                # set the plotting limits
                #ax[0].set_xlim([-1,xmax])
                #ax[1].set_xlim([-1,xmax])
                ylim = ax[0].get_ylim()[-1]
                ax[0].axvline(hyp1_init, 0, (ylim*0.75)/ylim, color = "#c44e52")
                ax[1].axvline(hyp2_init, 0, (ylim*0.75)/ylim, color = "#c44e52")

                # set the xlabel
                ax[1].set_xlabel(plot_to_xlabel[thisthing])
                plt.tight_layout()

                pdf_pages.savefig(fig)
                plt.close()
