"""
This merges the RBH files based on the anchor species. Makes a special RBH
that has some missing data.
"""

import numpy as np
import odp_functions as OdpF
import pandas as pd

configfile: "config.yaml"

config["tool"] = "odp_rbh_merge"

OdpF.check_legality(config)

# make sure we have the right info in the config.yaml
for entry in ["rbh_files", "anchor_species"]:
    if entry not in config:
        raise IOError("You must specify '{}' in config".format(entry))

# make sure all the files in rbh_files end in ".rbh"
for thisfile in config["rbh_files"]:
    if not thisfile.endswith(".rbh"):
        raise IOError("The rbh file {} must end with .rbh".format(thisfile))
    # make sure that each file has all the anchor species
    for thisanchor in config["anchor_species"]:
        if thisanchor not in thisfile:
            raise IOError("Species {} was not in {}, but should have been.".format(
                thisanchor, thisfile))

anchor_string = "_".join(sorted(config["anchor_species"]))

rule all:
    input:
        # the groupby file
        config["tool"] + "/output/" + anchor_string + ".mer"

rule groupby_RBH_results:
    input:
        rbh   = config["rbh_files"]
    output:
        rbh = config["tool"] + "/output/" + anchor_string + ".mer"
    threads: 1
    params:
        anchors = ["{}_gene".format(x) for x in config["anchor_species"]]
    run:
        df_list = []
        for thisfile in input.rbh:
            df = pd.read_csv(thisfile, sep = "\t")
            df_list.append(df)
        all_data = pd.concat(df_list)
        all_data = all_data[[x for x in all_data.columns
                             if x not in ["rbh"]]]

        # There will invariably be some overlaps.
        # We have to go through and handle the duplicates
        grouped = all_data.groupby(params.anchors)
        keeps = []
        for name, group in grouped:
            # this collapses each row.
            # We are anchoring around the common species,
            #  so we can be a little creative
            tempgroup = group.set_index(params.anchors[0]).stack().groupby(level=[0, 1]).head(1).unstack().reset_index()
            keeps.append(tempgroup)

        # now we turn this into an rbh
        grouped2 = pd.concat(keeps).reset_index(drop=True)
        all_species = sorted([x.replace("_scaf", "") for x in grouped2.columns
                       if "_scaf" in x])
        grouped2["nancount"] = grouped2.isnull().sum(axis=1)

        print(grouped2.columns)
        print("Before filtering there were {} rows.".format(len(grouped2)))

        # get rid of genes with more than one entry
        for thisanchor in params.anchors + \
            ["{}_gene".format(x) for x in all_species]:
            gb = grouped2.sort_values(by=[thisanchor,"nancount"],
                                      ascending = [True, True]
                                      ).groupby(by=[thisanchor], dropna=False)
            keeps = []
            for name, group in gb:
                # keep the row with the most info
                if pd.isnull(name):
                    temprow = group.reset_index(drop=True)
                else:
                    temprow = group.head(1).reset_index(drop=True)
                keeps.append(temprow)
            grouped2 = pd.concat(keeps)
        grouped2 = grouped2[[x for x in grouped2.columns if x != "nancount"]]
        print("After filtering there were {} rows.".format(len(grouped2)))

        sp_string = "_".join(all_species)
        all_cols = list(grouped2.columns)
        grouped2["rbh"] = ["mer_{}_{}".format(sp_string, i)
                           for i in range(1, len(grouped2)+1)]
        grouped2 = grouped2[["rbh"] + all_cols]
        pd.set_option('display.float_format', lambda x: '%.3f' % x)
        grouped2.to_csv(output.rbh, sep="\t", na_rep = "nan", index = False)
