#!/usr/bin/env python

"""
This is a helper program of GToTree (https://github.com/AstrobioMike/GToTree/wiki)
to facilitate subsetting accessions pulled from the GTDB database (with 'gtt-get-accessions-from-GTDB').

It is intended to help when wanting a tree to span a breadth of diversity we know about, while also helping
to reduce over-representation of certain taxa.

There are two primary methods to use it:

1) If a specific class makes up > 0.05% (by default) of the total number of accessions, it will randomly subset
that class down to 5% of what it was. So if there are 40,000 total target genomes, and Gammaproteobacteria have
8,000 of them (20% of the total), the program will randomly select 80 Gammaproteobacteria to include (1% of 8,000).

2) Select 1 random genome from each of the specific rank: "phylum", "class", "order", "family", "genus", or "species".
"""

import sys
import argparse
import textwrap
import pandas as pd


parser = argparse.ArgumentParser(description='This script is a helper program of GToTree (https://github.com/AstrobioMike/GToTree/wiki)\
                                              to facilitate subsetting accessions pulled from the GTDB database (with\
                                              \'gtt-get-accessions-from-GTDB\' – the input file is the "*metadata.tsv" from that program).\
                                              It is intended to help when wanting a tree to span the breadth\
                                              of diversity we know about, while also helping to reduce over-representation of certain taxa. \
                                              There are 2 primary methods for using it: \
                                              1) If a specific Class makes up > 0.05% (by default) of the total number of target genomes, the script\
                                              will randomly subset that class down to 1% of what it was. So if there are 40,000 total target genomes,\
                                              and Gammaproteobacteria make up 8,000 of them (20% of the total), the program will randomly select 80 Gammaproteobacteria\
                                              to include (1% of 8,000). \
                                              2) Select 1 random genome from each taxa of the specified rank. \
                                              It ultimately outputs a new subset accessions file ready for use with the main GToTree program.',
                                epilog = "Ex. usage: gtt-subset-GTDB-classes -i GTDB-arc-and-bac-refseq-rep-metadata.tsv --get-only-individuals-for-the-rank order")

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--GTDB-table", help="GTDB metadata table produced with 'gtt-get-accessions-from-GTDB'", action="store", required=True)
parser.add_argument("-o", "--output-prefix", help='output prefix for output subset accessions (*.txt) and GTDB taxonomy files (*.tsv) (default: "subset-accessions")', action="store", dest="output_prefix", default="subset-accessions")

parser.add_argument("-p", "--cutoff-fraction", help='Fraction of total target genomes that any given Class must contribute in order for that class to be randomly subset (default: 0.0005)', action="store", default=0.0005)
parser.add_argument("-f", "--fraction-to-subset", help='Fraction those that are filtered should be randomly subset down to (default: 0.01)', action="store", dest="filter_fraction", default=0.01)

# this is being left for backwards-compatibility reasons only (same )
parser.add_argument("--get-Order-representatives-only", action="store_true", help = "Provide this flag to simply get 1 random genome from each Order in GTDB (same as if specifying \
                    `--get-only-individuals-for-the-rank order`, but left here for backwards-compatibility purposes)", dest = "order_only")

parser.add_argument("--get-only-individuals-for-the-rank", action="store", choices = {"phylum", "class", "order", "family", "genus", "species"}, 
                    help = "Use this option with a specified rank if wanting to randomly subset such that we retain 1 genome from each entry in a specific rank in GTDB", dest = "target_rank")


parser.add_argument("--seed", action = "store", help = "Specify the seed for random subsampling (default = 1)", default = 1, type = int)

if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()


############################################################

# setting some colors
tty_colors = {
    'green' : '\033[0;32m%s\033[0m',
    'yellow' : '\033[0;33m%s\033[0m',
    'red' : '\033[0;31m%s\033[0m'
}


### functions ###
def color_text(text, color='green'):
    if sys.stdout.isatty():
        return tty_colors[color] % text
    else:
        return text

def wprint(text):
    print(textwrap.fill(text, width=80, initial_indent="  ",
          subsequent_indent="  ", break_on_hyphens=False))


# function to subset master table to one for each order
def order_subset_table(order_to_subset, input_tab, seed):

    sub_master_df_to_keep = input_tab.loc[input_tab["order"] != order_to_subset]

    curr_order_df = input_tab.loc[input_tab["order"] == order_to_subset]

    random_sub_df = curr_order_df.sample(n = 1, random_state = int(seed))

    new_df = pd.concat([sub_master_df_to_keep, random_sub_df])

    return new_df


# function to subset master table
def subset_table(class_to_subset, input_tab, seed):

    curr_class = class_to_subset

    sub_master_df_to_keep = input_tab.loc[input_tab["class"] != curr_class]

    curr_class_df = input_tab.loc[input_tab["class"] == curr_class]

    random_sub_df = curr_class_df.sample(n=int(len(curr_class_df.index) * float(args.filter_fraction)), random_state = int(seed))

    new_df = pd.concat([sub_master_df_to_keep, random_sub_df])

    return new_df


# function to subset arbitrary rank
def taxa_subset_table(taxa_to_subset, rank, input_tab, seed):

    sub_master_df_to_keep = input_tab.loc[input_tab[rank] != taxa_to_subset]

    curr_order_df = input_tab.loc[input_tab[rank] == taxa_to_subset]

    random_sub_df = curr_order_df.sample(n = 1, random_state = int(seed))

    new_df = pd.concat([sub_master_df_to_keep, random_sub_df])

    return new_df


################################################################################



# reading lineage table into pandas dataframe
lineage_df = pd.read_csv(args.GTDB_table, delimiter="\t", usecols = range(8))

starting_size = len(lineage_df.index)

# just giving 1 of each order if requested (left here like this for backwards-compatibility purposes)
if args.order_only:

    args.target_rank = "order"


if args.target_rank:

    # getting list of all unique entries of wanted rank
    unique_entries = lineage_df[args.target_rank].unique()

    # getting one rep genome of each of these
    for entry in unique_entries:

        lineage_df = taxa_subset_table(entry, args.target_rank, lineage_df, args.seed)

    filtered_size = len(lineage_df.index)

    # removing "RS_" and "GB_" prefixes and writing out output accs
    output_accessions = args.output_prefix + ".txt"
    with open(output_accessions, "w") as out:
        for acc in lineage_df.accession:
            out.write(acc[3:] + "\n")

    # writing out subset GTDB taxonomy
    output_tax = args.output_prefix + "-taxonomy.tsv"
    lineage_df.to_csv(output_tax, sep = "\t", index = False)

    # reporting and exiting
    print("")
    wprint(color_text(str("{:,}".format(starting_size)) + " initial entries were subset down to " + str("{:,}".format(filtered_size)) + "\n", "yellow"))
    print("")
    wprint("Subset accessions file for GToTree written to:")
    print(color_text("    " + str(output_accessions)) + "\n")
    wprint("A subset GTDB taxonomy table for these accessions written to:")
    print(color_text("    " + str(output_tax)) + "\n")

    exit()


# if the above didn't run and exit, then we are going to randomly subset based on class
class_dict = {}

# counting how many times each class shows up
for index, row in lineage_df.iterrows():
    if row["class"] not in class_dict:
        class_dict[row["class"]] = 1
    else:
        class_dict[row["class"]] += 1

# getting cutoff threshold of total number of entries
cutoff = int(starting_size * float(args.cutoff_fraction))

# getting which classes are above this threshold
classes_to_subset = []

for key in class_dict:
    if class_dict[key] >= cutoff:
        classes_to_subset.append(key)

# subsetting each class
for rank in classes_to_subset:
    lineage_df = subset_table(rank, lineage_df, args.seed)

filtered_size = len(lineage_df.index)

# removing "RS_" and "GB_" prefixes and writing out output accs
output_accessions = args.output_prefix + ".txt"
with open(output_accessions, "w") as out:
    for acc in lineage_df.accession:
        out.write(acc[3:] + "\n")

# writing out subset GTDB taxonomy
output_tax = args.output_prefix + "-taxonomy.tsv"
lineage_df.to_csv(output_tax, sep = "\t", index = False)


print("")
wprint(color_text(str("{:,}".format(starting_size)) + " initial entries were subset down to " + str("{:,}".format(filtered_size)) + "\n", "yellow"))
# print("\n    Subset classes included: \n\t\t\t\t" + "\n\t\t\t\t".join(classes_to_subset) + "\n")
print("")
wprint("Subset accessions written to file:")
print(color_text("    " + str(output_accessions)) + "\n")
wprint("A subset GTDB taxonomy table for these accessions written to:")
print(color_text("    " + str(output_tax)) + "\n")
