#!/usr/bin/env python3

### Program: atpoc
### Author: Rauf Salamzade
### Kalan Lab
### UW Madison, Department of Medical Microbiology and Immunology

# BSD 3-Clause License
#
# Copyright (c) 2023, Kalan-Lab
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import sys
import argparse
import subprocess
from time import sleep
from zol import util, fai
import _pickle as cPickle
from Bio import SeqIO
from operator import itemgetter
from collections import defaultdict
import traceback
import pandas as pd
import math 

def create_parser():
	""" Parse arguments """
	parser = argparse.ArgumentParser(description="""
	Program: atpoc
	Author: Rauf Salamzade
	Affiliation: Kalan Lab, UW Madison, Department of Medical Microbiology and Immunology

	atpoc - Assess Temperate Phage-Ome Conservation 
	
	atpoc wraps fai to assess the conservation of a sample's integrated/temperate phage-ome 
	relative to a set of target genomes (e.g. genomes belonging to the same genus). Alternatively, 
	it can run a simple DIAMOND BLASTp analysis to just assess the presence of prophage genes
	individually - without the requirement they are co-located like in the focal sample.
	""", formatter_class=argparse.RawTextHelpFormatter)

	parser.add_argument('-i', '--sample_genome', help='Path to sample genome in GenBank or FASTA format.', required=True, default=True)
	parser.add_argument('-vi', '--vibrant_results', help='Path to VIBRANT results directory for a single sample/genome.', required=False, default=None)
	parser.add_argument('-ps', '--phispy_results', help='Path to PhiSpy results directory for a single sample/genome.', required=False, default=None)
	parser.add_argument('-gn', '--genomad_results', help='Path to GeNomad results directory for a single sample/genome.', required=False, default=None)
	parser.add_argument('-tg', '--target_genomes_db', help='prepTG database directory for target genomes of interest.', required=True)
	parser.add_argument('-up', '--use_pyrodigal', action='store_true', help="Use default pyrodigal instead of prodigal-gv to call genes in\nphage regions to use as queries in fai/simple-blast. This\nis perhaps preferable if target genomes db was created with default pyrodigal/prodigal.", required=False, default=False)
	parser.add_argument('-fo', '--fai_options', help='Provide fai options to run. Should be surrounded by quotes. [Default is "-e 1e-10 -m 0.5 -dm -sct 0.4"]', required=False, default="-e 1e-10 -m 0.5 -dm -sct 0.4")
	parser.add_argument('-s', '--use_simple_blastp', action='store_true', help='Use a simple DIAMOND BLASTp search with no requirement for co-localization of hits.', required=False, default=False)
	parser.add_argument('-si', '--simple_blastp_identity_cutoff', type=float, help='If simple BLASTp mode requested : cutoff for identity between query proteins and matches in target genomes [Default is 40.0].', required=False, default=40.0)
	parser.add_argument('-sc', '--simple_blastp_coverage_cutoff', type=float, help='If simple BLASTp mode requested : cutoff for coverage between query proteins and matches in target genomes [Default is 70.0].', required=False, default=70.0)
	parser.add_argument('-se', '--simple_blastp_evalue_cutoff', type=float, help='If simple BLASTp mode requested : cutoff for E-value between query proteins and matches in target genomes [Default is 1e-5].', required=False, default=1e-5)
	parser.add_argument('-sm', '--simple_blastp_sensitivity_mode', help='Sensitivity mode for DIAMOND BLASTp. [Default is "very-sensititve"].', required=False, default="very-sensitive")
	parser.add_argument('-o', '--outdir', help='Output directory.', required=True)
	parser.add_argument('-c', '--cpus', type=int, help='The number of CPUs to use [Default is 1].', required=False, default=1)

	args = parser.parse_args()
	return args

def atpoc():
	myargs = create_parser()

	sample_genome = myargs.sample_genome
	vibrant_results_dir = myargs.vibrant_results
	phispy_results_dir = myargs.phispy_results
	genomad_results_dir = myargs.genomad_results
	preptg_db_dir = myargs.target_genomes_db
	outdir = os.path.abspath(myargs.outdir) + '/'
	fai_options = myargs.fai_options
	use_simple_blastp_flag = myargs.use_simple_blastp
	use_pyrodigal_flag = myargs.use_pyrodigal
	simple_blastp_identity_cutoff = myargs.simple_blastp_identity_cutoff
	simple_blastp_coverage_cutoff = myargs.simple_blastp_coverage_cutoff
	simple_blastp_evalue_cutoff = myargs.simple_blastp_evalue_cutoff
	simple_blastp_sensitivity_mode = myargs.simple_blastp_sensitivity_mode
	cpus = myargs.cpus

	# create output directory if needed, or warn of over-writing
	if os.path.isdir(outdir):
		sys.stderr.write("Output directory exists. Overwriting in 5 seconds ...\n ")
		sleep(5)
	util.setupReadyDirectory([outdir])

	if (vibrant_results_dir == None or not os.path.isdir(vibrant_results_dir)) and (phispy_results_dir == None or not os.path.isdir(phispy_results_dir) and (genomad_results_dir == None or not os.path.isdir(genomad_results_dir))):
		sys.stderr.write("Neither PhiSpy, VIBRANT, nor geNomad sample results directory provided as input or the paths are faulty!\n")
		sys.exit(1)

	if util.is_genbank(sample_genome):
		try:
			sample_genome_fasta = outdir + '.'.join(sample_genome.split('/')[-1].split('.')[:-1]) + '.fna'
			util.convertGenomeGenBankToFasta([sample_genome, sample_genome_fasta])
			assert(os.path.isfile(sample_genome_fasta) and util.is_fasta(sample_genome_fasta))
			sample_genome = sample_genome_fasta
		except:
			sys.stderr.write("Issue converting sample's genome from GenBank to FASTA format.\n")
			sys.exit(1)

	if not util.is_fasta(sample_genome):
		sys.stderr.write("Issue with processing/validating focal sample's genome. Was it in FASTA or GenBank format?\n.")
		sys.exit(1)

	# create logging object
	log_file = outdir + 'Progress.log'
	logObject = util.createLoggerObject(log_file)
	version_string = util.getVersion()

	sys.stdout.write('Running atpoc version %s\n' % version_string)
	logObject.info('Running atpoc version %s' % version_string)

	# log command used
	parameters_file = outdir + 'Command_Issued.txt'
	parameters_handle = open(parameters_file, 'a+')
	parameters_handle.write(' '.join(sys.argv) + '\n')
	parameters_handle.close()

	sys.stdout.write('Parsing phage prediction results\n')
	logObject.info('Parsing phage prediction results')

	sample_phages = []
	if vibrant_results_dir != None and os.path.isdir(vibrant_results_dir):
		for root, dirs, files in os.walk(vibrant_results_dir):
			for basename in files:
				filename = os.path.join(root, basename)
				if basename.startswith('VIBRANT_integrated_prophage_coordinates_') and basename.endswith('.tsv'):
					with open(filename) as ofn:
						for i, line in enumerate(ofn):
							if i == 0: continue
							line = line.strip()
							scaff, phage_id, _, _, _, start_coord, end_coord, phage_length = line.split('\t')
							sample_phages.append(['VIBRANT', 'VB-' + phage_id, scaff, start_coord, end_coord, phage_length, ''])

	if phispy_results_dir != None and os.path.isdir(phispy_results_dir):
		for root, dirs, files in os.walk(phispy_results_dir):
			for basename in files:
				filename = os.path.join(root, basename)
				if basename == 'prophage_coordinates.tsv':
					with open(filename) as ofn:
						for i, line in enumerate(ofn):
							line = line.strip('\n')
							phage_id, scaff, start_coord, end_coord, _, _, _, _, attL_seq, attR_seq, att_explanation = line.split('\t')
							phage_length = abs(int(start_coord)-int(end_coord))
							additional_info = '; '.join(['attL_sequence=', attL_seq, 'attR_sequence=' + attR_seq, 'att_explanation=' + att_explanation])
							sample_phages.append(['PhiSpy', 'PS-' + phage_id, scaff, start_coord, end_coord, str(phage_length), additional_info])

	if genomad_results_dir != None and os.path.isdir(genomad_results_dir):
		for root, dirs, files in os.walk(genomad_results_dir):
			for basename in files:
				filename = os.path.join(root, basename)
				if basename.endswith('_virus_summary.tsv'):
					with open(filename) as ofn:
						for i, line in enumerate(ofn):
							# seq_name        length  topology        coordinates     n_genes genetic_code    virus_score     fdr     n_hallmarks     marker_enrichment       taxonomy
							if i == 0: continue
							line = line.strip('\n')
							phage_id, phage_length, topology, coords, n_genes, genetic_code, virus_score, fdr, n_hallmarks, marker_enrichment, taxonomy = line.split('\t')
							scaff = phage_id.split('|')[0]
							phage_id = phage_id.replace('|', '_')
							start_coord, end_coord = coords.split('-')
							additional_info = '; '.join(['virus_score' + virus_score, 'topology=' + topology, 'genetic_code=' + genetic_code, 'n_hallmarks=' + n_hallmarks, 'marker_enrichment=' + marker_enrichment, 'taxonomy=' + taxonomy])
							sample_phages.append(['geNomad', 'GN-' + phage_id, scaff, start_coord, end_coord, phage_length, additional_info])

	sys.stdout.write('Found %d prophage predictions!\n' % len(sample_phages))
	logObject.info('Found %d prophage predictions!' % len(sample_phages))

	fai_results_dir = outdir + 'fai_or_blast_Results/'
	util.setupReadyDirectory([fai_results_dir])
	phage_info_file = outdir + 'TemperatePhageOme_Info.tsv'
	phage_info_handle = open(phage_info_file, 'w')
	phage_lengths = []
	phage_info_handle.write('\t'.join(['phage_id', 'phage_prediction_software', 'scaffold_id', 'start', 'end', 'phage_length', 'additional_attributes']) + '\n')
	numeric_columns_t1 = ['phage_length']
	for phage in sample_phages:
		phage_software, phage_id, scaff, start_coord, end_coord, phage_length, additional_info = phage

		phage_fai_res_dir = fai_results_dir + phage_id + '/'

		phage_info_handle.write('\t'.join([str(x) for x in [phage_id, phage_software, scaff, start_coord, end_coord, phage_length, additional_info]]) + '\n')	
		phage_lengths.append(float(phage_length))

		if use_simple_blastp_flag:
			# run simple blastp
			util.setupReadyDirectory([phage_fai_res_dir])
			reference_genome_gbk = phage_fai_res_dir + 'reference.gbk'
			gene_call_cmd = ['runProdigalAndMakeProperGenbank.py', '-i', sample_genome, '-s', 'reference', '-l',
								'OG', '-o', phage_fai_res_dir]
			if use_pyrodigal_flag:
				gene_call_cmd += ['-gcm', 'pyrodigal']
			else:
				gene_call_cmd += ['-gcm', 'prodigal-gv']
			util.runCmdViaSubprocess(gene_call_cmd, logObject, check_files=[reference_genome_gbk])
			locus_genbank = phage_fai_res_dir + 'Reference_Query.gbk'
			locus_proteome = phage_fai_res_dir + 'Reference_Query.faa'
			fai.subsetGenBankForQueryLocus(reference_genome_gbk, locus_genbank, locus_proteome, scaff, int(start_coord), int(end_coord), logObject)
			util.diamondBlastAndGetBestHits(phage_id, locus_proteome, locus_proteome, preptg_db_dir, phage_fai_res_dir, logObject,
								   			identity_cutoff=simple_blastp_identity_cutoff, coverage_cutoff=simple_blastp_coverage_cutoff, 
											blastp_mode=simple_blastp_sensitivity_mode, prop_key_prots_needed=0.0,
											evalue_cutoff=simple_blastp_evalue_cutoff, cpus=cpus)
		else:
			# run fai 
			fai_cmd = ['fai', '-r', sample_genome, '-rc', scaff, '-rs', start_coord, '-re', end_coord, 
			           '-tg', preptg_db_dir, fai_options, '-o', phage_fai_res_dir, '-c', str(cpus)]
			if not use_pyrodigal_flag:
				fai_cmd += ['-rgv']
			genome_wide_tsv_result_file = phage_fai_res_dir + 'Spreadsheet_TSVs/total_gcs.tsv'
			util.runCmdViaSubprocess(fai_cmd, logObject, check_files=[genome_wide_tsv_result_file])
	phage_info_handle.close()

	tg_phage_aai = defaultdict(lambda: defaultdict(lambda: 'NA'))
	tg_phage_sgp = defaultdict(lambda: defaultdict(lambda: 'NA'))
	phages = set([])
	tg_total_phages_found = defaultdict(int)
	for phage in sample_phages:
		phage_software, phage_id, scaff, start_coord, end_coord, phage_length, additional_info = phage
		phages.add(phage_id)
		genome_wide_tsv_result_file = fai_results_dir + phage_id + '/Spreadsheet_TSVs/total_gcs.tsv'
		if use_simple_blastp_flag:
			genome_wide_tsv_result_file = fai_results_dir + phage_id + '/total_gcs.tsv'

		with open(genome_wide_tsv_result_file) as ogwtrf:
			for i, line in enumerate(ogwtrf):
				line = line.strip()
				ls = line.split('\t')
				if i == 0: continue
				tg = ls[0]
				aai = ls[2]
				sgp = ls[4]
				tg_phage_aai[tg][phage_id] = aai
				tg_phage_sgp[tg][phage_id] = sgp
				if float(aai) > 0.0:
					tg_total_phages_found[tg] += 1
	
	target_genome_listing_file = preptg_db_dir + 'Target_Genome_Annotation_Files.txt'
	result_file = outdir + 'TemperatePhageOme_to_Target_Genomes_Similarity.tsv'
	result_handle = open(result_file, 'w')
	result_handle.write('\t'.join(['target_genome', 'shared_phages', 'aggregate_score'] + [x + ' - AAI' + '\t' + x + ' - Proportion Shared Genes' for x in sorted(phages)]) + '\n')
	numeric_columns_t2 = ['shared_phages', 'aggregate_score'] + [x + ' - AAI' for x in sorted(phages)] + [x + ' - Proportion Shared Genes' for x in sorted(phages)]
	tgs = set([])
	all_scores = []
	with open(target_genome_listing_file) as otglf:
		for line in otglf:
			line = line.strip()
			ls = line.split('\t')
			tg = ls[0]
			tgs.add(tg)
			shared_phages = tg_total_phages_found[tg]
			total_score = 0.0
			for phage in sorted(phages):
				if tg_phage_aai[tg][phage] != "NA":
					total_score += float(tg_phage_aai[tg][phage])*float(tg_phage_sgp[tg][phage])
			printlist = [tg, str(shared_phages), str(total_score)]
			all_scores.append(total_score)
			for phage in sorted(phages):
				printlist.append(tg_phage_aai[tg][phage])
				printlist.append(tg_phage_sgp[tg][phage])
			result_handle.write('\t'.join(printlist) + '\n')
	result_handle.close()

	# construct spreadsheet
	atpoc_spreadsheet_file = outdir + 'ATPOC_Results.xlsx'
	writer = pd.ExcelWriter(atpoc_spreadsheet_file, engine='xlsxwriter')
	workbook = writer.book
	header_format = workbook.add_format({'bold': True, 'text_wrap': True, 'valign': 'top', 'fg_color': '#D7E4BC', 'border': 1})

	t1_results_df = util.loadTableInPandaDataFrame(phage_info_file, numeric_columns_t1)
	t1_results_df.to_excel(writer, sheet_name='Sample ProphageOme Overview', index=False, na_rep="NA")

	t2_results_df = util.loadTableInPandaDataFrame(result_file, numeric_columns_t2)
	t2_results_df.to_excel(writer, sheet_name='Similarity to Target Genomes', index=False, na_rep="NA")

	t1_worksheet = writer.sheets['Sample ProphageOme Overview']
	t2_worksheet = writer.sheets['Similarity to Target Genomes']

	t1_worksheet.conditional_format('A1:F1', {'type': 'cell', 'criteria': '!=', 'value': 'NA', 'format': header_format})
	t2_worksheet.conditional_format('A1:HA1', {'type': 'cell', 'criteria': '!=', 'value': 'NA', 'format': header_format})

	row_count = len(phages) + 1
	t2_row_count = len(tgs) + 1

	# phage length (t1)
	t1_worksheet.conditional_format('F2:F' + str(row_count), {'type': '2_color_scale', 'min_color': "#d5f0ea", 'max_color': "#83a39c", 'min_type': 'num', 'max_type': 'num', "min_value": 0, "max_value": max(phage_lengths)})

	# total phage shared (t2)
	t2_worksheet.conditional_format('B2:B' + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#d9d9d9", 'max_color': "#949494", 'min_type': 'num', 'max_type': 'num', "min_value": 0, "max_value": len(phages)})

	# aggregate score (t2)
	t2_worksheet.conditional_format('C2:C' + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#f7db94", 'max_color': "#ad8f40", 'min_type': 'num', 'max_type': 'num', "min_value": 0, "max_value": max(all_scores)})

	phage_index = 3
	for phage in phages:
		columnid_pid = util.determineColumnNameBasedOnIndex(phage_index)
		columnid_sgp = util.determineColumnNameBasedOnIndex(phage_index+1)
		phage_index += 2

		# aai
		t2_worksheet.conditional_format(columnid_pid + '2:' + columnid_pid + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#d8eaf0", 'max_color': "#83c1d4", "min_value": 0.0, "max_value": 100.0, 'min_type': 'num', 'max_type': 'num'})

		# prop genes found
		t2_worksheet.conditional_format(columnid_sgp + '2:' + columnid_sgp + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#e6bec0", 'max_color': "#f26168", "min_value": 0.0, "max_value": 1.0, 'min_type': 'num', 'max_type': 'num'})

	# Freeze the first row of both sheets
	t1_worksheet.freeze_panes(1, 0)
	t2_worksheet.freeze_panes(1, 0)

	# close workbook
	workbook.close()

	sys.stdout.write('Done running atpoc!\nFinal results can be found at:\n%s\n' % (atpoc_spreadsheet_file))
	logObject.info('Done running atpoc!\nFinal results can be found at:\n%s\n' % (atpoc_spreadsheet_file))

	# Close logging object and exit
	util.closeLoggerObject(logObject)
	sys.exit(0)

def determineColumnNameBasedOnIndex(index):
	"""
	Function to determine spreadsheet column name for a given index
	"""
	# offset at 0 
	num_to_char = {}
	alphabet = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
	alphabet_first_spot = [''] + alphabet
	for i, c in enumerate(alphabet):
		num_to_char[i] = c
	level = math.floor(index / 26)
	remainder = index % 26
	columname = alphabet_first_spot[level]
	columname += alphabet[remainder]
	return columname

if __name__ == '__main__':
	atpoc()
