#!/usr/bin/env python

import sys
import argparse

parser = argparse.ArgumentParser(description='This script is for parsing GTDB\'s assembly metadata file down to the target accessions.')

required = parser.add_argument_group('required arguments')

required.add_argument("-a", "--assembly_summary", help="GTDB's assembly metadata file", action="store", dest="all_assemblies", required=True)
required.add_argument("-w", "--wanted_accessions", help="Single-column file with wanted accessions", action="store", dest="wanted_accs", required=True)
parser.add_argument("-o", "--output_file", help='Wanted summary info only (default: "target-gtdb.tsv")', action="store", dest="output_file", default="target-gtdb.tsv")
parser.add_argument("-f", "--found_accs_output_file", help='Accessions found in GTDB (default: "gtdb-found-accs.txt")', action="store", dest="found_accs_output_file", default="gtdb-found-accs.txt")
parser.add_argument("-n", "--not_found_accs_output_file", help='Accessions not found in GTDB (default: "gtdb-not-found-accs.tsv")', action="store", dest="not_found_accs_output_file", default="gtdb-not-found-accs.tsv")
parser.add_argument("-t", "--gtdb_tax_output_file", help='Target GTDB taxonomy table (default: "target-gtdb-tax.tsv")', action="store", dest="gtdb_tax_output_file", default="target-gtdb-tax.tsv")



if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()

wanted_dict = {}

with open(args.wanted_accs, "r") as wanted_accs:
    for line in wanted_accs:
        root_acc = line.strip().split(".")[0][4:]
        wanted_dict[str(root_acc)] = line.strip()

out_file = open(args.output_file, "w")

# tracking found so can write out those not found at end too
found_accs = []

gtdb_found_accs_out_file = open(args.found_accs_output_file, "w")
# adding header
gtdb_found_accs_out_file.write("input_searched\tgtdb_acc_found\tfull_gtdb_acc\n")


with open(args.all_assemblies) as assemblies:
    # writing out header to keep
    out_file.write(assemblies.readline())

    for line in assemblies:
        split_line = line.strip().split("\t")
        
        acc_with_no_version = split_line[0][7:].split(".")[0]
        
        # i believe refseq typically only has 1 version, so taking even if not the same version as specified (this info, what was searched and what was found, is reported in the output "/run_files/gtdb_to_input_accession_map.tsv" file)
        if acc_with_no_version in wanted_dict:

            out_file.write(line)
            gtdb_found_accs_out_file.write(wanted_dict[acc_with_no_version] + "\t" + split_line[0][3:] + "\t" + split_line[0] + "\n")
            
            # adding to found list
            found_accs.append(wanted_dict[acc_with_no_version])
            
out_file.close()
gtdb_found_accs_out_file.close()

## getting and writing out entries that weren't found (and how they were searched)
wanted_accs = list(wanted_dict.values())
not_found_accs = list(set(wanted_accs) - set(found_accs))

if len(not_found_accs) > 0:
    with open(args.not_found_accs_output_file, "w") as not_found_output_file:

        # writing out header
        not_found_output_file.write("input\tsearched_as\n")

        for key in wanted_dict:
            if wanted_dict[key] in not_found_accs:

                not_found_output_file.write(wanted_dict[key] + "\t" + key + "\n")

# making GTDB taxonomy table only
out_gtdb_tax_table = open(args.gtdb_tax_output_file, "w")

# adding header
out_gtdb_tax_table.write("base_gtdb_acc\tfull_gtdb_acc\tdomain\tphylum\tclass\torder\tfamily\tgenus\tspecies\n")

with open(args.output_file, "r") as assemblies:

    # skipping header
    next(assemblies)
    
    for line in assemblies:
        line = line.strip().split("\t")
        full_gtdb_acc = line[0]
        base_gtdb_acc = full_gtdb_acc[3:]

        gtdb_tax_list = [line[1], line[2], line[3], line[4], line[5], line[6], line[7]]
        
        if len(gtdb_tax_list) != 7:
            print("GTDB entry " + full_gtdb_acc + " doesn't seem to have full lineage info.")

        out_gtdb_tax_table.write(base_gtdb_acc + "\t" + full_gtdb_acc + "\t" + gtdb_tax_list[0] + "\t" + gtdb_tax_list[1] + "\t" + gtdb_tax_list[2] + "\t" + gtdb_tax_list[3] + "\t" + gtdb_tax_list[4] + "\t" + gtdb_tax_list[5] + "\t" + gtdb_tax_list[6] + "\n")

out_gtdb_tax_table.close()
