"""
This script takes the grouby rbh hits file, to which a group column has
  been added. This script unwraps that file into a row-by-row
  reciprocal best hit file.

the input file must be named:
  {sample1}_{sample2}_{sampleN}_reciprocal_best_hits.*.groupby

... where sample1, sample2, sampleN is the sample name in the config.

There can be as many samples as you'd like.
"""

import ast
from itertools import groupby
from itertools import product
import math
import numpy as np
import odp_functions as odpf
from operator import itemgetter
import pandas as pd
import statistics
import sys

configfile: "config.yaml"

config["tool"] = "odp_groupby_to_rbh"

odpf.check_legality(config)

if "groupby_files" not in config:
    raise IOError("You must specify 'groupby_files' in config")

# come up with all of the analyses
analyses = []
for thisfile in config["groupby_files"]:
    df = pd.read_csv(thisfile, sep = "\t", index_col = None)
    all_species = [x.replace("_scaf","") for x in df.columns
                   if x.endswith("_scaf")]
    species_string = "_".join(sorted(all_species))

    print(species_string)
    num_species = len(all_species)
    if num_species < 3:
        raise IOError("There must be more than two species.")
    analyses.append(species_string)

def parse_pd_list_or_string(pd_list, rowcount):
    templist = []
    if type(pd_list) == str:
        for entry in ["nan", "NaN"]:
            pd_list = pd_list.replace(entry, "None")
        if '[' in pd_list:
            # the thing is a list, parse it as a list
            return ast.literal_eval(pd_list)
        else:
            # if it isn't a list, just
            print(pd_list, rowcount)
            return [pd_list] * rowcount
    elif (type(pd_list) == float) and (np.isnan(pd_list)):
        return [None] * rowcount

rule all:
    input:
        # just the unwrapped data
        expand(config["tool"] + "/output/{sp}.unwrapped.rbh",
               sp = analyses),
        expand(config["tool"] + "/output/{sp}.colors.tsv",
               sp = analyses)

rule unwrap_rbh_file_with_group_column:
    input:
        rbh_files = config["groupby_files"]
    output:
        rbh_unwrapped = expand(config["tool"] + "/output/{sp}.unwrapped.rbh",
                               sp = analyses),
        colors        = expand(config["tool"] + "/output/{sp}.colors.tsv",
                               sp = analyses)
    params:
        prefix = config["tool"] + "/output/"
    threads: 1
    run:
        for thisfile in input.rbh_files:
            df = pd.read_csv(thisfile, sep = "\t", index_col = None)
            if len([x for x in df.columns if x.startswith("Unnamed")]) > 0:
                print("There is a column that is unnamed, and probably doesn't have data in that column. Check your input.", file = sys.stderr)
            df = df[[x for x in df.columns if not x.startswith("Unnamed")]]
            print(df)
            if True in [np.isnan(x) for x in df["count"].unique()]:
                raise IOError("There are some rows in this dataset where count is nan. That should not occur in a rbh or rbh.groupby file. Check that your input has no rows that are erroneously empty.")

            # get all the species in the dataframe
            complete_species_list = [x.split("_")[0] for x in df.columns
                                     if x.endswith("_scaf")]
            species_string = "_".join(sorted(complete_species_list))
            print(species_string)

            rbh_entries = []
            for index, row in df.iterrows():
                rbh_list = ast.literal_eval(row["rbh"])
                species_to_gene_list = {}
                species_to_gene_pos  = {}
                species_to_gene_scaf = {}
                # make lookup tables
                for thisspecies in complete_species_list:
                    rowcount = row["count"]
                    parse_pairs = [("gene", species_to_gene_list),
                                   ("pos",  species_to_gene_pos),
                                   ("scaf", species_to_gene_scaf) ]
                    for thiscol, thislist in parse_pairs:
                        # gene_col
                        colname = "{}_{}".format(thisspecies, thiscol)
                        #species_to_gene_list[thisspecies] = ast.literal_eval(row[colname])
                        thislist[thisspecies] = parse_pd_list_or_string(
                                                             row[colname], rowcount)

                for i in range(len(rbh_list)):
                    thisgroup = row["gene_group"]
                    thisentry = {"rbh": rbh_list[i],
                                 "gene_group": thisgroup,
                                 "count":      row["count"]}
                    for keep_this in ["ALG", "color", "alpha", "alpha_type"]:
                        if keep_this in row:
                            thisentry[keep_this] = row[keep_this]
                    for thisspecies in complete_species_list:
                        # get col names
                        scafcol = "{}_scaf".format(thisspecies)
                        genecol = "{}_gene".format(thisspecies)
                        poscol  = "{}_pos".format(thisspecies)
                        thisentry[scafcol] = species_to_gene_scaf[thisspecies][i]
                        thisentry[genecol] = species_to_gene_list[thisspecies][i]
                        thisentry[poscol]  = species_to_gene_pos[thisspecies][i]
                    rbh_entries.append(thisentry)

            unwrapped = pd.DataFrame(rbh_entries)
            unwrapped = unwrapped[[x for x in unwrapped.columns
                                   if x not in ["count", "alpha", "alpha_type"]]]
            outfile = params.prefix + species_string + ".unwrapped.rbh"
            print(outfile)
            unwrapped.to_csv(outfile, sep="\t", na_rep = "None", index = False)

            # now we go through the dataframe and check print out the colors
            outfile = params.prefix + species_string + ".colors.tsv"
            outhandle = open(outfile, "w")

            if "color" not in unwrapped:
                print("There is no color column, so the color outfile will be empty.",
                      file = sys.stderr)
                print("There is no color column, so the color outfile will be empty.", file = outhandle)
            else:
                print(unwrapped)
                gene_cols = [x for x in unwrapped.columns if "_gene" in x]
                for index, row in unwrapped.iterrows():
                    for thiscol in gene_cols:
                        if row[thiscol] != None:
                            sp = thiscol.replace("_gene", "")
                            scaf = row["{}_scaf".format(sp)]
                            pos  = int(row["{}_pos".format(sp)])
                            print("{}\t{}\t{}\t{}\t{}\t{}\t{}".format(
                                sp, scaf, pos,
                                row[thiscol], row["color"], row["gene_group"],
                                row["rbh"]),
                                  file = outhandle)
            outhandle.close()
