#!/usr/bin/env python

"""
This is a helper program of GToTree (https://github.com/AstrobioMike/GToTree/wiki) 
to facilitate accessing and using taxonomy and genomes from the glorious Genome Taxonomy 
Database (gtdb.ecogenomic.org/).

For examples, please visit the GToTree wiki here: https://github.com/AstrobioMike/GToTree/wiki/example-usage
"""

import sys
import os
import urllib.request
import tarfile
import pandas as pd
import textwrap
import argparse
import shutil
import filecmp

parser = argparse.ArgumentParser(description="This is a helper program to facilitate using taxonomy \
                                              and genomes from the Genome Taxonomy Database (gtdb.ecogenomic.org) with GToTree.\
                                              It primarily returns NCBI accessions and GTDB summary tables based on GTDB-taxonomy\
                                              searches. It also currently has filtering capabilities built-in for specifying only\
                                              GTDB representative species or RefSeq representative genomes (see help menu and links\
                                              therein for explanations of what these are). For examples, please visit the GToTree\
                                              wiki here: github.com/AstrobioMike/GToTree/wiki/example-usage",
                                 epilog="Ex. usage: gtt-get-accessions-from-GTDB -t Archaea --GTDB-representatives-only\n")

parser.add_argument("-t", "--target-taxon", help="Target taxon (enter 'all' for all)", action="store", dest="target_taxon")
parser.add_argument("-r", "--target-rank", help='Target rank', action="store", dest="target_rank")
parser.add_argument("--get-table", action="store_true", help="Provide just this flag alone to download and parse a GTDB metadata \
                                           table. Archaea and Bacteria tables pulled from here \
                                           (https://data.gtdb.ecogenomic.org/releases/latest/) and combined, and the \
                                           GTDB taxonomy column is split for easier manual searching if wanted.")
parser.add_argument("--get-rank-counts", action="store_true", help="Provide just this flag alone to see counts of how many \
                                           unique taxa there are for each rank.")
parser.add_argument("--get-taxon-counts", action="store_true", help="Provide this flag along with a specified taxon to the `-t` flag \
                                                                     to see how many of that taxon are in the database.")
parser.add_argument("--GTDB-representatives-only", action="store_true", help="Provide this flag along with a specified taxon to the `-t` flag \
                                                                        to pull accessions only for genomes designated as GTDB species \
                                                                        representatives (see e.g.: https://gtdb.ecogenomic.org/faq#gtdb_species_clusters).")
parser.add_argument("--RefSeq-representatives-only", action="store_true", help="Provide this flag along with a specified taxon to the `-t` flag \
                                                                        to pull accessions only for genomes designated as RefSeq representative \
                                                                        genomes (see e.g.: https://www.ncbi.nlm.nih.gov/refseq/about/prokaryotes/#representative_genomes).\
                                                                        (Useful for subsetting a view across a broad level of diversity, see \
                                                                        also `gtt-subset-GTDB-accessions`.)")

parser.add_argument("--do-not-check-GTDB-version", action="store_true", help="By default, this program checks if it is using the latest version of \
                                                                        the GTDB database. Provide this flag to prevent this from occurring, using \
                                                                        the version already present.")

parser.add_argument("--store-GTDB-metadata-in-current-working-dir", action="store_true", help="By default, GToTree uses a system-wide variable \
                                                                        to know where to put and search the GTDB metadata. Provide this flag to \
                                                                        ignore that and store the master table in the current working directory.")


if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()


################################################################################

def main():

    if args.get_table:
        get_gtdb_tab(args.do_not_check_GTDB_version, args.store_GTDB_metadata_in_current_working_dir)
        sys.exit(0)

    ## some checks to prevent things that either should be provided together, or should be mutually exclusive
    if args.get_taxon_counts and not args.target_taxon:
        print("")
        wprint(color_text("A specific taxon needs to also be provided to the `-t` flag in order to use `--get-taxon-counts`.", "yellow"))
        print("")
        wprint("  E.g.: gtt-get-accessions-from-GTDB --get-taxon-counts -t Alteromonas")
        print("")
        sys.exit(1)

    if args.GTDB_representatives_only and args.RefSeq_representatives_only:
        print("")
        wprint(color_text("Only one of `--GTDB-representatives-only` or `--RefSeq-representatives-only` can be provided.", "yellow"))
        print("")
        sys.exit(1)

    ## checking env variable is present, if not putting in current working dir
    if not args.store_GTDB_metadata_in_current_working_dir:

        check_location_var_is_set_and_writable("GTDB_dir")

    ## moving on
    gtdb_tab = get_gtdb_tab(args.do_not_check_GTDB_version, args.store_GTDB_metadata_in_current_working_dir)

    if args.GTDB_representatives_only:
        gtdb_rep_tab = gtdb_tab[gtdb_tab["gtdb_representative"] == "t"]
        representatives_source = "GTDB"

    if args.RefSeq_representatives_only:
        gtdb_rep_tab = gtdb_tab[gtdb_tab["ncbi_refseq_category"] == "representative genome"]
        representatives_source = "RefSeq"

    if args.get_rank_counts:
        if args.GTDB_representatives_only or args.RefSeq_representatives_only:
            get_unique_taxa_counts_of_all_ranks(gtdb_tab, gtdb_rep_tab, representatives_source=representatives_source)
            sys.exit(0)
        else:
            get_unique_taxa_counts_of_all_ranks(gtdb_tab)
            sys.exit(0)

    if args.get_taxon_counts:
        if args.GTDB_representatives_only or args.RefSeq_representatives_only:
            get_unique_taxon_counts(args.target_taxon, gtdb_tab, gtdb_rep_tab, representatives_source=representatives_source)
            sys.exit(0)
        else:
            get_unique_taxon_counts(args.target_taxon, gtdb_tab)
            sys.exit(0)

    if args.target_taxon and not args.target_rank:
        if args.GTDB_representatives_only or args.RefSeq_representatives_only:
            get_accessions(args.target_taxon, gtdb_tab, gtdb_rep_tab, representatives_source=representatives_source)
            sys.exit(0)
        else:
            get_accessions(args.target_taxon, gtdb_tab)
            sys.exit(0)

    if args.target_taxon and args.target_rank:
        if args.GTDB_representatives_only or args.RefSeq_representatives_only:
            get_accessions(args.target_taxon, gtdb_tab, gtdb_rep_tab, rank=args.target_rank, representatives_source=representatives_source)
            sys.exit(0)
        else:
            get_accessions(args.target_taxon, gtdb_tab, rank=args.target_rank)
            sys.exit(0)

################################################################################


# setting some colors
tty_colors = {
    'green' : '\033[0;32m%s\033[0m',
    'yellow' : '\033[0;33m%s\033[0m',
    'red' : '\033[0;31m%s\033[0m'
}


### functions ###
def color_text(text, color='green'):
    if sys.stdout.isatty():
        return tty_colors[color] % text
    else:
        return text


def wprint(text):
    print(textwrap.fill(text, width=80, initial_indent="  ", 
          subsequent_indent="  ", break_on_hyphens=False))


def check_location_var_is_set_and_writable(variable):

    # making sure there is an env variable
    try:
        path = os.environ[variable]

        if path == "":
            raise

    except:
        print()
        wprint(color_text("The environment variable '" + str(variable) + "' does not seem to be set :(", "yellow"))
        print()
        wprint("Try to set it with `gtt-data-locations set`, or use the `--store-GTDB-metadata-in-current-working-dir` argument with this program.")
        print("")
        sys.exit(1)

    # making sure path is writable for the user
    path_writable = os.access(path, os.W_OK)

    if not path_writable:
        print()
        wprint(color_text("The environment variable '" + str(variable) + "' does not seem to be writable :(", "yellow"))
        print()
        wprint("You can modify the location with `gtt-data-locations set`, or use the `--store-GTDB-metadata-in-current-working-dir` argument with this program.")
        print("")
        sys.exit(1)

    return()


def get_accessions(taxon, gtdb_tab, gtdb_rep_tab=None, rank=None, representatives_source=None):
    """ get accessions based on specified taxon, if the provided taxon is in more than one rank, will require specified rank """

    if taxon != "all":

        if not representatives_source:

            ranks_found_in = get_unique_taxon_counts(taxon, gtdb_tab, return_ranks_found_in=True)

            if not rank:

                if len(ranks_found_in) > 1:
                    wprint(color_text("Since '" + str(taxon) + "' occurs at more than 1 rank, we'll need to specify \
                           which rank is wanted as well before we pull the accessions. This can be specified with the `-r` flag.", "yellow"))
                    print("")
                    sys.exit(0)

                else:

                    rank = ranks_found_in[0]

            else:

                rank = rank.lower()

            # subsetting table and writing out that and accessions list
            gtdb_sub_tab = gtdb_tab[gtdb_tab[rank] == taxon]

            tab_out_filename = "GTDB-" + taxon + "-" + rank + "-metadata.tsv"
            acc_out_filename = "GTDB-" + taxon + "-" + rank + "-accs.txt"

            gtdb_sub_tab.to_csv(tab_out_filename, sep="\t", index=False)

            target_accs = gtdb_sub_tab["ncbi_genbank_assembly_accession"].tolist()

            with open(acc_out_filename, "w") as out:
                for acc in target_accs:
                    out.write(acc + "\n")

            wprint("The targeted NCBI accessions were written to:")
            wprint("  " + color_text(acc_out_filename))
            print("")
            wprint("A subset GTDB table of these targets was written to:")
            wprint("  " + color_text(tab_out_filename))
            print("")

        else:

            ranks_found_in_rep = get_unique_taxon_counts(taxon, gtdb_tab, gtdb_rep_tab, return_ranks_found_in=True, representatives_source=representatives_source)

            if not rank:

                if len(ranks_found_in_rep) > 1:
                    wprint(color_text("Since '" + str(taxon) + "' occurs at more than 1 rank, we'll need to specify \
                           which rank is wanted as well before we pull the accessions. This can be specified with the `-r` flag.", "yellow"))
                    print("")
                    sys.exit(0)

                else:

                    rank = ranks_found_in_rep[0]

            else:

                rank = rank.lower()

            # subsetting table and writing out that and accessions list
            gtdb_sub_rep_tab = gtdb_rep_tab[gtdb_rep_tab[rank] == taxon]

            tab_out_filename = "GTDB-" + taxon + "-" + rank + "-" + representatives_source + "-rep-metadata.tsv"
            acc_out_filename = "GTDB-" + taxon + "-" + rank + "-" + representatives_source + "-rep-accs.txt"

            gtdb_sub_rep_tab.to_csv(tab_out_filename, sep="\t", index=False)

            target_accs = gtdb_sub_rep_tab["ncbi_genbank_assembly_accession"].tolist()

            with open(acc_out_filename, "w") as out:
                for acc in target_accs:
                    out.write(acc + "\n")

            wprint("The targeted NCBI accessions were written to:")
            wprint("  " + color_text(acc_out_filename))
            print("")
            wprint("A subset GTDB table of these targets was written to:")
            wprint("  " + color_text(tab_out_filename))
            print("")

    else:

        if not representatives_source:

            target_accs = gtdb_tab["ncbi_genbank_assembly_accession"].tolist()

            with open("GTDB-arc-and-bac-accessions.txt", "w") as out:
                for acc in target_accs:
                    out.write(acc + "\n")

            print("")
            wprint("The NCBI accessions were written to: " + color_text("GTDB-arc-and-bac-accessions.txt"))
            print("")
            wprint("The GTDB table that already exists holds all of these: " + color_text("GTDB-arc-and-bac-metadata.tsv"))
            print("")

        else:

            # already subset, writing out that and accessions list
            gtdb_rep_tab.to_csv("GTDB-arc-and-bac-refseq-rep-metadata.tsv", sep="\t", index=False)

            target_accs = gtdb_rep_tab["ncbi_genbank_assembly_accession"].tolist()

            with open("GTDB-arc-and-bac-refseq-rep-accessions.txt", "w") as out:
                for acc in target_accs:
                    out.write(acc + "\n")

            print("")
            wprint("The targeted NCBI accessions were written to: " + color_text("GTDB-arc-and-bac-refseq-rep-accessions.txt"))
            print("")
            wprint("A subset GTDB table of these targets was written to: " + color_text("GTDB-arc-and-bac-refseq-rep-metadata.tsv"))
            print("")


def get_unique_taxon_counts(taxon, gtdb_tab, gtdb_rep_tab=None, return_ranks_found_in=False, representatives_source=None):
    """ get counts of specific taxa """

    if taxon == "all":
        count = len(gtdb_tab.index)
        print("")
        wprint("  There are " + str(count) + " total genomes in the database.")
        print("")

        if representatives_source:
            count = len(gtdb_rep_tab.index)
            wprint(color_text("In considering only " + representatives_source + " representative genomes:", "yellow"))
            print("")
            wprint("  There are " + str(count) + " total representative genomes in the database.")
            print("")

    else:

        ranks = ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']

        ranks_found_in = []

        print("")
        for rank in ranks:

            if taxon in gtdb_tab[rank].unique():

                ranks_found_in.append(rank)

                count = len(gtdb_tab[gtdb_tab[rank] == taxon].index)
                wprint("  The rank '" + rank + "' has " + str(count) + " " + taxon + " entries.")

        if len(ranks_found_in) == 0:
            wprint(color_text("Input taxon '" + taxon + "' doesn't seem exist at any rank :(", "yellow"))
            wprint("Be aware, to be on the safe side, our searching is case-sensitive.")
            print("")
            sys.exit(0)

        print("")

        if representatives_source:

            ranks_found_in_rep = []

            wprint(color_text("In considering only " + representatives_source + " representative genomes:", "yellow"))
            print("")
            for rank in ranks:

                if taxon in gtdb_rep_tab[rank].unique():

                    ranks_found_in_rep.append(rank)

                    count = len(gtdb_rep_tab[gtdb_rep_tab[rank] == taxon].index)
                    wprint("  The rank '" + rank + "' has " + str(count) + " " + taxon + " representative genome entries.")
                    print("")

            if len(ranks_found_in_rep) == 0:
                wprint(color_text("Input taxon '" + taxon + "' doesn't seem exist at any rank as a representative genome :(", "yellow"))
                print("")
                sys.exit(0)

            if return_ranks_found_in:
                return(ranks_found_in_rep)

        else:
            if return_ranks_found_in:
                return(ranks_found_in)


def get_unique_taxa_counts_of_all_ranks(gtdb_tab, gtdb_rep_tab=None, representatives_source=None):
    """ get counts of unique taxa at each rank """

    ranks = ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']

    print("\n    {:<10} {:}".format("Rank", "Num. Unique Taxa"))
    for rank in ranks:
        print("    {:<10} {:}".format(rank, str(gtdb_tab[rank].nunique())))

    print("")

    # below only needed for RefSeq, because if it is GTDB, it is equivalent (i.e., they have 1 rep genome for each unique rank in their system)
    if representatives_source == "RefSeq":

        wprint(color_text("In considering only " + representatives_source + " representative genomes:", "yellow"))
        print("")

        print("    {:<10} {:}".format("Rank", "Num. Unique Rep. Taxa"))
        for rank in ranks:
            print("    {:<10} {:}".format(rank, str(gtdb_rep_tab[rank].nunique())))

        print("")


def gen_gtdb_tab(location):
    """ downloads and parses the GTDB info tables """

    # getting archaea
    arc_tar_gz = urllib.request.urlopen("https://data.gtdb.ecogenomic.org/releases/latest/ar53_metadata.tar.gz")
    arc_tab = pd.read_csv(arc_tar_gz, sep="\t", compression="gzip", on_bad_lines = 'skip', header=0, low_memory=False)
    arc_tab.rename(columns={arc_tab.columns[0]:"accession"}, inplace=True)
    arc_tab.dropna(inplace=True, how="all")

    # getting bacteria
    bac_tar_gz = urllib.request.urlopen("https://data.gtdb.ecogenomic.org/releases/latest/bac120_metadata.tar.gz")
    bac_tab = pd.read_csv(bac_tar_gz, sep="\t", compression="gzip", on_bad_lines = 'skip', header=0, low_memory=False)
    bac_tab.rename(columns={bac_tab.columns[0]:"accession"}, inplace=True)
    bac_tab.dropna(inplace=True, how="all")

    # combining
    gtdb_tab = pd.concat([arc_tab, bac_tab])

    # splitting gtdb taxonomy column into 7 and dropping the single column
    domain, phylum, rclass, order, family, genus, species = [], [], [], [], [], [], []

    for index, row in gtdb_tab.iterrows():
        curr_acc = row["accession"]
        tax_list = row["gtdb_taxonomy"].split(";")

        if len(tax_list) != 7:
            wprint(color_text("GTDB entry " + curr_acc + " doesn't seem to have 7-column lineage info. Something is likely wrong :(", "yellow"))
            print("")
            wprint("If this continues to happen, please file an issue at github.com/AstrobioMike/GToTree/issues")
            print("")
            wprint("Aborting for now.")
            print("")
            sys.exit(0)

        else:
            domain.append(tax_list[0][3:])
            phylum.append(tax_list[1][3:])
            rclass.append(tax_list[2][3:])
            order.append(tax_list[3][3:])
            family.append(tax_list[4][3:])
            genus.append(tax_list[5][3:])
            species.append(tax_list[6][3:])

    gtdb_tab.insert(1, "species", species)
    gtdb_tab.insert(1, "genus", genus)
    gtdb_tab.insert(1, "family", family)
    gtdb_tab.insert(1, "order", order)
    gtdb_tab.insert(1, "class", rclass)
    gtdb_tab.insert(1, "phylum", phylum)
    gtdb_tab.insert(1, "domain", domain)

    # writing out
    gtdb_tab.to_csv(location + "GTDB-arc-and-bac-metadata.tsv", index=False, sep="\t")

    gtdb_version_info = urllib.request.urlretrieve("https://data.gtdb.ecogenomic.org/releases/latest/VERSION", location + "GTDB-version-info.txt")

    return(gtdb_tab)

def report_gtdb_version_info(location):
    """ reporting GTDB version info """
    version_info = []
    with open(location + "GTDB-version-info.txt") as version_info_file:
        for line in version_info_file:
            line = line.strip()
            if line != "":
                version_info.append(line)

    gtdb_version = version_info[0]
    gtdb_release_date = version_info[1]

    print("    Using GTDB " + gtdb_version + ": " + gtdb_release_date + "\n")


def check_stored_gtdb_up_to_date(location):
    """ checks if the stored gtdb metadata table is the latest """

    # getting latest version info from GTDB
    gtdb_latest_version_info = urllib.request.urlretrieve("https://data.gtdb.ecogenomic.org/releases/latest/VERSION", location + "GTDB-latest-version-info.txt")

    # comparing vs old
    if filecmp.cmp(location + "GTDB-latest-version-info.txt", location + "GTDB-version-info.txt"):
        os.remove(location + "GTDB-latest-version-info.txt")

    else:
        print("")
        wprint(color_text("A newer version of the GTDB is available, updating stored table...", "yellow"))

        gen_gtdb_tab(location)
        os.remove(location + "GTDB-latest-version-info.txt")


def get_gtdb_tab(prevent_database_version_check, store_db_metadata_in_current_working_dir):
    """ reads in file with gtdb info """

    # checking if there is a GTDB_dir (should be if conda install of GToTree was done)
    try:
        GTDB_dir = os.environ['GTDB_dir']
    except:
        GTDB_dir = False


    # working there if there is one
    if GTDB_dir and store_db_metadata_in_current_working_dir == False:

        if os.path.exists(GTDB_dir + "GTDB-arc-and-bac-metadata.tsv") and os.path.exists(GTDB_dir + "GTDB-version-info.txt"):

            # making sure the stored version is the latest available from GTDB (unless the do not check flag was provided)
            if prevent_database_version_check == False:
                check_stored_gtdb_up_to_date(GTDB_dir)

            # copying if just table wanted
            if args.get_table:
                if not os.path.exists("GTDB-arc-and-bac-metadata.tsv"):

                    # reporting version info
                    report_gtdb_version_info(GTDB_dir)

                    shutil.copy(GTDB_dir + "GTDB-arc-and-bac-metadata.tsv", "GTDB-arc-and-bac-metadata.tsv")
                    print("")
                    # reporting version info
                    report_gtdb_version_info(GTDB_dir)
                    wprint("GTDB table copied to:")
                    print(color_text("    GTDB-arc-and-bac-metadata.tsv\n"))
                    sys.exit(0)

                else:
                    print("")
                    wprint("The GTDB table " + color_text("GTDB-arc-and-bac-metadata.tsv") + " is already here :)")
                    print("")
                    sys.exit(0)

            # reading in if using program beyond just grabbing table
            print("")
            wprint(color_text("Reading in the GTDB info table...", "yellow"))

            # reporting version info
            report_gtdb_version_info(GTDB_dir)

            gtdb_tab = pd.read_csv(GTDB_dir + "GTDB-arc-and-bac-metadata.tsv", sep="\t", low_memory=False)


        # generating when table doesn't exist yet
        else:
            print("")
            wprint(color_text("Downloading and parsing archaeal and bacterial metadata tables from GTDB (only needs to be done once)...", "yellow"))
            print("")
            
            gtdb_tab = gen_gtdb_tab(GTDB_dir)

            # reporting version info
            report_gtdb_version_info(GTDB_dir)

            # copying to current directory if user just wants the table
            if args.get_table:
                shutil.copy(GTDB_dir + "GTDB-arc-and-bac-metadata.tsv", "GTDB-arc-and-bac-metadata.tsv")
                print("")
                wprint("GTDB table written to:")
                print(color_text("    GTDB-arc-and-bac-metadata.tsv\n"))
                sys.exit(0)

    # if there is no GTDB_dir environmental variable (i.e. if not conda-installed)
    else:
        # checking if table and version info already exist in current directory
        if os.path.exists("GTDB-arc-and-bac-metadata.tsv") and os.path.exists("GTDB-version-info.txt"):

            # making sure the stored version is the latest available from GTDB (unless the do not check flag was provided)
            if prevent_database_version_check == False:
                check_stored_gtdb_up_to_date("./")

            if args.get_table:
                print("")
                wprint("The GTDB table " + color_text("GTDB-arc-and-bac-metadata.tsv") + " is already here :)")
                print("")
                sys.exit(0)

            # reading in if already present
            print("")
            wprint(color_text("Reading in the GTDB info table...", "yellow"))
            report_gtdb_version_info("./")

            gtdb_tab = pd.read_csv("GTDB-arc-and-bac-metadata.tsv", sep="\t", low_memory=False)

        # if doesn't already exist in current working directory
        else:

            print("")
            wprint(color_text("Downloading and parsing archaeal and bacterial metadata tables from GTDB...", "yellow"))

            gtdb_tab = gen_gtdb_tab("./")
            report_gtdb_version_info("./")

            if args.get_table:
                print("")
                wprint("GTDB table written to:")
                print(color_text("    GTDB-arc-and-bac-metadata.tsv\n"))
                sys.exit(0)

    return(gtdb_tab)


if __name__ == "__main__":
    main()
