#!/usr/bin/env python

import os
import argparse
import sys
import pandas as pd

parser = argparse.ArgumentParser(description='This script is for combining GO summary tables produced by\
                                              `bit-summarize-go-annotations`. For version info, run `bit-version`.')


required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input-files", metavar = "<FILE(s)>", nargs = "+",
                      type = str, help = "space-delimited list of `bit-summarize-go-annotations` output files",
                      action = "store", required = True)
parser.add_argument("-n", "--sample-names", metavar = "<NAME(s)>",
                    help = 'Sample names provided as a comma-delimited list, be sure it matches the order of the input files (by default will use basename of input files up to last period)',
                    action = "store", default = '')
parser.add_argument("-o", "--output-file", metavar = "<FILE>",
                    help = 'Output combined summaries (default: "combined-GO-summaries.tsv")',
                    action = "store", default = "combined-GO-summaries.tsv")

if len(sys.argv)==1:
  parser.print_help(sys.stderr)
  sys.exit(0)

args = parser.parse_args()

# setting up variables
sample_counts = {}
total_counts = {}
# file name is key, and sample name is value
all_samples = {}

# setting sample names and intializing counts
if len(args.sample_names) == 0:
    for file in args.input_files:
        curr_sample = os.path.basename(file).rsplit('.', 1)[0]
        total_counts[curr_sample] = 0

        if file in all_samples:
            print('\n    It seems the file "' + file + '" is trying to get in here twice.')
            print("\n    That's not gonna fly :(\n")
            sys.exit(1)

        all_samples[file] = curr_sample

else:

    # checking if sample names provided the length equals the number of input files
    if len(args.sample_names.split(",")) != len(args.input_files):
        print("\n    It seems the number of provided sample names doesn't match the number of provided input files :(")
        print("\n    Check usage with `bit-combine-go-summaries -h`.\n")
        sys.exit(0)

    # setting iterator
    i = 0

    for curr_sample in args.sample_names.split(","):
        total_counts[curr_sample] = 0
        all_samples[args.input_files[i]] = curr_sample
        i += 1

# keeping a nested dictionary of info for all GO terms that show up in any table
GO_dict = {}

# building counts/percents table
building_tab = pd.DataFrame(columns=["GO_term"])

## working on each file
for sample_key in all_samples:

    # reading current file into pandas dataframe
    curr_tab = pd.read_csv(sample_key, sep="\t")


    # adding to building GO dictionary of all GO terms in the input tables
    for row in curr_tab.itertuples():

        if row[1] not in GO_dict:
            GO_dict[row[1]] = {'namespace': row[2], 'depth': row[3], 'name': row[4]}

    # trimming down current table
    curr_sub_tab = curr_tab[["GO_term", "counts", "percent_of_annotated"]]
    # and changing names to match sample ID
    curr_sub_tab.columns = ['GO_term', str(all_samples[sample_key]) + "_counts", str(all_samples[sample_key]) + "_perc_of_annotated"]

    # merging with master tab on GO_term
    building_tab = building_tab.merge(curr_sub_tab, on="GO_term", how="outer")

## replacing NAs with 0s
building_tab = building_tab.fillna(0)

## making GO info dict into dataframe and merging into final table
go_df = pd.DataFrame.from_dict(GO_dict, orient="index")
# moving index to column and renaming
go_df.reset_index(inplace=True)
go_df.rename(columns = {'index': 'GO_term'}, inplace=True)
# merging
final_tab = go_df.merge(building_tab, on="GO_term", how="outer")
# sorting
final_tab.sort_values(by=["namespace", "depth"], inplace=True)

## writing out
with open(args.output_file, "w") as out:
    out.write(final_tab.to_csv(index=False, sep="\t"))
