#!/usr/bin/env python

## learned most goatools/python things from this great tutorial: GO Tutorial in Python - Solutions.ipynb, which comes from here: http://gohandbook.org/doku.php ; https://nbviewer.jupyter.org/urls/dessimozlab.github.io/go-handbook/GO%20Tutorial%20in%20Python%20-%20Solutions.ipynb

from goatools import obo_parser
import os
import argparse
import pandas as pd
import sys
import subprocess

parser = argparse.ArgumentParser(description = 'This script takes an GO-annotated tab-delimited file with gene IDs in the first column, and associated \
                                              GO terms in second column (can be multiple delimited by semi-colons). By default, it returns one \
                                              tab-delimited summary table with counts and percentages for each GO term. If the `--by_namespace` \
                                              flag is added, it will also return separate summaries for each GO namespace. For version info,  \
                                              run `bit-version`.')

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input-GO-annotations", metavar = "<FILE>", help = "Input annotations file", action = "store", dest="input_annotations", required = True)
parser.add_argument("-g", "--GO-obo-file", metavar = "<STR>", help = 'GO obo file to use (e.g. from: http://geneontology.org/docs/download-ontology/). By default will \
                                                 use "goslim_metagenomics.obo". "go-basic.obo" is also a pre-packaged option (enter `-g go_basic` to specify it). Or \
                                                 a different obo-formatted file can be specified here (it should probably be the one the annotations are based off of, or were slimmed based on).',
                    action = "store", dest = "obo", default = "goslim_metagenomics")
parser.add_argument("-o", "--output-file", metavar = "<FILE>", help = 'Output file name (default: "GO-summary.tsv").', action = "store", default = "GO-summary.tsv")
parser.add_argument("--keep-zeroes", help = "By default the program will remove any rows with for GO terms with 0 counts to them. Add this flag \
                                             to keep them.", action = "store_true")
parser.add_argument("--by-namespace", help = 'By default the program will return a single summary table with all GO terms combined. Add this flag \
                                            to also return individual tables for each GO namespace. (Will append namespace ID in front of last extension \
                                            of specified output file name, e.g. with default "GO-summary.tsv" will also produce "GO-summary-MF.tsv" for \
                                            the Molecular Function namespace.', action = "store_true")


if len(sys.argv)==1:
  parser.print_help(sys.stderr)
  sys.exit(0)

args = parser.parse_args()

### checking and setting up obo file location
go_data_dir = os.environ["GO_DB_DIR"]

## downloading default GO databases if they are not present already
checking_db_dir = subprocess.run(["helper-bit-setup-GO-dbs"])

if args.obo == "goslim_metagenomics":
    go_obo = go_data_dir + "goslim_metagenomics.obo"

elif args.obo == "go_basic":
    go_obo = go_data_dir + "go-basic.obo"

else:
    go_obo = args.obo

### loading GO database
print("\n\tGO obo file being used:")
go = obo_parser.GODag(go_obo, load_obsolete=True)
print("")

#### PARSING OBO TO BUILD TABLE
# getting unique GO Term IDs in the metagenomics.obo
unique_GO_terms_set = set(go.keys())

# function to get info for a go term:
def get_general_info(go_id):
    go_term = go[go_id]
    name = go_term.name
    namespace = go_term.namespace
    depth = go_term.depth
    is_obsolete = go_term.is_obsolete

    # returns a list
    go_term_info = [go_id, namespace, depth, name, is_obsolete]

    return go_term_info

# creating list then table of GO info for each term
table_list = [[]]
for GO_ID in unique_GO_terms_set:
    table_list.append(get_general_info(GO_ID))

table_list = list(filter(None, table_list)) # because i don't know how to initialize the list without it introducing a row of NAs...

header = ["GO_term", "namespace", "depth", "name", "is_obsolete"]
GO_df = pd.DataFrame(table_list, columns=header).sort_values(by=["depth"])

#### REMOVING ANY TERMS THAT ARE OBSOLETE FROM THE TABLE AND MAKING A SET OF OBSOLETE IDS TO HANDLE THEM PROPERLY IF IN THE INPUT ANNOTATIONS (http://geneontology.org/docs/GO-term-elements)
if True in GO_df["is_obsolete"]:

    GO_df_of_obsolete = GO_df.loc[GO_df["is_obsolete"] == True, ["GO_term"]]
    obsolete_terms = set(GO_df_of_obsolete["GO_term"])

    GO_df = GO_df.loc[GO_df["is_obsolete"] == False, ["GO_term", "namespace", "depth", "name"]]

#### COUNTING HOW MANY TIMES EACH TERM APPEARS IN ANNOTATION FILE
# creating dictionary to hold counts
term_list = GO_df["GO_term"]
term_counts_dict={key:0 for key in term_list}

# iterating through input file
with open(args.input_annotations, "r") as annots:
    for line in annots:

        if len(line.strip().split("\t")) <= 1:
            continue
        elif line.startswith("#"):
            continue
        else:
            terms=line.strip().split("\t")[1].replace(" ", "")
            for term in terms.split(";"):
                # only moving forward if term isn't obsolete
                if term not in obsolete_terms:
                    # checking if term exists in dictionary, if it doesn't, the annotation or slim source used isn't being provided here and it should be
                    if term in term_counts_dict.keys():
                        term_counts_dict[term] += 1
                    else:
                        print('\n    So the GO term "' + str(term) + '" shows up in the provided annotation file')
                        print('    "' + str(args.input_annotations) + '", but it is not present in the obo reference')
                        print('    file used: "' + str(args.obo) + "\". We aren't sure how to deal with this. You might")
                        print('    want to pass a different obo file to the `-g` argument.\n')
                        print('        Exiting for now :(\n')
                        sys.exit(1)

GO_df['term_counts'] = term_counts_dict.values()


#### ADDING COLUMN WITH PERCENTAGE OF TOTAL TERMS FOR EACH TABLE
GO_perc_df = pd.DataFrame({"term_perc_of_annotated": GO_df["term_counts"] / GO_df["term_counts"].sum() * 100})
GO_df = pd.concat([GO_df, GO_perc_df], axis=1)


#### SPLITTING TO CREATE OTHER 3 OUTPUT TABLES IF --by_namespace FLAG WAS PROVIDED
if args.by_namespace:

    biological_process_NS_df = GO_df[GO_df["namespace"] == "biological_process"]
    molecular_function_NS_df = GO_df[GO_df["namespace"] == "molecular_function"]
    cellular_component_NS_df = GO_df[GO_df["namespace"] == "cellular_component"]

    BP_percs_df = pd.DataFrame({"term_perc_of_annotated": biological_process_NS_df["term_counts"] / biological_process_NS_df["term_counts"].sum() * 100})
    biological_process_NS_df = pd.concat([biological_process_NS_df, BP_percs_df], axis=1)

    MF_percs_df = pd.DataFrame({"term_perc_of_annotated": molecular_function_NS_df["term_counts"] / molecular_function_NS_df["term_counts"].sum() * 100})
    molecular_function_NS_df = pd.concat([molecular_function_NS_df, MF_percs_df], axis=1)

    CC_percs_df = pd.DataFrame({"term_perc_of_annotated": cellular_component_NS_df["term_counts"] / cellular_component_NS_df["term_counts"].sum() * 100})
    cellular_component_NS_df = pd.concat([cellular_component_NS_df, CC_percs_df], axis=1)

    # creating basename for additional output tables
    out_base = "".join(args.output_file.split(".")[:-1])

#### WRITING OUT SUMMARY TABLE(S)
if args.keep_zeroes:

    no_term_check_all = GO_df[GO_df["term_counts"] > 0]
    if len(no_term_check_all.index) == 0:
        print("\n\tThere were no counts to any terms :( Sure we're working with the right files here?\n")
        sys.exit(0)
    else:
        with open(args.output_file, "w") as out:
            out.write(GO_df.to_csv(index=False, sep="\t"))

    if args.by_namespace:
        no_term_check_BP = biological_process_NS_df[biological_process_NS_df["term_counts"] > 0]
        if len(no_term_check_BP.index) == 0:
            print("\n\tThere were no counts to any Biological Process terms, so that table wasn't reported.\n")
        else:
            with open(out_base + "-BP.tsv", "w") as out:
                out.write(biological_process_NS_df.to_csv(index=False, sep="\t"))

        no_term_check_MF = molecular_function_NS_df[molecular_function_NS_df["term_counts"] > 0]
        if len(no_term_check_MF.index) == 0:
            print("\n\tThere were no counts to any Molecular Function terms, so that table wasn't reported.\n")
        else:
            with open(out_base + "-MF.tsv", "w") as out:
                out.write(molecular_function_NS_df.to_csv(index=False, sep="\t"))

        no_term_check_CC = cellular_component_NS_df[cellular_component_NS_df["term_counts"] > 0]
        if len(no_term_check_CC.index) == 0:
            print("\n\tThere were no counts to any Cellular Component terms, so that table wasn't reported.\n")
        else:
            with open(out_base + "-CC.tsv", "w") as out:
                out.write(cellular_component_NS_df.to_csv(index=False, sep="\t"))

else:
    GO_df = GO_df[GO_df["term_counts"] > 0]
    if len(GO_df.index) == 0:
        print("\n\tThere were no counts to any terms :( Sure we're working with the right files here?\n")
        sys.exit(0)
    else:
        with open(args.output_file, "w") as out:
            out.write(GO_df.to_csv(index=False, sep="\t"))

    if args.by_namespace:
        biological_process_NS_df = biological_process_NS_df[biological_process_NS_df["term_counts"] > 0]
        if len(biological_process_NS_df.index) == 0:
            print("\n\tThere were no counts to any Biological Process terms, so that table wasn't reported.\n")
        else:
            with open(out_base + "-BP.tsv", "w") as out:
                out.write(biological_process_NS_df.to_csv(index=False, sep="\t"))

        molecular_function_NS_df = molecular_function_NS_df[molecular_function_NS_df["term_counts"] > 0]
        if len(molecular_function_NS_df.index) == 0:
            print("\n\tThere were no counts to any Molecular Function terms, so that table wasn't reported.\n")
        else:
            with open(out_base + "-MF.tsv", "w") as out:
                out.write(molecular_function_NS_df.to_csv(index=False, sep="\t"))

        cellular_component_NS_df = cellular_component_NS_df[cellular_component_NS_df["term_counts"] > 0]
        if len(cellular_component_NS_df.index) == 0:
            print("\n\tThere were no counts to any Cellular Component terms, so that table wasn't reported.\n")
        else:
            with open(out_base + "-CC.tsv", "w") as out:
                out.write(cellular_component_NS_df.to_csv(index=False, sep="\t"))
