"""
This script takes the rbh hits file, specifically of a 4-way rbh search,
  and mutates one group to match the number of chromosomes of another group.

This program considers the chromosome-scale genomes of four species,
with one known outgroup species. The program assumes the other three species
are a polytomy.

      ,========= outer
    =={
      {  ,====== species1
      '=={
         {====== species2
         {
         '====== species3

Using fusion-with-mixing events, the program looks for support for these three
possible phylogenetic topologies:

      hypothesis 1               hypothesis 2               hypothesis 3

     ,========= outer           ,========= outer           ,========= outer
   =={                        =={                        =={
     {  ,====== species1        {  ,====== species2        {  ,====== species3
     '=={                 OR    '=={                 OR    '=={
        {  ,=== species2           {  ,=== species1           {  ,=== species1
        '=={                       '=={                       '=={
           '=== species3              '=== species3              '=== species2


The config file will look something like this:

```
# this parameter is the rbh file you will use to look for groups
rbh_file: "EMU_HCA_RESLi_SRO_reciprocal_best_hits.rbh"
outer:  "SRO"
species1: "HCA"
species2: "EMU"
species3: "RESLi"

# This is the decay simulation.
num_swaps: 5000 # around 7 miniutes
num_simulations: 5000

# Randomization trials. Should be divisible by 100,000
# Each of the four genomes will be randomized in this many trials
num_randomizations:  1000000
```

"""

# data stuff
import pandas as pd
import numpy as np
# simulation stuff
import time
import random
from random import randrange
import sys
# plotting stuff
import ast
import itertools
import matplotlib.pyplot as plt
import matplotlib.patches as mplpatches
import seaborn as sns
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pylab as pl
import operator

configfile: "config.yaml"

config["tool"] = "odp_genome_mutation_analysis"

#OdpF.check_legality(config)

# make sure the appropriate fields are present
check_these = ["rbh_file",
               "outgroup",
               "species1",
               "species2",
               "species3",
               "num_simulations",
               "species"]
for check_this in check_these:
    if not check_this in config:
        raise IOError("You must specify '{}' in config".format(check_this))

# make sure that the group fields are present
check_these = ["outgroup",
               "species1",
               "species2",
               "species3"]
for check_this in check_these:
    if config[check_this] not in config["species"]:
        raise IOError("The sample {} is not in the species".format(config[check_this]))

# randomizations
num_randomizations_per_round=10000
num_rand_rounds = int(config["num_randomizations"]/num_randomizations_per_round)
#print("number of randomization rounds is : {}".format(num_rand_rounds))

# make sure none of the sample names have underscores
for thissample in config["species"]:
    if "_" in thissample:
        raise IOError("Sample names can't have '_' char: {}".format(thissample))

four_species = [config["outgroup"],
                config["species1"],
                config["species2"],
                config["species3"]]
three_sp = [config["species1"],
            config["species2"],
            config["species3"]]
spstring = "-".join(four_species)

# check that the four species are present in the dataframe
tempdf = pd.read_csv(config["rbh_file"], index_col = None, sep = "\t")
for thisspecies in four_species:
    if "{}_scaf".format(thisspecies) not in tempdf.columns:
        raise IOError("Species {} was one of the defined species, but wasn't in the provided rbh file".format(thisspecies))
for thisspecies in [x.replace("_scaf", "") for x in tempdf.columns if x.endswith("_scaf")]:
    if thisspecies not in four_species:
        raise IOError("Species {} was in the rbh file, but wasn't in the specified species for this analysis.".format(thisspecies))

rule all:
    input:
        # first complete the simulations
        expand(config["tool"] + "/sim_randomization/{spstr}/trials/{spstr}_random_sim_{species}_{randcore}.txt",
               randcore = list(range(1, num_rand_rounds+1) ),
               spstr = spstring, species  = four_species),

        ## analyze the input data
        expand(config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_random_sim_init.txt",
               spstr = spstring),
        expand(config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_rows_annotated.rbh.groupby",
               spstr = spstring),
        expand(config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_groups_supporting_{sp}_sister.txt",
               spstr = spstring, sp = three_sp),
        expand(config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_df_of_fusionsUnmixed_{sp}.txt",
               spstr = spstring, sp = three_sp),
        expand(config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_df_of_groups_supporting_{sp}_sister.txt",
               spstr = spstring, sp = three_sp),

        # plots
        expand(config["tool"] + "/figures/{spstr}/{spstr}_randomization_marginals_{species}-randomizations.pdf",
               spstr = spstring, species = four_species),
        expand(config["tool"] + "/figures/{spstr}/{spstr}_randomization_stats_{species}-randomizations.txt",
               spstr = spstring, species = four_species),
        expand(config["tool"] + "/figures/{spstr}/{spstr}_support_{species}-randomizations.pdf",
               spstr = spstring, species = four_species)

def flatten3(li):
    """
    flattens nested lists
    """
    flatten = []
    for i in li:
        for ii in i:
            for iii in ii:
                flatten.append(iii)
    return flatten

def overlaps(x, y):
    """
    from https://stackoverflow.com/questions/64745139/check-if-two-integer-ranges-overlap
    """
    return max(x[0],y[0]) <= min(x[1],y[1])

def _get_groups_supporting_polytomies(grouped_df,
                                      outgroup, speciesA,
                                      speciesB,
                                      speciesC):
    """
    Takes a df of potentially informative relationships and looks for
     phylogenetic support of this tree:

              ,========= outer
            =={
              {  ,====== speciesA
              '=={
                 {====== speciesB
                 {
                 '====== speciesC

    Just returns a dataframe of the groupings that support this
     hypothesis

     This data structure: [[5,14], [7,10,5], [8,12]]
       - shows that there are three pairings of FLGs that support this hypothesis
       - the pairings have 5 and 14 genes, 7 and 10 and 5, 8 and 12 genes...
    """
    polytomydf = grouped_df.groupby(["{}_scaf".format(speciesA),
                                     "{}_scaf".format(speciesB),
                                     "{}_scaf".format(speciesC)]).filter(
                                          lambda x: len(x) >= 2)
    return polytomydf

def _get_groups_supporting_hypothesis(grouped_df,
                                      outgroup, speciesA,
                                      speciesB,
                                      speciesC):
    """
    Takes a df of potentially informative relationships and looks for
     phylogenetic support of this tree:

              ,========= outer
            =={
              {  ,====== speciesA
              '=={
                 {  ,=== speciesB
                 '=={
                    '=== speciesC

    Just returns a dataframe of the groupings that support this
     hypothesis

     This data structure: [[5,14], [7,10,5], [8,12]]
       - shows that there are three pairings of FLGs that support this hypothesis
       - the pairings have 5 and 14 genes, 7 and 10 and 5, 8 and 12 genes...
    """
    informative = grouped_df.groupby(["{}_scaf".format(speciesB), "{}_scaf".format(speciesC)]).filter(lambda x: len(x) >= 2)
    informative = informative.sort_values(["{}_scaf".format(speciesC),
                                           "{}_scaf".format(speciesB),
                                           "{}_scaf".format(speciesA),
                                           "{}_scaf".format(outgroup)]).reset_index(drop=True)

    # test hypothesis1
    hypAdf = informative.groupby(["{}_scaf".format(speciesB),
                                  "{}_scaf".format(speciesC)])
    #print(hypAdf.agg(list).reset_index(drop=True))
    #sys.exit()

    fused_and_mixed   = []
    fused_and_unmixed = []
    for group_name, df_group in hypAdf:
        # For example this subgroup where out=COW, spA=HCA, spB=EMU, spC=RES
        #
        # There is a special edge case where the linkage group is split between
        #  two groups in sp1 and the outgroup. Here we just get rid of one of the
        #  rows since there is still some evidence of a fusion event.
        #
        #       COW_scaf HCA_scaf EMU_scaf RES_scaf
        #     0     COW3     HCA2     EMU9     RES5
        #     1    COW14     HCA5     EMU9     RES5
        #     2     COW3     HCA5     EMU9     RES5
        #
        # We go through the two rows that have the same chrom in the outgroup
        #  (COW3) and get rid of those that have the same sp1 chrom as the
        #  COW3 bundle (HCA5).
        #
        # So we keep these two rows:
        #       COW_scaf HCA_scaf EMU_scaf RES_scaf
        #     0     COW3     HCA2     EMU9     RES5
        #     1    COW14     HCA5     EMU9     RES5
        #
        # and we get rid of this row:
        #       COW_scaf HCA_scaf EMU_scaf RES_scaf
        #     2     COW3     HCA5     EMU9     RES5

        #print()
        #print("NEW SUBGROUP ANALYSIS")
        #print("spA: {}  spB: {}  spC: {}".format(speciesA, speciesB, speciesC))
        #print(df_group[[x for x in df_group.columns if x.endswith("_scaf")]])

        done = False
        keep_these = []
        temp_singletons = df_group.groupby(["{}_scaf".format(outgroup),
                                     "{}_scaf".format(speciesB),
                                     "{}_scaf".format(speciesC)
                                       ]).filter(lambda x: len(x) < 2)
        temp_dups = df_group.groupby(["{}_scaf".format(outgroup),
                                     "{}_scaf".format(speciesB),
                                     "{}_scaf".format(speciesC)
                                       ]).filter(lambda x: len(x) > 1)
        #print("subgroup")
        #print(df_group[["{}_scaf".format(x) for x in [outgroup, speciesA, speciesB, speciesC]] + ["count"]])
        #print("temp_dups")
        #print(temp_dups[["{}_scaf".format(x) for x in [outgroup, speciesA, speciesB, speciesC]] + ["count"]])
        #print("temp_singletos")
        #print(temp_singletons[["{}_scaf".format(x) for x in [outgroup, speciesA, speciesB, speciesC]] + ["count"]])
        if (len(temp_dups) > 1) and (len(temp_singletons) > 0):
            # make sure that the length of the temp_singletons is just 1.
            # I can't conceive of a situation in which there should be more than one
            if len(temp_singletons) > 1:
                # In this scenario there are two rows that should not be considered:
                #  the ones with COW3
                # These are ambiguous and could be derived fissions in HCA or a derived
                #  fusion in COW. So, we just remove these and they will later be flagged
                #  as ambiguous.
                #
                #    spA: HCA  spB: CLA  spC: PMA
                #    subgroup
                #      COW_scaf HCA_scaf    CLA_scaf PMA_scaf  count
                #    0     COW3    HCA11  CLA15_hapA     PMA3      5
                #    1     COW3     HCA2  CLA15_hapA     PMA3      5
                #    2     COW5     HCA6  CLA15_hapA     PMA3      6
                #    3     COW2     HCA9  CLA15_hapA     PMA3      5
                #    temp_dups
                #      COW_scaf HCA_scaf    CLA_scaf PMA_scaf  count
                #    0     COW3    HCA11  CLA15_hapA     PMA3      5
                #    1     COW3     HCA2  CLA15_hapA     PMA3      5
                #    temp_singletos
                #      COW_scaf HCA_scaf    CLA_scaf PMA_scaf  count
                #    2     COW5     HCA6  CLA15_hapA     PMA3      6
                #    3     COW2     HCA9  CLA15_hapA     PMA3      5
                #
                temp_dups = pd.DataFrame(columns = temp_singletons.columns)
            else:
                # get the chrom of sp1 in the temp_singletons
                opposite_chrom = temp_singletons.iloc[0]["{}_scaf".format(speciesA)]
                temp_dups = temp_dups.loc[~(temp_dups["{}_scaf".format(speciesA)] == opposite_chrom), ]
        subgroup = pd.concat([temp_dups, temp_singletons])
        #print("NEW TEMP DUPS")
        #print(subgroup)

        # Remove the entries where the outgroup has two of the same scaf.
        # These are derived splits in species1.
        #
        # For example this subgroup where out=COW, spA=HCA, spB=EMU, spC=RES
        #       COW_scaf HCA_scaf EMU_scaf RES_scaf
        #     0     COW3     HCA1    EMU19     RES2
        #     1     COW3     HCA2    EMU19     RES2
        #     2     COW3     HCA3    EMU19     RES2
        #
        # would be deleted:
        #     Empty DataFrame
        #     Columns: [COW_scaf, EMU_scaf, HCA_scaf, RES_scaf]
        #     Index: []
        #
        subgroup = subgroup.groupby(["{}_scaf".format(outgroup),
                                     "{}_scaf".format(speciesB),
                                     "{}_scaf".format(speciesC)
                                       ]).filter(lambda x: len(x) < 2)

        # This gets rid of cases where species1 has two of the same scaf.
        # These are cases where we're not sure whether the outer split
        #  is derived in the outer group or a merge in the ancestor of
        #  speciesA, speciesB, and speciesC.
        #
        # For example this subgroup where out=COW, spA=HCA, spB=EMU, spC=RES
        #
        #       COW_scaf EMU_scaf HCA_scaf RES_scaf
        #     0     COW1    EMU13     HCA1    RES14
        #     1     COW2    EMU13     HCA1    RES14
        #     2     COW4    EMU13     HCA1    RES14
        #
        # would be deleted:
        #     Empty DataFrame
        #     Columns: [COW_scaf, EMU_scaf, HCA_scaf, RES_scaf]
        #     Index: []
        subgroup = subgroup.groupby(["{}_scaf".format(speciesA),
                                     "{}_scaf".format(speciesB),
                                     "{}_scaf".format(speciesC)
                                     ]).filter(lambda x: len(x) < 2)
        #print("filters to")
        #print(subgroup[[x for x in subgroup.columns if x.endswith("_scaf")]])

        # make sure that the ranges of the genes overlap for mixing
        overlapping = False
        if len(subgroup) > 1:
            for thissp in [speciesB, speciesC]:
                # get a list of ranges of the extent of genes in the groups
                ranges_list = []
                for index in subgroup.index:
                    unwrapped  = subgroup.loc[index, "{}_pos".format(thissp)]
                    unwrapped2 = subgroup.loc[index, "{}_pos".format(thissp)]
                    if type(unwrapped) == str:
                        print("string", unwrapped)
                        pass
                    elif type(unwrapped) in [list, tuple]:
                        everything_int = True
                        for x in unwrapped:
                            if type(x) not in [int, float]:
                                everything_int = False
                        if everything_int:
                           # if  everything here is an int, then we're OK
                           # The list is OK and can be processed
                            unwrapped2 = [int(x) for x in unwrapped]
                        else:
                            # If there are things that aren't ints they're probably
                            # a multi=nested list
                            unwrapped2 = []
                            for i in unwrapped:
                                if type(i) == str:
                                    for ii in i.split(","):
                                        unwrapped2.append(int(ii))
                                else:
                                    raise IOError("The type here is {}, but we only know how to handle strings.".format(
                                        type(i)))
                    else:
                        raise IOError("Not sure how to handle this. Contact the developer.")
                    thismin = int(min(unwrapped2))
                    thismax = int(max(unwrapped2))
                    ranges_list.append([thismin, thismax])
                # now make sure at least two of the groups have overlapped
                for i in range(len(ranges_list)):
                    for j in range(i+1, len(ranges_list)):
                        if overlaps(ranges_list[i], ranges_list[j]):
                            overlapping = True

        if (len(subgroup) > 1):
            if overlapping:
                fused_and_mixed.append(subgroup)
            else:
                fused_and_unmixed.append(subgroup)

    if len(fused_and_mixed) > 0:
        hypAdf = pd.concat(fused_and_mixed)
        hypAdf_grouped = hypAdf.groupby(["{}_scaf".format(speciesB),
                                         "{}_scaf".format(speciesC),
                                         ]).agg(list).reset_index()
    else:
        hypAdf = pd.DataFrame(columns = informative.columns)
        hypAdf_grouped = pd.DataFrame(columns = informative.columns)

    if len(fused_and_unmixed) > 0:
        hypAFUdf = pd.concat(fused_and_unmixed)
        hypAFUdf_grouped = hypAFUdf.groupby(["{}_scaf".format(speciesB),
                                           "{}_scaf".format(speciesC),
                                           ]).agg(list).reset_index()
    else:
        hypAFUdf = pd.DataFrame(columns = informative.columns)
        hypAFUdf_grouped = pd.DataFrame(columns = informative.columns)



    return hypAdf,hypAdf_grouped,hypAFUdf,hypAFUdf_grouped

def _row_string(row, outgroup, sp1, sp2, sp3):
    """
    takes a row and prints the 4-species string
    """
    ogsc = "{}_scaf".format(outgroup)
    s1sc = "{}_scaf".format(sp1)
    s2sc = "{}_scaf".format(sp2)
    s3sc = "{}_scaf".format(sp3)
    ogv  = row[ogsc]
    s1v  = row[s1sc]
    s2v  = row[s2sc]
    s3v  = row[s3sc]
    return "{}_{}_{}_{}".format(ogv, s1v, s2v, s3v)

def _get_sp_string_from_df(df, outgroup, sp1, sp2, sp3):
    """
    takes in a dataframe.
    outputs a list of ["OG-sp1-sp2-sp3"] for each row of a df
    """
    results = []
    for index, row in df.iterrows():
        results.append(_row_string(row, outgroup, sp1, sp2, sp3))
    return results

def _get_derived_fissions(df, outgroup, sp1, sp2, sp3):
    """
    takes in a dataframe.
    outputs three dataframes containing the derived fusions for
     sp1, sp2, sp3
    """
    sp1_derived_fissions = df.groupby(["{}_scaf".format(outgroup),
                                       "{}_scaf".format(sp2),
                                       "{}_scaf".format(sp3)
                                       ]).filter(lambda x: len(x) > 1)
    sp2_derived_fissions = df.groupby(["{}_scaf".format(outgroup),
                                       "{}_scaf".format(sp1),
                                       "{}_scaf".format(sp3)
                                       ]).filter(lambda x: len(x) > 1)
    sp3_derived_fissions = df.groupby(["{}_scaf".format(outgroup),
                                       "{}_scaf".format(sp1),
                                       "{}_scaf".format(sp2)
                                       ]).filter(lambda x: len(x) > 1)
    return sp1_derived_fissions, sp2_derived_fissions, sp3_derived_fissions

def stats_on_df(df,
                outgroup, species1,
                species2, species3):
    """
    This calculates some statistics about the dataframe.
    All it records is the number of genes in the FLG pairs:

     This data structure: [[5,14], [7,10,5], [8,12]]
       - shows that there are three pairings of FLGs that support this hypothesis
       - the pairings have 5 and 14 genes, 7 and 10 and 5, 8 and 12 genes...

    The output is two columns:
    - column1: df of groups supporting sp1 as sister
    - column2: "" supporting sp2 as sister
    - column3: "" support sp3 as sister
    - column4: counts supporting hyp1 data structure
    - column5: counts supporting hyp2 data structure
    - column6: counts supporting hyp3 data structure
    - column7: rbh groups supporting sp1 as sister clade
    - column8: rbh groups supporting sp2 as sister clade
    - column9: rbh groups supporting sp3 as sister clade
    - column10: groupby dataframe annotated with what relationships the rows imply
    """
    gene_cols = ["{}_gene".format(x)
                 for x in [outgroup, species1, species2, species3] ]
    scaf_cols = ["{}_scaf".format(x)
                 for x in [outgroup, species2, species1, species3] ]
    grouped_multiple = df.groupby(scaf_cols).agg(list).reset_index()
    # get the size
    grouped_multiple["count"] = grouped_multiple.rbh.str.len()
    grouped_multiple = grouped_multiple.loc[grouped_multiple["count"] > 4, ]
    grouped_multiple = grouped_multiple.sort_values(["{}_scaf".format(species3),
                                                "{}_scaf".format(species1),
                                                "{}_scaf".format(species2),
                                                "{}_scaf".format(outgroup)]).reset_index(drop=True)
    grouped_multiple["data_support"] = "singleton"

    ## sort
    #grouped_multiple = grouped_multiple[scaf_cols + ["rbh", "count"]]

    # get the groups that support a sister clade hypothesis
    r1_FM_df, r1_FM, r1_FU_df, r1_FU = _get_groups_supporting_hypothesis(grouped_multiple, outgroup, species1, species2, species3)
    rows_supporting_s1sis = _get_sp_string_from_df(r1_FM_df, outgroup, species1, species2, species3)
    rows_supporting_s1FU =  _get_sp_string_from_df(r1_FU_df, outgroup, species1, species2, species3)
    r2_FM_df, r2_FM, r2_FU_df, r2_FU = _get_groups_supporting_hypothesis(grouped_multiple, outgroup, species2, species1, species3)
    rows_supporting_s2sis = _get_sp_string_from_df(r2_FM_df, outgroup, species1, species2, species3)
    rows_supporting_s2FU =  _get_sp_string_from_df(r2_FU_df, outgroup, species1, species2, species3)
    r3_FM_df, r3_FM, r3_FU_df, r3_FU = _get_groups_supporting_hypothesis(grouped_multiple, outgroup, species3, species1, species2)
    rows_supporting_s3sis = _get_sp_string_from_df(r3_FM_df, outgroup, species1, species2, species3)
    rows_supporting_s3FU =  _get_sp_string_from_df(r3_FU_df, outgroup, species1, species2, species3)

    # now find the groups that are derived splits
    sp1_derFiss, sp2_derFiss, sp3_derFiss = _get_derived_fissions(grouped_multiple, outgroup, species1, species2, species3)
    sp1_derFiss_rows = _get_sp_string_from_df(sp1_derFiss, outgroup, species1, species2, species3)
    sp2_derFiss_rows = _get_sp_string_from_df(sp2_derFiss, outgroup, species1, species2, species3)
    sp3_derFiss_rows = _get_sp_string_from_df(sp3_derFiss, outgroup, species1, species2, species3)

    # get the groups that are fused in spA spB and spC
    polytomydf = _get_groups_supporting_polytomies(grouped_multiple, outgroup, species1, species2, species3)
    rows_supporting_polytomy = _get_sp_string_from_df(polytomydf, outgroup, species1, species2, species3)

    for index, row in grouped_multiple.iterrows():
        support_list = []
        row_string = _row_string(row, outgroup, species1, species2, species3)
        # sister hypotheses
        if row_string in rows_supporting_s1sis:
            support_list.append("{}_sis".format(species1))
        if row_string in rows_supporting_s2sis:
            support_list.append("{}_sis".format(species2))
        if row_string in rows_supporting_s3sis:
            support_list.append("{}_sis".format(species3))
        # fused but unmixed
        if row_string in rows_supporting_s1FU:
            support_list.append("{}_FU".format(species1))
        if row_string in rows_supporting_s2FU:
            support_list.append("{}_FU".format(species2))
        if row_string in rows_supporting_s3FU:
            support_list.append("{}_FU".format(species3))
        # derived fissions
        if row_string in sp1_derFiss_rows:
            support_list.append("{}_DerFissOrOGfus".format(species1))
        if row_string in sp2_derFiss_rows:
            support_list.append("{}_DerFissOrOGfus".format(species2))
        if row_string in sp3_derFiss_rows:
            support_list.append("{}_DerFissOrOGfus".format(species3))
        # polytomies
        if row_string in rows_supporting_polytomy:
            support_list.append("{}_{}_{}_polytomy".format(species1, species2, species3))
        support_string = "_".join(support_list)
        if support_string != "":
            grouped_multiple.loc[index, "data_support"] = support_string

    #print("r1", "\n", r1_FM_df[[x for x in r1_FM_df.columns if x.endswith("_scaf")]])
    #print("r1", "\n", r1_FU_df)
    #print()
    #print("r2", "\n", r2_FM_df)
    #print("r2", "\n", r2_FU_df)
    #print()
    #print("r3", "\n", r3_FM_df)
    #print("r3", "\n", r3_FU_df)
    #print()
    #print_me = grouped_multiple[[x for x in grouped_multiple.columns if x.endswith("_scaf")] + ["data_support"]]
    #print(print_me)

    return [r1_FM_df, r2_FM_df, r3_FM_df,
            list(r1_FM["count"]),
            list(r2_FM["count"]),
            list(r3_FM["count"]),
            flatten3(list(r1_FM["rbh"])),
            flatten3(list(r2_FM["rbh"])),
            flatten3(list(r3_FM["rbh"])),
            r1_FU_df, r2_FU_df, r3_FU_df,
            grouped_multiple
            ]

# randomization simulation
rule randomization_simulation:
    """
    This rule performs the randomization simulation.
    It prints out the stats from each round onto a line.
    """
    input:
        rbh_file = config["rbh_file"],
    output:
        random_sim = config["tool"] + "/sim_randomization/{spstr}/trials/{spstr}_random_sim_{species}_{randcore}.txt"
    threads: 1
    params:
        outgroup  = config["outgroup"],
        species1  = config["species1"],
        species2  = config["species2"],
        species3  = config["species3"],
        species = lambda wildcards: wildcards.species,
        num_rounds = num_randomizations_per_round
    run:
        # now we load up the rbh table
        outhandle = open(output.random_sim, "w")
        df = pd.read_csv(input.rbh_file, index_col = None, sep = "\t")
        scaf_cols = [x for x in df.columns if "_scaf" in x]
        pos_cols = [x for x in df.columns if "_pos" in x]
        df = df[['rbh'] + scaf_cols + pos_cols]

        # species1 col
        species1_col = "{}_scaf".format(params.species)

        # go until we say to stop
        # how many days for simulation
        start_time = time.time()
        counter = 1
        while True:
            # randomize the species1 column
            df[species1_col] = np.random.permutation(df[species1_col].values)
            #print(list(df[species1_col][0:10]))
            ssdone = False
            stats = stats_on_df(
                    df,
                    params.outgroup,
                    params.species1,
                    params.species2,
                    params.species3)
            print("{}\t{}\t{}".format(stats[3], stats[4],
                                      stats[5]), file = outhandle)

            # cleanup
            if counter % 10 == 0:
                elapsed = time.time() - start_time
                days,  rem       = divmod(elapsed, 86400)
                hours, rem       = divmod(rem, 3600)
                minutes, seconds = divmod(rem, 60)
                days = int(days)
                hours = int(hours)
                minutes = int(minutes)
                seconds = int(seconds)
                print("  Done with {} iterations. Time elapsed: {}d {}h {}m {}s. \r".format(
                    counter, days, hours, minutes, seconds ), end = "\r")
            if counter == params.num_rounds:
                break
            counter = counter + 1

        outhandle.close()

# randomization simulation
rule randomization_initialization:
    """
    This measures the initial state of the genome.
    """
    input:
        rbh_file = config["rbh_file"]
    output:
        random_sim = config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_random_sim_init.txt",
        annotated_gb = config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_rows_annotated.rbh.groupby",
        sp1_sisdf  = expand(config["tool"] + "/sim_randomization/{{spstr}}/initialization/{{spstr}}_df_of_groups_supporting_{sp}_sister.txt",
                            sp = three_sp),
        sp1_FUdf   = expand(config["tool"] + "/sim_randomization/{{spstr}}/initialization/{{spstr}}_df_of_fusionsUnmixed_{sp}.txt",
                            sp = three_sp),
        sp1groupdf = expand(config["tool"] + "/sim_randomization/{{spstr}}/initialization/{{spstr}}_groups_supporting_{sp}_sister.txt",
                            sp = three_sp)
    threads: 1
    params:
        outgroup = config["outgroup"],
        species1 = config["species1"],
        species2 = config["species2"],
        species3 = config["species3"],
        num_rounds = num_rand_rounds,
        df_prefix = lambda wildcards: config["tool"] + "/sim_randomization/{st}/initialization/{st}_df_of_groups_supporting_".format(st = wildcards.spstr),
        df_suffix = "_sister.txt",
        FU_prefix = lambda wildcards: config["tool"] + "/sim_randomization/{st}/initialization/{st}_df_of_fusionsUnmixed_".format(st = wildcards.spstr),
        grouppref = lambda wildcards: config["tool"] + "/sim_randomization/{st}/initialization/{st}_groups_supporting_".format(st = wildcards.spstr),
        groupsuff = "_sister.txt",
    run:
        # now we load up the rbh table
        df = pd.read_csv(input.rbh_file, index_col = None, sep = "\t")
        #scaf_cols = [x for x in df.columns if "_scaf" in x]
        #pos_cols =  [x for x in df.columns if "_pos" in x]
        #df = df[['rbh'] + scaf_cols + pos_cols]

        stats = stats_on_df(
            df,
            params.outgroup,
            params.species1,
            params.species2,
            params.species3)

        # print out the dataframes of groups that support each hypothesis
        sp1out = params.df_prefix + params.species1 + params.df_suffix
        stats[0].to_csv(sp1out, sep="\t", index = False)
        sp2out = params.df_prefix + params.species2 + params.df_suffix
        stats[1].to_csv(sp2out, sep="\t", index = False)
        sp3out = params.df_prefix + params.species3 + params.df_suffix
        stats[2].to_csv(sp3out, sep="\t", index = False)

        # print out the number of groups supporting each hypothesis
        outhandle = open(output.random_sim, "w")
        print("{}\t{}\t{}".format(stats[3], stats[4], stats[5]),
              file = outhandle)
        outhandle.close()

        # just print the group counts
        sp1out = params.grouppref + params.species1 + params.groupsuff
        with open(sp1out, "w") as f:
            for x in stats[6]:
                print(x, file = f)
        sp1out = params.grouppref + params.species2 + params.groupsuff
        with open(sp1out, "w") as f:
            for x in stats[7]:
                print(x, file = f)
        sp1out = params.grouppref + params.species3 + params.groupsuff
        with open(sp1out, "w") as f:
            for x in stats[8]:
                print(x, file = f)

        # print out the dataframs of groups that have fusions, but not fusions with mixing
        sp1out = params.FU_prefix + params.species1 + ".txt"
        stats[9].to_csv(sp1out, sep="\t", index = False)
        sp2out = params.FU_prefix + params.species2 + ".txt"
        stats[10].to_csv(sp2out, sep="\t", index = False)
        sp3out = params.FU_prefix + params.species3 + ".txt"
        stats[11].to_csv(sp3out, sep="\t", index = False)

        # print out the annotated groupby dataframe
        stats[-1].to_csv(output.annotated_gb, sep = "\t", index = False)



rule make_randomization_summary:
    """
    This reads in all of the randomization data and turns it into a plottable,
    machine-readable file.

    The goal is to only execute this once per randomization trial.

    Get these fields for each hypothesis:
      - max_number_of_genes_in_grouping
      - total_number_of_genes_in_FLGs
      - total_number_of_groupings
      - mean_number_of_genes_in_each_FLG
    """
    input:
        random_sim = config["tool"] + "/sim_randomization/{spstr}/trials/{spstr}_random_sim_{species}_{randcore}.txt"
    output:
        sim_plot   = config["tool"] + "/sim_randomization/{spstr}/plottable/{spstr}_random_sim_{species}_{randcore}_plottable.txt"
    run:
        outhandle = open(output.sim_plot, "w")
        with open(input.random_sim, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    fields = line.split("\t")
                    hyp1_data  = ast.literal_eval(fields[0])
                    hyp2_data  = ast.literal_eval(fields[1])
                    hyp3_data  = ast.literal_eval(fields[2])
                    hyp1_unwrapped = [item for sublist in hyp1_data
                                      for item in sublist]
                    hyp2_unwrapped = [item for sublist in hyp2_data
                                      for item in sublist]
                    hyp3_unwrapped = [item for sublist in hyp3_data
                                      for item in sublist]

                    hyp1_max_number_of_genes_in_single_FLG = 0
                    hyp1_max_number_of_genes_in_grouping   = 0
                    hyp1_total_number_of_genes_in_FLGs     = 0
                    hyp1_total_number_of_groupings         = 0
                    hyp1_mean_number_of_genes_in_FLGs      = 0

                    hyp2_max_number_of_genes_in_single_FLG = 0
                    hyp2_max_number_of_genes_in_grouping   = 0
                    hyp2_total_number_of_genes_in_FLGs     = 0
                    hyp2_total_number_of_groupings         = 0
                    hyp2_mean_number_of_genes_in_FLGs      = 0

                    hyp3_max_number_of_genes_in_single_FLG = 0
                    hyp3_max_number_of_genes_in_grouping   = 0
                    hyp3_total_number_of_genes_in_FLGs     = 0
                    hyp3_total_number_of_groupings         = 0
                    hyp3_mean_number_of_genes_in_FLGs      = 0

                    # max number of genes in single FLG
                    try:
                        hyp1_max_number_of_genes_in_single_FLG = max(hyp1_unwrapped)
                    except:
                        pass
                    try:
                        hyp2_max_number_of_genes_in_single_FLG = max(hyp2_unwrapped)
                    except:
                        pass
                    try:
                        hyp3_max_number_of_genes_in_single_FLG = max(hyp2_unwrapped)
                    except:
                        pass

                    # max number of genes in grouping
                    try:
                        hyp1_max_number_of_genes_in_grouping   = max(sum(x) for x in hyp1_data)
                    except:
                        pass
                    try:
                        hyp2_max_number_of_genes_in_grouping   = max(sum(x) for x in hyp2_data)
                    except:
                        pass
                    try:
                        hyp3_max_number_of_genes_in_grouping   = max(sum(x) for x in hyp3_data)
                    except:
                        pass

                    # total number of genes in FLGs
                    hyp1_total_number_of_genes_in_FLGs     = sum(hyp1_unwrapped)
                    hyp2_total_number_of_genes_in_FLGs     = sum(hyp2_unwrapped)
                    hyp3_total_number_of_genes_in_FLGs     = sum(hyp3_unwrapped)

                    # total number of groupings
                    hyp1_total_number_of_groupings         = len(hyp1_data)
                    hyp2_total_number_of_groupings         = len(hyp2_data)
                    hyp3_total_number_of_groupings         = len(hyp3_data)

                    # mean number of genes in FLGs
                    try:
                        hyp1_mean_number_of_genes_in_FLGs = sum(hyp1_unwrapped)/len(hyp1_unwrapped)
                    except:
                        pass
                    try:
                        hyp2_mean_number_of_genes_in_FLGs = sum(hyp2_unwrapped)/len(hyp2_unwrapped)
                    except:
                        pass
                    try:
                        hyp3_mean_number_of_genes_in_FLGs = sum(hyp3_unwrapped)/len(hyp3_unwrapped)
                    except:
                        pass

                    print("\t".join([str(x) for x in [
                           hyp1_max_number_of_genes_in_single_FLG,
                           hyp1_max_number_of_genes_in_grouping,
                           hyp1_total_number_of_genes_in_FLGs,
                           hyp1_total_number_of_groupings,
                           hyp1_mean_number_of_genes_in_FLGs,
                           hyp2_max_number_of_genes_in_single_FLG,
                           hyp2_max_number_of_genes_in_grouping,
                           hyp2_total_number_of_genes_in_FLGs,
                           hyp2_total_number_of_groupings,
                           hyp2_mean_number_of_genes_in_FLGs,
                           hyp3_max_number_of_genes_in_single_FLG,
                           hyp3_max_number_of_genes_in_grouping,
                           hyp3_total_number_of_genes_in_FLGs,
                           hyp3_total_number_of_groupings,
                           hyp3_mean_number_of_genes_in_FLGs
                                                     ]
                                     ]),
                                     file = outhandle)
        outhandle.close()

rule cat_together_randomization_plottables:
    """
    This just cats together all the plottables into a single file
    so we can easily load it into seaborn later
    """
    input:
        sim_plot = expand(config["tool"] + "/sim_randomization/{{spstr}}/plottable/{{spstr}}_random_sim_{{species}}_{randcore}_plottable.txt",
               randcore = list(range(1, num_rand_rounds+1)) )
    output:
        outfile = config["tool"] + "/sim_randomization/{spstr}/plottable_final/{spstr}_random_sim_final_plottable_{species}.txt"
    shell:
        """
        echo "" | awk '{{printf("hyp1_max_number_of_genes_in_single_FLG\\thyp1_max_number_of_genes_in_grouping\\thyp1_total_number_of_genes_in_FLGs\\thyp1_total_number_of_groupings\\thyp1_mean_number_of_genes_in_FLGs\\thyp2_max_number_of_genes_in_single_FLG\\thyp2_max_number_of_genes_in_grouping\\thyp2_total_number_of_genes_in_FLGs\\thyp2_total_number_of_groupings\\thyp2_mean_number_of_genes_in_FLGs\\thyp3_max_number_of_genes_in_single_FLG\\thyp3_max_number_of_genes_in_grouping\\thyp3_total_number_of_genes_in_FLGs\\thyp3_total_number_of_groupings\\thyp3_mean_number_of_genes_in_FLGs\\n")}}' > {output.outfile}
        cat {input.sim_plot} >> {output.outfile}
        """

rule plottable_init_file:
    """
    just make the initial measurement of the data plottable
    """
    input:
        init = config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_random_sim_init.txt"
    output:
        init = config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_random_sim_init_plottable.txt"
    run:
        # get the info for the init data
        hyp1_max_number_of_genes_in_single_FLG = 0
        hyp1_max_number_of_genes_in_grouping   = 0
        hyp1_total_number_of_genes_in_FLGs     = 0
        hyp1_total_number_of_groupings         = 0
        hyp1_mean_number_of_genes_in_FLGs      = 0

        hyp2_max_number_of_genes_in_single_FLG = 0
        hyp2_max_number_of_genes_in_grouping   = 0
        hyp2_total_number_of_genes_in_FLGs     = 0
        hyp2_total_number_of_groupings         = 0
        hyp2_mean_number_of_genes_in_FLGs      = 0

        hyp3_max_number_of_genes_in_single_FLG = 0
        hyp3_max_number_of_genes_in_grouping   = 0
        hyp3_total_number_of_genes_in_FLGs     = 0
        hyp3_total_number_of_groupings         = 0
        hyp3_mean_number_of_genes_in_FLGs      = 0

        with open(input.init, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    fields = line.split("\t")
                    hyp1_data  = ast.literal_eval(fields[0])
                    hyp2_data  = ast.literal_eval(fields[1])
                    hyp3_data  = ast.literal_eval(fields[2])
                    hyp1_unwrapped = [item for sublist in hyp1_data
                                      for item in sublist]
                    hyp2_unwrapped = [item for sublist in hyp2_data
                                      for item in sublist]
                    hyp3_unwrapped = [item for sublist in hyp3_data
                                      for item in sublist]

                    # max number of genes in single FLG
                    try:
                        hyp1_max_number_of_genes_in_single_FLG = max(hyp1_unwrapped)
                    except:
                        pass
                    try:
                        hyp2_max_number_of_genes_in_single_FLG = max(hyp2_unwrapped)
                    except:
                        pass
                    try:
                        hyp3_max_number_of_genes_in_single_FLG = max(hyp3_unwrapped)
                    except:
                        pass

                    # max number of genes in grouping
                    try:
                        hyp1_max_number_of_genes_in_grouping = max(sum(x) for x in hyp1_data)
                    except:
                        pass
                    try:
                        hyp2_max_number_of_genes_in_grouping = max(sum(x) for x in hyp2_data)
                    except:
                        pass
                    try:
                        hyp3_max_number_of_genes_in_grouping = max(sum(x) for x in hyp3_data)
                    except:
                        pass

                    # total number of genes in FLGs
                    hyp1_total_number_of_genes_in_FLGs = sum(hyp1_unwrapped)
                    hyp2_total_number_of_genes_in_FLGs = sum(hyp2_unwrapped)
                    hyp3_total_number_of_genes_in_FLGs = sum(hyp3_unwrapped)

                    # total number of groupings
                    hyp1_total_number_of_groupings = len(hyp1_data)
                    hyp2_total_number_of_groupings = len(hyp2_data)
                    hyp3_total_number_of_groupings = len(hyp3_data)

                    # mean number of genes in FLGs
                    try:
                        hyp1_mean_number_of_genes_in_FLGs = sum(hyp1_unwrapped)/len(hyp1_unwrapped)
                    except:
                        pass
                    try:
                        hyp2_mean_number_of_genes_in_FLGs = sum(hyp2_unwrapped)/len(hyp2_unwrapped)
                    except:
                        pass
                    try:
                        hyp3_mean_number_of_genes_in_FLGs = sum(hyp3_unwrapped)/len(hyp3_unwrapped)
                    except:
                        pass

        outhandle = open(output.init, "w")
        print("\t".join(["hyp1_max_number_of_genes_in_single_FLG",
                         "hyp1_max_number_of_genes_in_grouping",
                         "hyp1_total_number_of_genes_in_FLGs",
                         "hyp1_total_number_of_groupings",
                         "hyp1_mean_number_of_genes_in_FLGs",

                         "hyp2_max_number_of_genes_in_single_FLG",
                         "hyp2_max_number_of_genes_in_grouping",
                         "hyp2_total_number_of_genes_in_FLGs",
                         "hyp2_total_number_of_groupings",
                         "hyp2_mean_number_of_genes_in_FLGs",

                         "hyp3_max_number_of_genes_in_single_FLG",
                         "hyp3_max_number_of_genes_in_grouping",
                         "hyp3_total_number_of_genes_in_FLGs",
                         "hyp3_total_number_of_groupings",
                         "hyp3_mean_number_of_genes_in_FLGs",
                         ]),
              file = outhandle)
        print("\t".join([str(x) for x in [
                         hyp1_max_number_of_genes_in_single_FLG,
                         hyp1_max_number_of_genes_in_grouping,
                         hyp1_total_number_of_genes_in_FLGs,
                         hyp1_total_number_of_groupings,
                         hyp1_mean_number_of_genes_in_FLGs,

                         hyp2_max_number_of_genes_in_single_FLG,
                         hyp2_max_number_of_genes_in_grouping,
                         hyp2_total_number_of_genes_in_FLGs,
                         hyp2_total_number_of_groupings,
                         hyp2_mean_number_of_genes_in_FLGs,

                         hyp3_max_number_of_genes_in_single_FLG,
                         hyp3_max_number_of_genes_in_grouping,
                         hyp3_total_number_of_genes_in_FLGs,
                         hyp3_total_number_of_groupings,
                         hyp3_mean_number_of_genes_in_FLGs,
                      ]]),
              file = outhandle)
        outhandle.close()


rule plot_joint_clouds:
    """
    Makes density maps with marginal distribution plots.
    """
    input:
        data = config["tool"] + "/sim_randomization/{spstr}/plottable_final/{spstr}_random_sim_final_plottable_{species}.txt",
        init = config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_random_sim_init_plottable.txt"
    output:
        pdf   = config["tool"] + "/figures/{spstr}/{spstr}_randomization_marginals_{species}-randomizations.pdf",
        stats = config["tool"] + "/figures/{spstr}/{spstr}_randomization_stats_{species}-randomizations.txt"
    run:
        import matplotlib.cm as cm
        outhandle = open(output.stats, "w")
        initdf = pd.read_csv(input.init, index_col = None, sep = "\t")
        print("# INIT STATS", file = outhandle)
        print(initdf.iloc[0], file=outhandle)
        print("", file = outhandle)

        df = pd.read_csv(input.data, index_col = None, sep = "\t")

        # just collect some basic stats
        print("# SIMULATION STATS", file = outhandle)
        print("total number of simulations: {}".format(len(df)), file = outhandle)
        print("minimum alpha: {}".format(1/len(df)), file = outhandle)
        print("", file = outhandle)

        # collect stats about each individual field
        print("# SINGLE VARIABLE STATS", file = outhandle)
        for thisfield in df.columns:
            initval = initdf.iloc[0][thisfield]
            tempdf = df.loc[df[thisfield] >= initval,]
            tempdf = tempdf.loc[tempdf[thisfield] != 0, ]
            print("Variable: {}".format(thisfield), file =outhandle)
            print("  - Num of simulations >= init's value of {}: {}".format(
                initval, len(tempdf)),
                  file = outhandle)
            if len(tempdf) == 0:
                print("  - alpha: too_low_to_measure",
                      file = outhandle)
            else:
                print("  - alpha: {}".format(len(tempdf)/len(df)),
                      file = outhandle)
        print("", file = outhandle)

        # print joint probabilities
        print("# TWO-VAR JOINT PROB", file = outhandle)
        for thishyp in ["hyp1", "hyp2", "hyp3"]:
            these = [x for x in df.columns if thishyp in x]
            combos = list(itertools.combinations(these, 2))
            for thiscombo in combos:
                field0 = thiscombo[0]
                field1 = thiscombo[1]
                initval0 = initdf.iloc[0][field0]
                initval1 = initdf.iloc[0][field1]
                tempdf = df.loc[df[field0] >= initval0, ]
                tempdf = tempdf.loc[tempdf[field0] != 0, ]
                tempdf = tempdf.loc[tempdf[field1] >= initval1, ]
                tempdf = tempdf.loc[tempdf[field1] != 0, ]
                print("Variables: {} and {}".format(field0, field1), file =outhandle)
                print("  - Num of simulations >= init's values of {} and {}: {}".format(
                    initval0, initval1, len(tempdf)),
                      file = outhandle)
                if len(tempdf) == 0:
                    print("  - alpha: too_low_to_measure",
                          file = outhandle)
                else:
                    print("  - alpha: {}".format(len(tempdf)/len(df)),
                          file = outhandle)
            print("", file = outhandle)
        print("", file = outhandle)

        import seaborn as sns
        sns.set_theme(style="ticks")

        # define custom colormap with fixed colour and alpha gradient
        # use simple linear interpolation in the entire scale
        cdict = {'red':   [(0.,0.,0),
                           (0.,0.,0)],

                 'green': [(0.,0.,0.),
                           (0.,0.,0.)],

                 'blue':  [(0.,0.,0.),
                           (0.,0.,0.)],

                 'alpha': [(1.,1,1),
                           (0.,0,0)]}

        cm.register_cmap(cmap=LinearSegmentedColormap("alpha_gradient", cdict))

        N = 256
        vals = np.ones((N, 4))
        vals[:, 0] = np.linspace(0, 0, N)
        vals[:, 1] = np.linspace(0, 0, N)
        vals[:, 2] = np.linspace(0, 0, N)
        vals[:, 3] = np.linspace(0, 1, N)
        newcmp = ListedColormap(vals)


        with PdfPages(output.pdf) as pdf_pages:
            possible_vars = list(set(["_".join(x.split("_")[1::]) for x in df.columns]))
            combos = list(itertools.combinations(possible_vars, 2))
            for thiscombo in combos:
                for thishyp in ["hyp1", "hyp2", "hyp3"]:

                    for orientation in [[thiscombo[0], thiscombo[1]],
                                        [thiscombo[1], thiscombo[0]]]:
                        initval0 = initdf.iloc[0]["{}_{}".format(thishyp, orientation[0])]
                        initval1 = initdf.iloc[0]["{}_{}".format(thishyp, orientation[1])]

                        # PLOTTING hyp1
                        field0 = "{}_{}".format(thishyp, orientation[0])
                        field1 = "{}_{}".format(thishyp, orientation[1])
                        tempdf = df[[field0, field1]]
                        tempdf = tempdf.drop_duplicates(
                            subset = [field0, field1]).reset_index(drop = True)

                        xmax = max(initval0, max(df[field0])) * 1.2
                        ymax = max(initval1, max(df[field1])) * 1.2
                        print("plotting {} and {}".format(field0, field1))
                        g = sns.JointGrid(data=df,
                                          x=field0,
                                          y=field1,
                                          xlim = [0,xmax],
                                          ylim = [0,ymax],
                                          marginal_ticks=True)
                        # plot the outliers
                        g.ax_joint.scatter(
                            data=tempdf,
                            x=field0,
                            y=field1,
                            linewidth = 0,
                            c='#808080', marker='o',
                            s=2)
                        # plot the init data
                        g.ax_joint.scatter(
                            data=initdf,
                            x=field0,
                            y=field1,
                            c='r', marker='o',
                            zorder=10)
                        cax = g.figure.add_axes([.15, .55, .02, .2])

                        g.plot_joint(sns.histplot, discrete=(True, True),
                            #cmap="light:#03012d", cbar=True, cbar_ax=cax)
                            #pthresh = 0.0001,
                            #cmap="gist_yarg", cbar=True, cbar_ax=cax)
                            cmap=newcmp, cbar=True, cbar_ax=cax)

                        maxfield0 = max(df[field0])
                        maxfield1 = max(df[field1])

                        xbins=np.arange(0, maxfield0, 1)
                        _ = g.ax_marg_x.hist(df[field0], color="#000000", alpha=.6,
                                              bins=xbins, align="left")
                        ybins=np.arange(0, maxfield1, 1)
                        _ = g.ax_marg_y.hist(df[field1], color="#000000", alpha=.6,
                                              bins=ybins,
                                              orientation="horizontal", align="left")

                        fig = g.fig
                        pdf_pages.savefig(fig)
                        plt.close(g.fig)

rule plot_hypothesis_support:
    """
    Plots the hypothesis support of S1 or S2 as sister.
    """
    input:
        data = config["tool"] + "/sim_randomization/{spstr}/plottable_final/{spstr}_random_sim_final_plottable_{species}.txt",
        init = config["tool"] + "/sim_randomization/{spstr}/initialization/{spstr}_random_sim_init_plottable.txt"
    output:
        pdf =  config["tool"] + "/figures/{spstr}/{spstr}_support_{species}-randomizations.pdf"
    params:
        species = lambda wildcards: wildcards.species
    run:
        import matplotlib.cm as cm
        from matplotlib import pyplot as plt
        import matplotlib
        matplotlib.rcParams['pdf.fonttype'] = 42
        matplotlib.rcParams['ps.fonttype'] = 42
        initdf = pd.read_csv(input.init, index_col = None, sep = "\t")
        df = pd.read_csv(input.data, index_col = None, sep = "\t")

        import seaborn as sns
        sns.set_theme(style="ticks")

        possible_vars = list(set(["_".join(x.split("_")[1::]) for x in df.columns]))
        plot_these = ["total_number_of_genes_in_FLGs",
                      "total_number_of_groupings",
                      "max_number_of_genes_in_grouping",
                      "max_number_of_genes_in_single_FLG"]

        plot_to_xlabel = {"total_number_of_genes_in_FLGs": "Genes in PI LGs",
                          "total_number_of_groupings": "Number of LG groupings",
                          "max_number_of_genes_in_grouping": "Genes in largest LG grouping",
                          "max_number_of_genes_in_single_FLG": "Genes in largest LG"
                         }

        for x in plot_these:
            if x not in possible_vars:
                raise IOError("{} was not in the input df".format(x))

        with PdfPages(output.pdf) as pdf_pages:
            for thisthing in plot_these:
                fig, ax = plt.subplots(3, 1, sharey = True, sharex = True)
                fig.suptitle("{}, {} randomized".format(thisthing, params.species), fontsize=8)
                fig.set_size_inches(5, 5)
                columns = ["{}_{}".format(x, thisthing) for x in ["hyp1", "hyp2", "hyp3"]]
                index = 0

                hyp1_init = initdf.iloc[0][columns[0]]
                hyp2_init = initdf.iloc[0][columns[1]]
                hyp3_init = initdf.iloc[0][columns[2]]
                xmax = int(max([hyp1_init, hyp2_init, hyp3_init,
                            max(df[columns[0]]),
                            max(df[columns[1]]),
                            max(df[columns[2]])
                          ]) * 1.2)
                ymax = max([len(df[x]) for x in columns])
                for thiscol in columns:
                    sns.histplot(df[thiscol], discrete = True,
                                 label=thiscol, color = "#000000",
                                      ax=ax[index])
                    ax[index].set_title(thiscol, fontsize = 6)
                    ax[index].ticklabel_format(axis='y', style='sci', scilimits=(0,0))
                    ax[index].tick_params(axis='both', which='major', labelsize=8)
                    ax[index].xaxis.label.set_size(8)
                    ax[index].yaxis.label.set_size(8)
                    sns.despine(offset=2, trim=False)
                    index += 1

                # set the plotting limits
                #ax[0].set_xlim([-1,xmax])
                #ax[1].set_xlim([-1,xmax])
                ylim = ax[0].get_ylim()[-1]
                ax[0].axvline(hyp1_init, 0, (ylim*0.75)/ylim, color = "#c44e52")
                ax[1].axvline(hyp2_init, 0, (ylim*0.75)/ylim, color = "#c44e52")
                ax[2].axvline(hyp3_init, 0, (ylim*0.75)/ylim, color = "#c44e52")

                # now set the names
                ax[0].set(ylabel="count\n{} sister".format(config["species1"]))
                ax[1].set(ylabel="count\n{} sister".format(config["species2"]))
                ax[2].set(ylabel="count\n{} sister".format(config["species3"]))

                # set the xlabel
                ax[2].set_xlabel(plot_to_xlabel[thisthing])
                plt.tight_layout()

                pdf_pages.savefig(fig)
                plt.close()
