#!/usr/bin/env python

"""
This was originally a helper program of GToTree (https://github.com/AstrobioMike/GToTree/wiki)
to facilitate accessing and using taxonomy and genomes from the glorious Genome Taxonomy
Database (gtdb.ecogenomic.org/), but found use for it much more than in GToTree so included here too.
"""

import sys
import os
import urllib.request
import tarfile
import pandas as pd
import textwrap
import argparse
import shutil
import filecmp

parser = argparse.ArgumentParser(description = "This is a helper program to facilitate using taxonomy \
                                              and genomes from the Genome Taxonomy Database (gtdb.ecogenomic.org).\
                                              It primarily returns NCBI accessions and GTDB summary tables based on GTDB-taxonomy\
                                              searches, which could then be passed to, e.g., `bit-dl-ncbi-assemblies`. It also \
                                              currently has filtering capabilities built-in for specifying only\
                                              GTDB representative species or RefSeq representative genomes (see help menu and links\
                                              therein for explanations of what these are). It always uses the latest GTDB release version.\
                                              For version info, run `bit-version`.",
                                 epilog="Ex. usage: bit-get-accessions-from-GTDB -t Archaea --GTDB-representatives-only\n")

parser.add_argument("-t", "--target-taxon", metavar = "<STR>", help = "Target taxon (enter 'all' for all)", action = "store")
parser.add_argument("-r", "--target-rank", metavar = "<STR>", help = 'Target rank', action = "store")
parser.add_argument("--get-table", action = "store_true", help = "Provide just this flag alone to download and parse a GTDB metadata \
                                           table. Archaea and Bacteria tables pulled from here \
                                           (https://data.ace.uq.edu.au/public/gtdb/data/releases/latest/) and combined, and the \
                                           GTDB taxonomy column is split for easier manual searching if wanted.")
parser.add_argument("--get-rank-counts", action = "store_true", help = "Provide just this flag alone to see counts of how many \
                                           unique taxa there are for each rank.")
parser.add_argument("--get-taxon-counts", action = "store_true", help = "Provide this flag along with a specified taxon to the `-t` flag \
                                                                     to see how many of that taxon are in the database.")
parser.add_argument("--GTDB-representatives-only", action = "store_true", help = "Provide this flag along with a specified taxon to the `-t` flag \
                                                                        to pull accessions only for genomes designated as GTDB species \
                                                                        representatives (see e.g.: https://gtdb.ecogenomic.org/faq#gtdb_species_clusters).")
parser.add_argument("--RefSeq-reference-genomes-only", action = "store_true", help = "Provide this flag along with a specified taxon to the `-t` flag \
                                                                                    to pull accessions only for genomes designated as RefSeq \"reference\" \
                                                                                    genomes (these used to be called \"representative\" genomes, see e.g.: \
                                                                                    https://www.ncbi.nlm.nih.gov/refseq/about/prokaryotes/#reference_genomes).\
                                                                                    (Useful for subsetting a view across a broad level of diversity.)")


if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)


args = parser.parse_args()


################################################################################

def main():

    if args.get_table:
        get_gtdb_tab()
        sys.exit(0)

    ## some checks to prevent things that either should be provided together, or should be mutually exclusive
    if args.get_taxon_counts and not args.target_taxon:
        print("")
        wprint(color_text("A specific taxon needs to also be provided to the `-t` flag in order to use `--get-taxon-counts`.", "yellow"))
        print("")
        wprint("  E.g.: gtt-get-accessions-from-GTDB --get-taxon-counts -t Alteromonas")
        print("")
        sys.exit(0)

    if args.GTDB_representatives_only and args.RefSeq_representatives_only:
        print("")
        wprint(color_text("Only one of `--GTDB-representatives-only` or `--RefSeq-representatives-only` can be provided.", "yellow"))
        print("")
        sys.exit(0)

    ## moving on
    gtdb_tab = get_gtdb_tab()

    if args.GTDB_representatives_only:
        gtdb_rep_tab = gtdb_tab[gtdb_tab["gtdb_representative"] == "t"]
        representatives_source = "GTDB"

    if args.RefSeq_representatives_only:
        gtdb_rep_tab = gtdb_tab[gtdb_tab["ncbi_refseq_category"] == "reference genome"]
        representatives_source = "RefSeq"

    if args.get_rank_counts:
        if args.GTDB_representatives_only or args.RefSeq_representatives_only:
            get_unique_taxa_counts_of_all_ranks(gtdb_tab, gtdb_rep_tab, representatives_source=representatives_source)
            sys.exit(0)
        else:
            get_unique_taxa_counts_of_all_ranks(gtdb_tab)
            sys.exit(0)

    if args.get_taxon_counts:
        if args.GTDB_representatives_only or args.RefSeq_representatives_only:
            get_unique_taxon_counts(args.target_taxon, gtdb_tab, gtdb_rep_tab, representatives_source=representatives_source)
            sys.exit(0)
        else:
            get_unique_taxon_counts(args.target_taxon, gtdb_tab)
            sys.exit(0)

    if args.target_taxon and not args.target_rank:
        if args.GTDB_representatives_only or args.RefSeq_representatives_only:
            get_accessions(args.target_taxon, gtdb_tab, gtdb_rep_tab, representatives_source=representatives_source)
            sys.exit(0)
        else:
            get_accessions(args.target_taxon, gtdb_tab)
            sys.exit(0)

    if args.target_taxon and args.target_rank:
        if args.GTDB_representatives_only or args.RefSeq_representatives_only:
            get_accessions(args.target_taxon, gtdb_tab, gtdb_rep_tab, rank=args.target_rank, representatives_source=representatives_source)
            sys.exit(0)
        else:
            get_accessions(args.target_taxon, gtdb_tab, rank=args.target_rank)
            sys.exit(0)

################################################################################


# setting some colors
tty_colors = {
    'green' : '\033[0;32m%s\033[0m',
    'yellow' : '\033[0;33m%s\033[0m',
    'red' : '\033[0;31m%s\033[0m'
}


### functions ###
def color_text(text, color='green'):
    if sys.stdout.isatty():
        return tty_colors[color] % text
    else:
        return text


def wprint(text):
    print(textwrap.fill(text, width=80, initial_indent="  ",
          subsequent_indent="  ", break_on_hyphens=False))


def get_accessions(taxon, gtdb_tab, gtdb_rep_tab=None, rank=None, representatives_source=None):
    """ get accessions based on specified taxon, if the provided taxon is in more than one rank, will require specified rank """

    if taxon != "all":

        if not representatives_source:

            ranks_found_in = get_unique_taxon_counts(taxon, gtdb_tab, return_ranks_found_in=True)

            if not rank:

                if len(ranks_found_in) > 1:
                    wprint(color_text("Since '" + str(taxon) + "' occurs at more than 1 rank, we'll need to specify \
                           which rank is wanted as well before we pull the accessions. This can be specified with the `-r` flag.", "yellow"))
                    print("")
                    sys.exit(0)

                else:

                    rank = ranks_found_in[0]

            else:

                rank = rank.lower()

            # subsetting table and writing out that and accessions list
            gtdb_sub_tab = gtdb_tab[gtdb_tab[rank] == taxon]

            # swapping spaces for dashes in case it's a species taxon
            taxon = taxon.replace(" ", "-")

            tab_out_filename = "GTDB-" + taxon + "-" + rank + "-metadata.tsv"
            acc_out_filename = "GTDB-" + taxon + "-" + rank + "-accs.txt"

            gtdb_sub_tab.to_csv(tab_out_filename, sep="\t", index=False)

            target_accs = gtdb_sub_tab["ncbi_genbank_assembly_accession"].tolist()

            with open(acc_out_filename, "w") as out:
                for acc in target_accs:
                    out.write(acc + "\n")

            wprint("The targeted NCBI accessions were written to:")
            wprint("  " + color_text(acc_out_filename))
            print("")
            wprint("A subset GTDB table of these targets was written to:")
            wprint("  " + color_text(tab_out_filename))
            print("")

        else:

            ranks_found_in_rep = get_unique_taxon_counts(taxon, gtdb_tab, gtdb_rep_tab, return_ranks_found_in=True, representatives_source=representatives_source)

            if not rank:

                if len(ranks_found_in_rep) > 1:
                    wprint(color_text("Since '" + str(taxon) + "' occurs at more than 1 rank, we'll need to specify \
                           which rank is wanted as well before we pull the accessions. This can be specified with the `-r` flag.", "yellow"))
                    print("")
                    sys.exit(0)

                else:

                    rank = ranks_found_in_rep[0]

            else:

                rank = rank.lower()

            # subsetting table and writing out that and accessions list
            gtdb_sub_rep_tab = gtdb_rep_tab[gtdb_rep_tab[rank] == taxon]

            tab_out_filename = "GTDB-" + taxon + "-" + rank + "-" + representatives_source + "-rep-metadata.tsv"
            acc_out_filename = "GTDB-" + taxon + "-" + rank + "-" + representatives_source + "-rep-accs.txt"

            gtdb_sub_rep_tab.to_csv(tab_out_filename, sep="\t", index=False)

            target_accs = gtdb_sub_rep_tab["ncbi_genbank_assembly_accession"].tolist()

            with open(acc_out_filename, "w") as out:
                for acc in target_accs:
                    out.write(acc + "\n")

            wprint("The targeted NCBI accessions were written to:")
            wprint("  " + color_text(acc_out_filename))
            print("")
            wprint("A subset GTDB table of these targets was written to:")
            wprint("  " + color_text(tab_out_filename))
            print("")

    else:

        if not representatives_source:

            target_accs = gtdb_tab["ncbi_genbank_assembly_accession"].tolist()

            with open("GTDB-arc-and-bac-accessions.txt", "w") as out:
                for acc in target_accs:
                    out.write(acc + "\n")

            print("")
            wprint("The NCBI accessions were written to: " + color_text("GTDB-arc-and-bac-accessions.txt"))
            print("")
            wprint("The GTDB table that already exists holds all of these: " + color_text("GTDB-arc-and-bac-metadata.tsv"))
            print("")

        else:

            # already subset, writing out that and accessions list
            gtdb_rep_tab.to_csv("GTDB-arc-and-bac-refseq-rep-metadata.tsv", sep="\t", index=False)

            target_accs = gtdb_rep_tab["ncbi_genbank_assembly_accession"].tolist()

            with open("GTDB-arc-and-bac-refseq-rep-accessions.txt", "w") as out:
                for acc in target_accs:
                    out.write(acc + "\n")

            print("")
            wprint("The targeted NCBI accessions were written to: " + color_text("GTDB-arc-and-bac-refseq-rep-accessions.txt"))
            print("")
            wprint("A subset GTDB table of these targets was written to: " + color_text("GTDB-arc-and-bac-refseq-rep-metadata.tsv"))
            print("")


def get_unique_taxon_counts(taxon, gtdb_tab, gtdb_rep_tab=None, return_ranks_found_in=False, representatives_source=None):
    """ get counts of specific taxa """

    if taxon == "all":
        count = len(gtdb_tab.index)
        print("")
        wprint("  There are " + str(count) + " total genomes in the database.")
        print("")

        if representatives_source:
            count = len(gtdb_rep_tab.index)
            wprint(color_text("In considering only " + representatives_source + " representative genomes:", "yellow"))
            print("")
            wprint("  There are " + str(count) + " total representative genomes in the database.")
            print("")

    else:

        ranks = ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']

        ranks_found_in = []

        for rank in ranks:

            if taxon in gtdb_tab[rank].unique():

                ranks_found_in.append(rank)

                count = len(gtdb_tab[gtdb_tab[rank] == taxon].index)
                wprint("  The rank '" + rank + "' has " + str(count) + " " + taxon + " entries.")

        if len(ranks_found_in) == 0:
            wprint(color_text("Input taxon '" + taxon + "' doesn't seem exist at any rank :(", "yellow"))
            wprint("Be aware, to be on the safe side, our searching is case-sensitive.")
            print("")
            sys.exit(0)

        print("")

        if representatives_source:

            ranks_found_in_rep = []

            wprint(color_text("In considering only " + representatives_source + " representative genomes:", "yellow"))
            print("")
            for rank in ranks:

                if taxon in gtdb_rep_tab[rank].unique():

                    ranks_found_in_rep.append(rank)

                    count = len(gtdb_rep_tab[gtdb_rep_tab[rank] == taxon].index)
                    wprint("  The rank '" + rank + "' has " + str(count) + " " + taxon + " representative genome entries.")
                    print("")

            if len(ranks_found_in_rep) == 0:
                wprint(color_text("Input taxon '" + taxon + "' doesn't seem exist at any rank as a representative genome :(", "yellow"))
                print("")
                sys.exit(0)

            if return_ranks_found_in:
                return(ranks_found_in_rep)

        else:
            if return_ranks_found_in:
                return(ranks_found_in)


def get_unique_taxa_counts_of_all_ranks(gtdb_tab, gtdb_rep_tab=None, representatives_source=None):
    """ get counts of unique taxa at each rank """

    ranks = ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']

    print("\n    {:<10} {:}".format("Rank", "Num. Unique Taxa"))
    for rank in ranks:
        print("    {:<10} {:}".format(rank, str(gtdb_tab[rank].nunique())))

    print("")

    # below only needed for RefSeq, because if it is GTDB, it is equivalent (i.e., they have 1 rep genome for each unique rank in their system)
    if representatives_source == "RefSeq":

        wprint(color_text("In considering only " + representatives_source + " representative genomes:", "yellow"))
        print("")

        print("    {:<10} {:}".format("Rank", "Num. Unique Rep. Taxa"))
        for rank in ranks:
            print("    {:<10} {:}".format(rank, str(gtdb_rep_tab[rank].nunique())))

        print("")


def gen_gtdb_tab(location):
    """ downloads and parses the GTDB info tables """

    # getting archaea
    arc_tsv_gz = urllib.request.urlopen("https://data.ace.uq.edu.au/public/gtdb/data/releases/latest/ar53_metadata.tsv.gz")
    arc_tab = pd.read_csv(arc_tsv_gz, sep="\t", compression="gzip", on_bad_lines='skip', header=0, low_memory=False)
    arc_tab.rename(columns={arc_tab.columns[0]:"accession"}, inplace=True)
    arc_tab.dropna(inplace=True, how="all")

    # getting bacteria
    bac_tsv_gz = urllib.request.urlopen("https://data.ace.uq.edu.au/public/gtdb/data/releases/latest/bac120_metadata.tsv.gz")
    bac_tab = pd.read_csv(bac_tsv_gz, sep="\t", compression="gzip", on_bad_lines='skip', header=0, low_memory=False)
    bac_tab.rename(columns={bac_tab.columns[0]:"accession"}, inplace=True)
    bac_tab.dropna(inplace=True, how="all")

    # combining
    gtdb_tab = pd.concat([arc_tab, bac_tab])

    # splitting gtdb taxonomy column into 7 and dropping the single column
    domain, phylum, rclass, order, family, genus, species = [], [], [], [], [], [], []

    for index, row in gtdb_tab.iterrows():
        curr_acc = row["accession"]
        tax_list = row["gtdb_taxonomy"].split(";")

        if len(tax_list) != 7:
            wprint(color_text("GTDB entry " + curr_acc + " doesn't seem to have 7-column lineage info. Something is likely wrong :(", "yellow"))
            print("")
            wprint("If this continues to happen, please file an issue at github.com/AstrobioMike/GToTree/issues")
            print("")
            wprint("Aborting for now.")
            print("")
            sys.exit(0)

        else:
            domain.append(tax_list[0][3:])
            phylum.append(tax_list[1][3:])
            rclass.append(tax_list[2][3:])
            order.append(tax_list[3][3:])
            family.append(tax_list[4][3:])
            genus.append(tax_list[5][3:])
            species.append(tax_list[6][3:])

    gtdb_tab.insert(1, "species", species)
    gtdb_tab.insert(1, "genus", genus)
    gtdb_tab.insert(1, "family", family)
    gtdb_tab.insert(1, "order", order)
    gtdb_tab.insert(1, "class", rclass)
    gtdb_tab.insert(1, "phylum", phylum)
    gtdb_tab.insert(1, "domain", domain)

    # writing out
    gtdb_tab.to_csv(location + "GTDB-arc-and-bac-metadata.tsv", index=False, sep="\t")

    gtdb_version_info = urllib.request.urlretrieve("https://data.ace.uq.edu.au/public/gtdb/data/releases/latest/VERSION.txt", location + "GTDB-version-info.txt")

    return(gtdb_tab)


def report_gtdb_version_info(location):
    """ reporting GTDB version info """

    version_info = []

    with open(location + "GTDB-version-info.txt") as version_info_file:

        for line in version_info_file:

            line = line.strip()

            if line != "":

                version_info.append(line)

    gtdb_version = version_info[0]
    gtdb_release_date = version_info[1]

    print("    Using GTDB " + gtdb_version + ": " + gtdb_release_date + "\n")


def check_stored_gtdb_up_to_date(location):
    """ checks if the stored gtdb metadata table is the latest """

    # getting latest version info from GTDB
    gtdb_latest_version_info = urllib.request.urlretrieve("https://data.ace.uq.edu.au/public/gtdb/data/releases/latest/VERSION.txt", location + "GTDB-latest-version-info.txt")

    # comparing vs old
    if filecmp.cmp(location + "GTDB-latest-version-info.txt", location + "GTDB-version-info.txt"):
        os.remove(location + "GTDB-latest-version-info.txt")

    else:
        print("")
        wprint(color_text("A newer version of the GTDB is available, updating stored table...", "yellow"))

        gen_gtdb_tab(location)
        os.remove(location + "GTDB-latest-version-info.txt")


def get_gtdb_tab():
    """ reads in file with gtdb info """

    # checking if there is a GTDB_DIR (should be if conda install of bit was done)
    try:
        GTDB_DIR = os.environ['GTDB_DIR']
    except:
        GTDB_DIR = False

    # working there if there is one
    if GTDB_DIR:

        if os.path.exists(GTDB_DIR + "GTDB-arc-and-bac-metadata.tsv") and os.path.exists(GTDB_DIR + "GTDB-version-info.txt"):

            # making sure the stored version is the latest available from GTDB
            check_stored_gtdb_up_to_date(GTDB_DIR)

            # copying if just table wanted
            if args.get_table:
                if not os.path.exists("GTDB-arc-and-bac-metadata.tsv"):

                    # reporting version info
                    report_gtdb_version_info(GTDB_DIR)

                    shutil.copy(GTDB_DIR + "GTDB-arc-and-bac-metadata.tsv", "GTDB-arc-and-bac-metadata.tsv")
                    print("")
                    # reporting version info
                    report_gtdb_version_info(GTDB_DIR)
                    wprint("GTDB table copied to:")
                    print(color_text("    GTDB-arc-and-bac-metadata.tsv\n"))
                    sys.exit(0)

                else:
                    print("")
                    wprint("The GTDB table " + color_text("GTDB-arc-and-bac-metadata.tsv") + " is already here :)")
                    print("")
                    sys.exit(0)

            # reading in if using program beyond just grabbing table
            print("")
            wprint(color_text("Reading in the GTDB info table...", "yellow"))

            # reporting version info
            report_gtdb_version_info(GTDB_DIR)

            gtdb_tab = pd.read_csv(GTDB_DIR + "GTDB-arc-and-bac-metadata.tsv", sep="\t", low_memory=False)


        # generating when table doesn't exist yet
        else:
            print("")
            wprint(color_text("Downloading and parsing archaeal and bacterial metadata tables from GTDB (only needs to be done once, or if there is a new version available)...", "yellow"))
            print("")

            gtdb_tab = gen_gtdb_tab(GTDB_DIR)

            # reporting version info
            report_gtdb_version_info(GTDB_DIR)

            # copying to current directory if user just wants the table
            if args.get_table:
                shutil.copy(GTDB_DIR + "GTDB-arc-and-bac-metadata.tsv", "GTDB-arc-and-bac-metadata.tsv")
                print("")
                wprint("GTDB table written to:")
                print(color_text("    GTDB-arc-and-bac-metadata.tsv\n"))
                sys.exit(0)

    # if there is no GTDB_DIR environmental variable (i.e. if not conda-installed)
    else:
        # checking if table and version info already exist in current directory
        if os.path.exists("GTDB-arc-and-bac-metadata.tsv") and os.path.exists("GTDB-version-info.txt"):

            # making sure the stored version is the latest available from GTDB
            check_stored_gtdb_up_to_date("./")

            if args.get_table:
                print("")
                wprint("The GTDB table " + color_text("GTDB-arc-and-bac-metadata.tsv") + " is already here :)")
                print("")
                sys.exit(0)

            # reading in if already present
            print("")
            wprint(color_text("Reading in the GTDB info table...", "yellow"))
            report_gtdb_version_info("./")

            gtdb_tab = pd.read_csv("GTDB-arc-and-bac-metadata.tsv", sep="\t", low_memory=False)

        # if doesn't already exist in current working directory
        else:

            print("")
            wprint(color_text("Downloading and parsing archaeal and bacterial metadata tables from GTDB (only needs to be done once, or if there is a new version available)...", "yellow"))
            print("")

            gtdb_tab = gen_gtdb_tab("./")
            report_gtdb_version_info("./")

            if args.get_table:
                print("")
                wprint("GTDB table written to:")
                print(color_text("    GTDB-arc-and-bac-metadata.tsv\n"))
                sys.exit(0)

    return(gtdb_tab)


if __name__ == "__main__":
    main()
