#!/usr/bin/env python

"""
This script filters the "detail-tsv"-formatted output of KOFamScan to retain only those above the
KO-specific score threshold, and retain only the best hit for each gene.
Outputs a 3-column tab-delimited table with: gene_ID, KO_ID, and KO_annotation

KOFamScan e.g. usage prior to input here:
    exec_annotation -p profiles/ -k ko_list --cpu 15 -f detail-tsv -o 5492-KO-tab.tmp 5492-genes.faa --tmp-dir 5492-tmp-KO --report-unannotated

Then would be:
    bit-filter-KOFamScan-results -i 5492-KO-tab.tmp -o 5492-annotations.tsv
"""

import sys
import argparse
import pandas as pd

parser = argparse.ArgumentParser(description = "This script filters the 'detail-tsv'-formatted output file from KOFamScan to retain only those above the KO-specific score threshold, and retains only the hit with the lowest e-value for each gene if there are multiple. It outputs a 3-column tab-delimited file with: gene_ID, KO_ID, and KO_annotation. For version info, run `bit-version`")

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input-file", help = "Input annotation table", metavar = "<FILE>", action = "store", required = True)
parser.add_argument("-o", "--output-file", metavar = "<FILE>", help = 'Output table filename (default: "output.tsv")', action = "store", default = "output.tsv")

if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()

# initializing dictionaries
annot_dict = {}
e_value_dict = {}

# looping through input file
  # input table looks like this:

"""
#   gene name     KO         thrshld score    E-value   "KO definition"
#   ---------     ------     ------- ------   --------- -------------
    k119_6520_1   K01999     126.63  44.1     6.7e-11   "branched-chain amino acid transport system substrate-binding protein"
    k119_6520_1   K11954     433.33  39.0     1.8e-09   "neutral amino acid transport system substrate-binding protein"
    k119_6520_1   K04615     290.13  24.9     1.7e-05   "gamma-aminobutyric acid type B receptor"
    k119_6520_1   K11959     388.40  18.0     0.0035    "urea transport system substrate-binding protein"
    k119_6520_1   K05387     467.63  11.1     0.25      "glutamate receptor, ionotropic, plant"
    k119_19560_1  K02014     164.80  144.6    2.1e-41   "iron complex outermembrane recepter protein"
    k119_19560_1  K15721     575.63  122.5    7.3e-35   "pesticin/yersiniabactin receptor"
    k119_19560_1  K16090     578.13  90.5     3.3e-25   "catecholate siderophore receptor"
"""

with open(args.input_file, "r") as annots:

    for line in annots:

        if line.startswith("#"):
            continue

        line = line.lstrip("*").strip().split("\t")

        # adding gene ID if not present, to ensure all end up in final table
        if line[0] not in annot_dict:
            annot_dict[line[0]] = {"KO_ID":"NA", "KO_function":"NA"}

        # nothing there if no annotations for current gene, skipping
        if len(line) == 1:
            continue

        else:

            # only considering if its score is above the threshold
              # some, though very few, like K15869, don't have a threshold score due to having too few representatives, so if no threshold, just taking
            if line[2] == "" or float(line[3]) > float(line[2]):

                # adding to e_value_dict if not represented already, adding annotation to annot_dict, and moving on
                if line[0] not in e_value_dict:

                    annot_dict[line[0]] = {"KO_ID":line[1], "KO_function":line[5].strip('"')}
                    e_value_dict[line[0]] = line[4]
                    continue

                else:

                    # replacing current annotation only if e-value is lower than current
                    if float(line[4]) < float(e_value_dict[line[0]]):

                        annot_dict[line[0]] = {"KO_ID":line[1], "KO_function":line[5].strip('"')}
                        e_value_dict[line[0]] = line[4]

annot_tab = pd.DataFrame.from_dict(annot_dict, orient="index")
annot_tab.reset_index(inplace=True)
annot_tab.rename(columns = {'index':'gene_ID'}, inplace=True)

with open(args.output_file, "w") as out:
    out.write(annot_tab.to_csv(index=False, sep="\t"))
