"""
This converts an rbh file to a groupby file.

The input must be named like this, output from the odp_nway_rbh script.
  {sample1}_{sample2}_{sampleN}_reciprocal_best_hits.rbh
  OR
  {sample1}_{sample2}_{sampleN}_RBHhmm_plus_other_species.rbh
    to
  {sample1}_{sample2}_{sampleN}_reciprocal_best_hits.rbh.groupby

... where sample1, sample2, sampleN is the sample name in the config.

If there is a gene_group column, it will be dropped.

One important thing this script does is calculates the false discovery rate
    for finding groups of size N given a randomized dataset.

There can be as many samples as you'd like.
"""

import numpy as np
import odp_functions as OdpF
import pandas as pd

configfile: "config.yaml"

config["tool"] = "odp_rbh_to_groupby"

OdpF.check_legality(config)

# randomizations
if not "num_permutations" in config:
    raise IOError("you must specify num_permutations in the config to help calculate the false discovery rate")
num_permutations_per_round=10000
num_permutation_rounds = int(config["num_permutations"]/num_permutations_per_round)
#print("number of randomization rounds is : {}".format(num_rand_rounds))

# make sure we have the right info in the config.yaml
if "rbh_file" not in config:
    raise IOError("You must specify 'rbh_file' in config")
if "num_permutations" not in config:
    raise IOError("You must specify 'num_permutations' in config")

# make sure none of the sample names have underscores
for thissample in config["species"]:
    if "_" in thissample:
        raise IOError("Sample names can't have '_' char: {}".format(thissample))

# come up with all of the analyses
myfile = config["rbh_file"]
species_string = myfile.split("/")[-1].replace("_reciprocal_best_hits.rbh", "").replace("_rbhhmm_plus_other_species.rbh", "")
ending = ""
if config["rbh_file"].endswith("_reciprocal_best_hits.rbh"):
    ending = "_reciprocal_best_hits.rbh"
elif config["rbh_file"].endswith("_rbhhmm_plus_other_species.rbh"):
    ending = "_rbhhmm_plus_other_species.rbh"
else:
    raise IOError("The input rbh_file must end in _reciprocal_best_hits.rbh or _rbhhmm_plus_other_species.rbh. It currently is {}".format(config["rbh_file"].split("/")[-1]))
all_species = list(sorted(species_string.split("_")))
print("species_in_analysis")
print(all_species)

if len(all_species) < 3:
    raise IOError("There must be more than two species.")

analyses_list = [all_species]

# make sure all of the species in the analyses are in the config
for entry in analyses_list:
    for thisspecies in entry:
        if thisspecies not in config["species"]:
            raise IOError ("You specified {} in the analyses, but it isn't defined in the species category of the config file".format(thisspecies))

config["nways"] = len(all_species)

rule all:
    input:
        # the alpha file
        config["tool"] + "/output/FDR/" + species_string + "_FDR.tsv",
        # the groupby file
        config["tool"] + "/output/" + species_string + ending + ".groupby"

def permute_n_times(df, num_permutations):
    observations = {x: 0 for x in range(1, 5000)}
    scafs = list(df.columns)
    for i in range(num_permutations):
        for thisscaf in df.columns:
            df[thisscaf] = np.random.permutation(df[thisscaf].values)
        subbed = df.groupby(scafs).size().reset_index(name = "counts")["counts"].value_counts().to_dict()
        for key in subbed:
            observations[key] += 1
        if i % 10 == 0:
               print("  - Finished {}/{} ({:.2f}%) analyses.  ".format(
                   i, num_permutations,
                   (i/num_permutations)*100), end = "\r")
    observations = {key: observations[key] for key in observations
                     if observations[key] != 0}
    return observations

rule single_permutation_test:
    """
    This performs one FDR test of 100000 rounds.

    These results will be cat'd with the results from other threads.
    """
    input:
        rbh = config["rbh_file"]
    output:
        alpha = config["tool"] + "/output/FDR/sim/" + species_string + "_sim_{simround}.tsv"
    threads: 1
    params:
        thisround = lambda wildcards: wildcards.simround,
        num_permutations = num_permutations_per_round
    run:
        df = pd.read_csv(input.rbh, sep = "\t", index_col = 0)
        scafs = ["{}_scaf".format(x) for x in all_species]
        df = df[scafs]

        observations = permute_n_times(df, params.num_permutations)
        with open(output.alpha, "w") as f:
            for key in observations:
                print("{}\t{}".format(key, observations[key]),
                      file = f)

rule permutation_test_for_compariasons:
    input:
        fdr_results = expand(config["tool"] + "/output/FDR/sim/" + species_string + "_sim_{simround}.tsv",
               simround = list(range(1,num_permutation_rounds+1)))
    output:
        alpha = config["tool"] + "/output/FDR/" + species_string + "_FDR.tsv"
    threads: workflow.cores - 1
    params:
        num_permutations = config["num_permutations"]
    run:
        observations = {x: 0 for x in range(1, 5000)}
        for thisfile in input.fdr_results:
            with open(thisfile, "r") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        print(line)
                        fields = [int(x) for x in line.split()]
                        observations[fields[0]] += fields[1]

        resultsDF = pd.Series(observations).to_frame()
        resultsDF.reset_index(inplace = True)
        resultsDF.columns = ["Num_Genes_In_Chr_Group", "Num_Permutations"]
        resultsDF["Total_Tests"] = params.num_permutations
        resultsDF["alpha"] = resultsDF["Num_Permutations"]/params.num_permutations
        print(resultsDF)
        resultsDF.to_csv(output.alpha, sep="\t", index = False)

rule groupby_rbh_results:
    input:
        rbh   = config["rbh_file"],
        alpha = config["tool"] + "/output/FDR/" + species_string + "_FDR.tsv"
    output:
        rbh  = config["tool"] + "/output/" + species_string + ending + ".groupby"
    threads: 1
    run:
        species_string = myfile.split("/")[-1].replace("_reciprocal_best_hits.rbh", "").replace("_rbhhmm_plus_other_species.rbh", "")
        all_species = [x for x in species_string.split("_")]
        match_string = "_".join(all_species)

        df = pd.read_csv(input.rbh, sep = "\t")
        groupbycols = ["{}_scaf".format(x) for x in all_species]
        alphadf = pd.read_csv(input.alpha, sep = "\t")

        # calculate the alpha score for each level
        df = df.reset_index(drop = True)
        print(df)
        print(df.columns)
        grouped_multiple = df.groupby(groupbycols).agg(list).reset_index()

        # get the size
        grouped_multiple["count"] = grouped_multiple.rbh.str.len()
        grouped_mutiple = grouped_multiple.loc[grouped_multiple["count"] > 1, ]

        # sort
        grouped_multiple = grouped_multiple.sort_values(by="count", ascending=False)
        alpha_dict = dict(zip(alphadf.Num_Genes_In_Chr_Group, alphadf.alpha))
        grouped_multiple["alpha"] = grouped_multiple["count"].map(alpha_dict)
        grouped_multiple["alpha_type"] = "equal_to"
        for index, row in grouped_multiple.iterrows():
            if row["alpha"] == 0:
                # now reassign the alpha to the next group that we've seen
                done = False
                countdown_index = 1
                while not done:
                    new_alpha = alpha_dict[row["count"] - countdown_index]
                    if new_alpha == 0:
                        countdown_index += 1
                    else:
                        done = True
                        grouped_multiple.loc[index, "alpha"] = new_alpha
                        grouped_multiple.loc[index, "alpha_type"] = "less_than"
        grouped_multiple["gene_group"] = "None"
        upfront   = ["rbh", "gene_group",
                     "count", "alpha", "alpha_type"]
        other_pt1 = [x for x in grouped_multiple.columns
                     if x.split("_")[0] not in all_species]
        other_pt2 = list(sorted([x for x in other_pt1
                                 if x not in upfront]))
        species_cols = list(sorted([x for x in grouped_multiple.columns
                                    if x.split("_")[0] in all_species]))
        upfront_species_other = upfront + species_cols + other_pt2
        grouped_multiple = grouped_multiple[upfront_species_other]
        grouped_multiple.to_csv(output.rbh, sep="\t", index = False)
