"""
This script finds linkage groups between multiple chromosome-scale genomes

Step 1: Performs n-ways reciprocal best blastp searches then outputs how many
        pairs are conserved between them. The algorithm can use blastp or
        DIAMOND blastp. Jackhammer is not implemented for this because it is
        very slow and doesn't perform better than blastp. The output is .rbh
        files.

Step 2: Converts an rbh file to a groupby file. This script calculates the
        false discovery rate for finding groups of size N given a randomized
        dataset.

Step 3: Filters the groupby file to only include groups of a size with a FDR
        less than or equal to 0.05.

Step 4: Convert the filtered groupby file back to a rbh file.
"""

import copy
import ast
from itertools import groupby
from itertools import combinations
from itertools import permutations
import matplotlib
from matplotlib import pyplot as plt
from multiprocessing import Pool
import math
import networkx as nx
from operator import itemgetter
import os
import pandas as pd
import numpy as np
import statistics
import sys
import odp_functions as odpf

configfile: "config.yaml"

config["tool"] = "odp_nway_rbh"

odpf.check_legality(config)

# check diamond_or_blastp
if "diamond_or_blastp" not in config:
    config["diamond_or_blastp"] = "blastp"
else:
    if config["diamond_or_blastp"] not in ["diamond", "blastp"]:
        raise IOError("diamond_or_blastp must be either 'diamond' or 'blastp'")

# make sure none of the sample names have underscores
for thissample in config["species"]:
    if "_" in thissample:
        raise IOError("Sample names can't have '_' char: {}".format(thissample))

# make sure there are at least 2 samples
if len(config["species"]) < 2:
    raise IOError("There must be at least 2 samples")
# make sure that nways is there
if not "nways" in config:
    raise IOError("you must specify nways in the config file. must be at least 2.")
# make sure that nways is greater than equal to 2
if config["nways"] < 2:
    raise IOError("nways must be at least 2")

# randomizations
if not "num_permutations" in config:
    raise IOError("you must specify num_permutations in the config to help calculate the false discovery rate")
num_permutations_per_round=10000
num_permutation_rounds = int(config["num_permutations"]/num_permutations_per_round)
#print("number of randomization rounds is : {}".format(num_rand_rounds))

# make sure that num species is gteq nways
if not len(config["species"]) >= config["nways"]:
    raise IOError("The number of species must be greater than or equal to nways")

# come up with all of the analyses
analyses_list = [list(sorted(x)) for x in config["analyses"]]
all_species = list(set([x for x in odpf.flatten(analyses_list)]))
print("Here is an example of the first few analyses: {}".format(analyses_list[0:3]))
print("There are {} possible combinations.".format(len(analyses_list)))

# make sure all of the species in the analyses are in the config
for entry in analyses_list:
    for thisspecies in entry:
        if thisspecies not in config["species"]:
            raise IOError ("You specified {} in the analyses, but it isn't defined in species".format(thisspecies))

config["yaxisspecies"] = config["species"]

# come up with a list of blast/diamond jobs
# we must have every combination that is in the analyses_list
config["blastjobs"] = {}
config["analysisspecies"] = set()
config["analysispairs"] = {}
for thisanalysis in analyses_list:
    for tup in list(permutations(thisanalysis, 2)):
        # all the blast analyses
        config["blastjobs"]["{}_{}".format(tup[0], tup[1])] = list(tup)
        # the individual species
        config["analysisspecies"].add(tup[0])
        config["analysisspecies"].add(tup[1])
        # get the sorted reciprocal analyses
        t = list(sorted(tup))
        config["analysispairs"]["{}_{}".format(t[0], t[1])] = t

#print(config["blastjobs"])
#print(config["analysisspecies"])
#print(config["analysispairs"])
#print(        [config["tool"] + "/step1-rbh/{}_reciprocal_best_hits.rbh".format(
#            "_".join(thisanalysis)) for thisanalysis in analyses_list])
#print([config["tool"] + "/step0-blastp_results/reciprocal_best/{}_and_{}_recip.temp.blastp".format(
#            config["blastjobs"][x][0], config["blastjobs"][x][1])
#            for x in config["blastjobs"]])

rule all:
    input:
        # the reciprocal blast results
        [config["tool"] + "/step0-blastp_results/reciprocal_best/{}_and_{}_recip.temp.blastp".format(
            config["blastjobs"][this][0], config["blastjobs"][this][1])
         for this in config["blastjobs"]],
        # find the reciprocal best hits
        [config["tool"] + "/step1-rbh/{}_reciprocal_best_hits.rbh".format(
            "_".join(thisanalysis)) for thisanalysis in analyses_list],
        # the alpha file
        [config["tool"] + "/step2-groupby/FDR/{}.FDR.tsv".format(
            "_".join(thisanalysis)) for thisanalysis in analyses_list],
        # the groupby file
        [config["tool"] + "/step2-groupby/{}.rbh.groupby".format(
            "_".join(thisanalysis)) for thisanalysis in analyses_list],
        # the filtered groupby file
        [config["tool"] + "/step2-groupby/{}.rbh.filt.groupby".format(
            "_".join(thisanalysis)) for thisanalysis in analyses_list],
        [config["tool"] + "/step3-unwrap/{}.filt.unwrapped.rbh".format(
            "_".join(thisanalysis)) for thisanalysis in analyses_list]


rule filter_prots:
    """
    Sometimes the prot file with have sequences that are not present in
     the chrom file. Make a prot file of only the proteins in the chrom file.
    """
    input:
        prots = lambda wildcards: config["species"][wildcards.sample]["proteins"],
        chrom = lambda wildcards: config["species"][wildcards.sample]["chrom"]
    output:
        pep = config["tool"] + "/db/{sample}_prots.pep"
    threads: 1
    run:
        odpf.filter_fasta_chrom(input.chrom, input.prots, output.pep)

rule make_diamond_and_blast_db:
    input:
        pep = ancient(config["tool"] + "/db/{sample}_prots.pep")
    output:
        dmnd = config["tool"] + "/db/dmnd/{sample}_prots.dmnd",
        pdb  = config["tool"] + "/db/{sample}_prots.pep.pdb"
    threads: workflow.cores - 1
    shell:
        """
        diamond makedb --in {input.pep} --db {output.dmnd}
        makeblastdb -in {input.pep} -dbtype prot
        """

rule diamond_blast_x_to_y:
    input:
        pep1  = ancient(config["tool"] + "/db/{sample1}_prots.pep"),
        dmnd1 = ancient(config["tool"] + "/db/dmnd/{sample1}_prots.dmnd"),
        pep2  = ancient(config["tool"] + "/db/{sample2}_prots.pep"),
        pdb2  = ancient(config["tool"] + "/db/{sample2}_prots.pep.pdb"),
        dmnd2 = ancient(config["tool"] + "/db/dmnd/{sample2}_prots.dmnd"),
    output:
        blastp = config["tool"] + "/step0-blastp_results/{sample1}_against_{sample2}.blastp"
    threads: (workflow.cores - 1)
    params:
        search_method = config["search_method"]
    priority: 1
    shell:
        """
        if [ "{params.search_method}" = "blastp" ]; then
            blastp -query {input.pep1} -db {input.pep2} \
              -num_threads {threads} -evalue 1E-5 -outfmt 6 > {output.blastp}
        elif [ "{params.search_method}" = "diamond" ]; then
            diamond blastp --query {input.pep1} --db {input.dmnd2} \
              --threads {threads} --evalue 1E-5 --outfmt 6 --out {output.blastp}
        fi
        """

rule reciprocal_best_hits:
    """
    finds the reciprocal best hits.
    reports it in the form of the blastp results from x -> y search
    """
    input:
        blastp1to2 = config["tool"] + "/step0-blastp_results/{sample1}_against_{sample2}.blastp",
        blastp2to1 = config["tool"] + "/step0-blastp_results/{sample2}_against_{sample1}.blastp"
    output:
        blastp1to2 = config["tool"] + "/step0-blastp_results/reciprocal_best/{sample1}_and_{sample2}_recip.temp.blastp",
        blastp2to1 = config["tool"] + "/step0-blastp_results/reciprocal_best/{sample2}_and_{sample1}_recip.temp.blastp",
    threads: 1
    run:
        odpf.reciprocal_best_hits_blastp_or_diamond_blastp(
            input.blastp1to2, input.blastp2to1, output.blastp1to2)
        df = pd.read_csv(output.blastp1to2, sep = "\t", header = None)
        newdf = df[[1,0,2,3,4,5,8,9,6,7,10,11]]
        newdf.to_csv(output.blastp2to1, sep ="\t", header = None, index = None)

def get_component_size_dict(G):
    """
    Prints out a dictionary of the component sizes of the graph.
    G is the graph.

    keys are the sizes of the components.
    values are the number of components of that siz/e
    """
    component_size = {}
    for thisentry in list(nx.connected_components(G)):
        if len(thisentry) not in component_size:
            component_size[len(thisentry)] = 0
        component_size[len(thisentry)] += 1
    print(component_size)

rule n_ways_reciprocal_best:
    """
    Gets reciprocal best hits from 3 or more samples
    For a protein to be retained, it must be a reciprocal-best hit in three samples.

              A      B
             / \    /|\
            B___C  A-+-C , et cetera
                    \|/
                     D

    The output of this rule is a yaml file with admissible proteins from each sample
    """
    input:
        # just require all the blast jobs to be done because it is easier
        #xtoyblastp = [config["tool"] + "/step0-blastp_results/reciprocal_best/{}_and_{}_recip.temp.blastp".format(x[0], x[1])
        #    for x in permutations("_".split(wildcards.analysis), 2)],
        xtoyblastp = [config["tool"] + "/step0-blastp_results/reciprocal_best/{}_and_{}_recip.temp.blastp".format(
            config["blastjobs"][this][0], config["blastjobs"][this][1])
         for this in config["blastjobs"]],
        chrom = lambda wildcards: [config["species"][x]["chrom"]
                    for x in config["analysisspecies"]]
    output:
        acceptable_prots = config["tool"] + "/step0-blastp_results/reciprocal_best/{analysis}_acceptable_prots.txt",
        blast_network    = config["tool"] + "/step0-blastp_results/reciprocal_best/{analysis}_edges.txt",
        rbh              = config["tool"] + "/step1-rbh/{analysis}_reciprocal_best_hits.rbh"
    threads: 1
    run:
        # prot to group
        prot_to_group = {}
        if "prot_to_group" in config:
            if os.path.exists(config["prot_to_group"]):
                with open(config["prot_to_group"], "r") as f:
                    for line in f:
                        line = line.strip()
                        fields = line.split("\t")
                        prot_to_group[fields[0]] = fields[1]

        #print(prot_to_group)
        species_string = output.acceptable_prots.split("/")[-1].replace("_acceptable_prots.txt", "")
        all_species = species_string.split("_")
        combos = list(combinations(all_species, 2))
        blastfiles = [[config["tool"] + "/step0-blastp_results/reciprocal_best/{}_and_{}_recip.temp.blastp".format(x[0], x[1]),
             x[0], x[1]] for x in combos]
        gene_to_species = {}
        gene_list = set()

        # get the chrom files
        chrom_dicts = {}
        for thisspecies in all_species:
            if not os.path.exists(config["species"][thisspecies]["chrom"]):
                raise IOError("This chrom file doesn't exist: {}".format(
                    config["species"][thisspecies]["chrom"]))
            chrom_dicts[thisspecies] = pd.read_csv(
                config["species"][thisspecies]["chrom"],
                header=None, sep = "\t")
            chrom_dicts[thisspecies].columns = ["prot",
                "scaf", "direction", "start", "stop"]

        # initialize the graph
        G = nx.Graph()
        checked_names = set()
        for analysis in blastfiles:
            thisfile = analysis[0]
            print(thisfile)
            a = analysis[1]
            b = analysis[2]
            with open(thisfile, "r") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        splitb = line.split("\t")
                        agene = "{}_{}".format(a, splitb[0])
                        bgene = "{}_{}".format(b, splitb[1])
                        evalue = float(splitb[-2])
                        bitscore = float(splitb[-1])
                        if a not in checked_names:
                            if agene in gene_list:
                                raise IOError("""We saw a gene twice. {}.
                                This means that two species have the same prot ids.""".format(agene))
                        gene_list.add(agene)
                        if b not in checked_names:
                            if bgene in gene_list:
                                raise IOError("""We saw a gene twice. {}.
                                This means that two species have the same prot ids.""".format(bgene))
                        gene_list.add(bgene)
                        gene_to_species[agene] = a
                        gene_to_species[bgene] = b
                        #add these since we've added the genes already
                        checked_names.add(a)
                        checked_names.add(b)
                        # now add the edge
                        G.add_edge(agene, bgene, weight = bitscore)
        remnodes = set()
        #get rid of things that couldn't possibly be an n-way best
        for thisentry in list(nx.connected_components(G)):
            if len(thisentry) < len(all_species):
                for node in thisentry:
                    remnodes.add(node)
        for node in remnodes:
            G.remove_node(node)
        remnodes.clear()

        # now get rid of nodes that don't have the correct degree
        #  to be n-connected
        for thisnode in G.nodes:
            if G.degree[thisnode] != (len(all_species) - 1):
                remnodes.add(thisnode)
        for node in remnodes:
            G.remove_node(node)
        remnodes.clear()
        # now get the n-connected components
        nwaybest = []
        for thisentry in list(nx.connected_components(G)):
            if len(thisentry) == len(all_species):
                nwaybest.append(thisentry)
            else:
                for node in thisentry:
                    remnodes.add(node)
        #cleanup the graph
        for node in remnodes:
            G.remove_node(node)
        remnodes.clear()
        # print out the graph
        uniquenodes = set()
        with open(output.blast_network, "w") as f:
            for thisedge in G.edges:
                agene = "_".join(thisedge[0].split("_")[1::])
                bgene = "_".join(thisedge[1].split("_")[1::])
                print("{}\t{}".format(agene, bgene), file = f)
                uniquenodes.add(thisedge[0])
                uniquenodes.add(thisedge[1])
        with open(output.acceptable_prots, "w") as f:
            for thisnode in uniquenodes:
                thisgene = "_".join(thisnode.split("_")[1::])
                print(thisgene, file = f)
        # print out the list of genes
        CCs = []
        for thisentry in list(nx.connected_components(G)):
            ccdict = {"rbh": "rbh{}way_{}_{}".format(
                len(all_species), "_".join(all_species), len(CCs)+1)}
            for node in thisentry:
                thisgene = "_".join(node.split("_")[1::])
                ccdict["{}_gene".format(gene_to_species[node])] = thisgene
            CCs.append(ccdict)
        genesdf = pd.DataFrame(CCs)
        genesdf["gene_group"] = "None"
        print(genesdf)

        # now add the other info
        column_add = []
        for thisspecies in sorted(all_species):
            genesdf["{}_scaf".format(thisspecies)] = genesdf[
                "{}_gene".format(thisspecies)].map(
                    dict(zip(chrom_dicts[thisspecies].prot,
                             chrom_dicts[thisspecies].scaf)) )
            genesdf["{}_pos".format(thisspecies)] = genesdf[
                "{}_gene".format(thisspecies)].map(
                    dict(zip(chrom_dicts[thisspecies].prot,
                             chrom_dicts[thisspecies].start)))
            column_add = column_add + ["{}_gene".format(thisspecies),
                                       "{}_scaf".format(thisspecies),
                                       "{}_pos".format(thisspecies)]

        # add the gene_group info
        for index, row in genesdf.iterrows():
            for thisspecies in all_species:
                this_gene = row["{}_gene".format(thisspecies)]
                if this_gene in prot_to_group:
                    genesdf.loc[index, "gene_group"] = prot_to_group[this_gene]

        # genesdf

        genesdf = genesdf[["rbh", "gene_group"] + column_add]
        genesdf.to_csv(output.rbh, sep="\t", index = False)


################################################################################
#     ____  ____   ___   __ __  ____  ____   __ __
#    /    ||    \ /   \ |  |  ||    \|    \ |  |  |
#   |   __||  D  )     ||  |  ||  o  )  o  )|  |  |
#   |  |  ||    /|  O  ||  |  ||   _/|     ||  ~  |
#   |  |_ ||    \|     ||  :  ||  |  |  O  ||___, |
#   |     ||  .  \     ||     ||  |  |     ||     |
#   |___,_||__|\_|\___/  \__,_||__|  |_____||____/
#
################################################################################

def permute_n_times(df, num_permutations):
    observations = {x: 0 for x in range(1, 5000)}
    scafs = list(df.columns)
    for i in range(num_permutations):
        for thisscaf in df.columns:
            df[thisscaf] = np.random.permutation(df[thisscaf].values)
        subbed = df.groupby(scafs).size().reset_index(name = "counts")["counts"].value_counts().to_dict()
        for key in subbed:
            observations[key] += 1
        if i % 10 == 0:
               print("  - Finished {}/{} ({:.2f}%) analyses.  ".format(
                   i, num_permutations,
                   (i/num_permutations)*100), end = "\r")
    observations = {key: observations[key] for key in observations
                     if observations[key] != 0}
    return observations

rule single_permutation_test:
    """
    This performs one FDR test of 100000 rounds.

    These results will be cat'd with the results from other threads.
    """
    input:
        rbh = config["tool"] + "/step1-rbh/{analysis}_reciprocal_best_hits.rbh"
    output:
        alpha = config["tool"] + "/step2-groupby/FDR/sim/{analysis}_sim_{simround}.tsv"
    threads: 1
    params:
        thisround = lambda wildcards: wildcards.simround,
        num_permutations = num_permutations_per_round,
        analysis_string = lambda wildcards: wildcards.analysis
    run:
        these_species = params.analysis_string.split("_")
        df = pd.read_csv(input.rbh, sep = "\t", index_col = 0)
        scafs = ["{}_scaf".format(x) for x in these_species]
        df = df[scafs]

        observations = permute_n_times(df, params.num_permutations)
        with open(output.alpha, "w") as f:
            for key in observations:
                print("{}\t{}".format(key, observations[key]),
                      file = f)

rule permutation_test_for_compariasons:
    input:
        fdr_results = expand(config["tool"] + "/step2-groupby/FDR/sim/{{analysis}}_sim_{simround}.tsv",
               simround = list(range(1,num_permutation_rounds+1)))
    output:
        alpha = config["tool"] + "/step2-groupby/FDR/{analysis}.FDR.tsv"
    threads: workflow.cores - 1
    params:
        num_permutations = config["num_permutations"]
    run:
        observations = {x: 0 for x in range(1, 5000)}
        for thisfile in input.fdr_results:
            with open(thisfile, "r") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        print(line)
                        fields = [int(x) for x in line.split()]
                        observations[fields[0]] += fields[1]

        resultsDF = pd.Series(observations).to_frame()
        resultsDF.reset_index(inplace = True)
        resultsDF.columns = ["Num_Genes_In_Chr_Group", "Num_Permutations"]
        resultsDF["Total_Tests"] = params.num_permutations
        resultsDF["alpha"] = resultsDF["Num_Permutations"]/params.num_permutations
        print(resultsDF)
        resultsDF.to_csv(output.alpha, sep="\t", index = False)

rule groupby_rbh_results:
    input:
        rbh = config["tool"] + "/step1-rbh/{analysis}_reciprocal_best_hits.rbh",
        alpha = config["tool"] + "/step2-groupby/FDR/{analysis}.FDR.tsv"
    output:
        rbh  = config["tool"] + "/step2-groupby/{analysis}.rbh.groupby"
    params:
        analysis_string = lambda wildcards: wildcards.analysis
    threads: 1
    run:
        all_species = params.analysis_string.split("_")
        match_string = "_".join(all_species)

        df = pd.read_csv(input.rbh, sep = "\t")
        groupbycols = ["{}_scaf".format(x) for x in all_species]
        alphadf = pd.read_csv(input.alpha, sep = "\t")

        # calculate the alpha score for each level
        df = df.reset_index(drop = True)
        print(df)
        print(df.columns)
        grouped_multiple = df.groupby(groupbycols).agg(list).reset_index()

        # get the size
        grouped_multiple["count"] = grouped_multiple.rbh.str.len()
        grouped_mutiple = grouped_multiple.loc[grouped_multiple["count"] > 1, ]

        # sort
        grouped_multiple = grouped_multiple.sort_values(by="count", ascending=False)
        alpha_dict = dict(zip(alphadf.Num_Genes_In_Chr_Group, alphadf.alpha))
        grouped_multiple["alpha"] = grouped_multiple["count"].map(alpha_dict)
        grouped_multiple["alpha_type"] = "equal_to"
        for index, row in grouped_multiple.iterrows():
            if row["alpha"] == 0:
                # now reassign the alpha to the next group that we've seen
                done = False
                countdown_index = 1
                while not done:
                    new_alpha = alpha_dict[row["count"] - countdown_index]
                    if new_alpha == 0:
                        countdown_index += 1
                    else:
                        done = True
                        grouped_multiple.loc[index, "alpha"] = new_alpha
                        grouped_multiple.loc[index, "alpha_type"] = "less_than"
        grouped_multiple["gene_group"] = "None"
        upfront   = ["rbh", "gene_group",
                     "count", "alpha", "alpha_type"]
        other_pt1 = [x for x in grouped_multiple.columns
                     if x.split("_")[0] not in all_species]
        other_pt2 = list(sorted([x for x in other_pt1
                                 if x not in upfront]))
        species_cols = list(sorted([x for x in grouped_multiple.columns
                                    if x.split("_")[0] in all_species]))
        upfront_species_other = upfront + species_cols + other_pt2
        grouped_multiple = grouped_multiple[upfront_species_other]
        grouped_multiple.to_csv(output.rbh, sep="\t", index = False)

rule filter_rbh_files:
    input:
        rbh  = config["tool"] + "/step2-groupby/{analysis}.rbh.groupby"
    output:
        rbh  = config["tool"] + "/step2-groupby/{analysis}.rbh.filt.groupby"
    params:
        alpha = 0.05
    threads: 1
    run:
            df = pd.read_csv(input.rbh, sep = "\t", index_col = None)
            df = df.loc[df["alpha"] <= params.alpha, ]
            df.to_csv(output.rbh, sep="\t", index = False, na_rep='nan')

def parse_pd_list_or_string(pd_list, rowcount):
    templist = []
    if type(pd_list) == str:
        for entry in ["nan", "NaN"]:
            pd_list = pd_list.replace(entry, "None")
        if '[' in pd_list:
            # the thing is a list, parse it as a list
            return ast.literal_eval(pd_list)
        else:
            # if it isn't a list, just
            #print(pd_list, rowcount)
            return [pd_list] * rowcount
    elif (type(pd_list) == float) and (np.isnan(pd_list)):
        return [None] * rowcount

rule unwrap_rbh_file_with_group_column:
    input:
        rbh = config["tool"] + "/step2-groupby/{analysis}.rbh.filt.groupby"
    output:
        rbh = config["tool"] + "/step3-unwrap/{analysis}.filt.unwrapped.rbh"
    params:
        prefix = config["tool"] + "/output/",
        analysis = lambda wildcards: wildcards.analysis
    threads: 1
    run:
        df = pd.read_csv(input.rbh, sep = "\t", index_col = None)
        if len([x for x in df.columns if x.startswith("Unnamed")]) > 0:
            print("There is a column that is unnamed, and probably doesn't have data in that column. Check your input.", file = sys.stderr)
        df = df[[x for x in df.columns if not x.startswith("Unnamed")]]
        if True in [np.isnan(x) for x in df["count"].unique()]:
            raise IOError("There are some rows in this dataset where count is nan. That should not occur in a rbh or rbh.groupby file. Check that your input has no rows that are erroneously empty.")

        # get all the species in the dataframe
        complete_species_list = [x.split("_")[0] for x in df.columns
                                 if x.endswith("_scaf")]
        species_string = params.analysis

        rbh_entries = []
        for index, row in df.iterrows():
            rbh_list = ast.literal_eval(row["rbh"])
            species_to_gene_list = {}
            species_to_gene_pos  = {}
            species_to_gene_scaf = {}
            # make lookup tables
            for thisspecies in complete_species_list:
                rowcount = row["count"]
                parse_pairs = [("gene", species_to_gene_list),
                               ("pos",  species_to_gene_pos),
                               ("scaf", species_to_gene_scaf) ]
                for thiscol, thislist in parse_pairs:
                    # gene_col
                    colname = "{}_{}".format(thisspecies, thiscol)
                    #species_to_gene_list[thisspecies] = ast.literal_eval(row[colname])
                    thislist[thisspecies] = parse_pd_list_or_string(
                                                         row[colname], rowcount)

            for i in range(len(rbh_list)):
                thisgroup = row["gene_group"]
                thisentry = {"rbh": rbh_list[i],
                             "gene_group": thisgroup,
                             "count":      row["count"]}
                for keep_this in ["ALG", "color", "alpha", "alpha_type"]:
                    if keep_this in row:
                        thisentry[keep_this] = row[keep_this]
                for thisspecies in complete_species_list:
                    # get col names
                    scafcol = "{}_scaf".format(thisspecies)
                    genecol = "{}_gene".format(thisspecies)
                    poscol  = "{}_pos".format(thisspecies)
                    thisentry[scafcol] = species_to_gene_scaf[thisspecies][i]
                    thisentry[genecol] = species_to_gene_list[thisspecies][i]
                    thisentry[poscol]  = species_to_gene_pos[thisspecies][i]
                rbh_entries.append(thisentry)

        unwrapped = pd.DataFrame(rbh_entries)
        unwrapped = unwrapped[[x for x in unwrapped.columns
                               if x not in ["count", "alpha", "alpha_type"]]]
        outfile = params.prefix + species_string + ".unwrapped.rbh"
        unwrapped.to_csv(output.rbh, sep="\t", na_rep = "None", index = False)
