"""
This script takes the grouby RBH hits file and removes the rows
 that have a false discovery rate greater than a selected value.
"""

import pandas as pd
import odp_functions as OdpF
import sys

configfile: "config.yaml"

config["tool"] = "odp_groupby_filter"

OdpF.check_legality(config)

for this in ["groupby_files", "alpha"]:
    if this not in config:
        raise IOError("You must specify '{}' in config".format(this))

sys_alpha = float(config["alpha"])

# come up with all of the analyses
myfile = config["groupby_files"]
analyses_list = []
for thisfile in config["groupby_files"]:
    if not thisfile.endswith("_reciprocal_best_hits.rbh.groupby"):
        raise IOError("Your groupby files must end with _reciprocal_best_hits.rbh.groupby")
    species_string = thisfile.split("/")[-1].replace("_reciprocal_best_hits.rbh.groupby", "").replace("_rbhhmm_plus_other_species.rbh.groupby", "")
    analyses_list.append(species_string)
    ending = ""
    if thisfile.endswith("_reciprocal_best_hits.rbh.groupby"):
        pass
        #ending = "_reciprocal_best_hits.rbh"
    elif thisfile.endswith("_rbhhmm_plus_other_species.rbh.groupby"):
        pass
        #ending = "_rbhhmm_plus_other_species.rbh"
    else:
        raise IOError("The input groupby_file must end in _reciprocal_best_hits.rbh.groupby or _rbhhmm_plus_other_species.rbh.groupby. It currently is {}".format(config["groupby_file"].split("/")[-1]))

if len(analyses_list[0].split("_")) < 3:
    raise IOError("There must be more than two species.")

rule all:
    input:
        # Filtered
        expand(config["tool"]+"/output/{sp}_reciprocal_best_hits.rbh.filt.groupby",
               sp = analyses_list)

rule filter_rbh_files:
    input:
        rbh_files = config["groupby_files"]
    output:
        rbh_filt = expand(config["tool"]+"/output/{sp}_reciprocal_best_hits.rbh.filt.groupby",
                          sp = analyses_list)
    params:
        prefix = config["tool"]+"/output/",
        suffix = "_reciprocal_best_hits.rbh.filt.groupby",
        alpha = sys_alpha
    threads: 1
    run:
        for thisfile in input.rbh_files:
            species_string = thisfile.split("/")[-1].replace("_reciprocal_best_hits.rbh.groupby", "").replace("_rbhhmm_plus_other_species.rbh.groupby", "")
            analyses_list.append(species_string)
            ending = ""
            if thisfile.endswith("_reciprocal_best_hits.rbh.groupby"):
                ending = "_reciprocal_best_hits.rbh"
            elif thisfile.endswith("_rbhhmm_plus_other_species.rbh.groupby"):
                ending = "_rbhhmm_plus_other_species.rbh"
            else:
                raise IOError("The input groupby_file must end in _reciprocal_best_hits.rbh.groupby or _rbhhmm_plus_other_species.rbh.groupby. It currently is {}".format(config["groupby_file"].split("/")[-1]))

            df = pd.read_csv(thisfile, sep = "\t", index_col = None)
            df = df.loc[df["alpha"] <= params.alpha, ]
            outfile = params.prefix + species_string + params.suffix

            df.to_csv(outfile, sep="\t", index = False, na_rep='nan')
