#!/usr/bin/env python3

### Program: apos
### Author: Rauf Salamzade
### Kalan Lab
### UW Madison, Department of Medical Microbiology and Immunology

# BSD 3-Clause License
#
# Copyright (c) 2023, Kalan-Lab
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import sys
import argparse
import subprocess
from time import sleep
from zol import util, fai
import _pickle as cPickle
from Bio import SeqIO
from operator import itemgetter
from collections import defaultdict
import traceback
import pandas as pd
import math 

def create_parser():
	""" Parse arguments """
	parser = argparse.ArgumentParser(description="""
	Program: apos
	Author: Rauf Salamzade
	Affiliation: Kalan Lab, UW Madison, Department of Medical Microbiology and Immunology

	apos - Assess Plasmid-Ome Similarity 
	
	apos wraps fai to assess the conservation of a sample's plasmid-ome 
	relative to a set of target genomes (e.g. genomes belonging to the same genus). Alternatively, 
	it can run a simple DIAMOND BLASTp analysis to just assess the presence of plasmid genes
	individually - without the requirement they are co-located in one scaffold like in the focal sample.
	""", formatter_class=argparse.RawTextHelpFormatter)

	parser.add_argument('-i', '--sample_genome', help='Path to sample genome in GenBank or FASTA format.', required=True, default=True)
	parser.add_argument('-ms', '--mobsuite_results', help='Path to MOB-suite (mob_recon) results directory for a single sample/genome.', required=False, default=None)
	parser.add_argument('-gn', '--genomad_results', help='Path to GeNomad results directory for a single sample/genome.', required=False, default=None)
	parser.add_argument('-tg', '--target_genomes_db', help='prepTG database directory for target genomes of interest.', required=True)
	parser.add_argument('-fo', '--fai_options', help='Provide fai options to run. Should be surrounded by quotes. [Default is "-e 1e-10 -m 0.5 -dm -sct 0.0"]', required=False, default="-e 1e-10 -m 0.5 -dm -sct 0.0")
	parser.add_argument('-s', '--use_simple_blastp', action='store_true', help='Use a simple DIAMOND BLASTp search with no requirement for co-localization of hits.', required=False, default=False)
	parser.add_argument('-si', '--simple_blastp_identity_cutoff', type=float, help='If simple BLASTp mode requested : cutoff for identity between query proteins and matches in target genomes [Default is 40.0].', required=False, default=40.0)
	parser.add_argument('-sc', '--simple_blastp_coverage_cutoff', type=float, help='If simple BLASTp mode requested : cutoff for coverage between query proteins and matches in target genomes [Default is 70.0].', required=False, default=70.0)
	parser.add_argument('-se', '--simple_blastp_evalue_cutoff', type=float, help='If simple BLASTp mode requested : cutoff for E-value between query proteins and matches in target genomes [Default is 1e-5].', required=False, default=1e-5)
	parser.add_argument('-sm', '--simple_blastp_sensitivity_mode', help='Sensitivity mode for DIAMOND BLASTp. [Default is "very-sensititve"].', required=False, default="very-sensitive")
	parser.add_argument('-o', '--outdir', help='Output directory.', required=True)
	parser.add_argument('-c', '--cpus', type=int, help='The number of CPUs to use [Default is 1].', required=False, default=1)

	args = parser.parse_args()
	return args

def apos():
	myargs = create_parser()

	sample_genome = myargs.sample_genome
	mobsuite_results_dir = myargs.mobsuite_results
	genomad_results_dir = myargs.genomad_results
	preptg_db_dir = myargs.target_genomes_db
	outdir = os.path.abspath(myargs.outdir) + '/'
	fai_options = myargs.fai_options
	use_simple_blastp_flag = myargs.use_simple_blastp
	simple_blastp_identity_cutoff = myargs.simple_blastp_identity_cutoff
	simple_blastp_coverage_cutoff = myargs.simple_blastp_coverage_cutoff
	simple_blastp_evalue_cutoff = myargs.simple_blastp_evalue_cutoff
	simple_blastp_sensitivity_mode = myargs.simple_blastp_sensitivity_mode
	cpus = myargs.cpus

	# create output directory if needed, or warn of over-writing
	if os.path.isdir(outdir):
		sys.stderr.write("Output directory exists. Overwriting in 5 seconds ...\n ")
		sleep(5)
	util.setupReadyDirectory([outdir])

	if (mobsuite_results_dir == None or not os.path.isdir(mobsuite_results_dir)) and (genomad_results_dir == None or not os.path.isdir(genomad_results_dir)):
		sys.stderr.write("Neither MOB-suite nor geNomad sample results directory provided as input or the paths are faulty!\n")
		sys.exit(1)

	if util.is_genbank(sample_genome):
		try:
			sample_genome_fasta = outdir + '.'.join(sample_genome.split('/')[-1].split('.')[:-1]) + '.fna'
			util.convertGenomeGenBankToFasta([sample_genome, sample_genome_fasta])
			assert(os.path.isfile(sample_genome_fasta) and util.is_fasta(sample_genome_fasta))
			sample_genome = sample_genome_fasta
		except:
			sys.stderr.write("Issue converting sample's genome from GenBank to FASTA format.\n")
			sys.exit(1)

	if not util.is_fasta(sample_genome):
		sys.stderr.write("Issue with processing/validating focal sample's genome. Was it in FASTA or GenBank format?\n.")
		sys.exit(1)

	# create logging object
	log_file = outdir + 'Progress.log'
	logObject = util.createLoggerObject(log_file)
	version_string = util.getVersion()

	sys.stdout.write('Running apos version %s\n' % version_string)
	logObject.info('Running apos version %s' % version_string)

	# log command used
	parameters_file = outdir + 'Command_Issued.txt'
	parameters_handle = open(parameters_file, 'a+')
	parameters_handle.write(' '.join(sys.argv) + '\n')
	parameters_handle.close()

	sys.stdout.write('Parsing plasmid prediction results\n')
	logObject.info('Parsing plasmid prediction results')

	sample_plasmids = []
	if mobsuite_results_dir != None and os.path.isdir(mobsuite_results_dir):
		scaffold_ids = set([])
		with open(sample_genome) as osg:
			for rec in SeqIO.parse(osg, 'fasta'):
				if rec.id in scaffold_ids: 
					sys.stderr.write('Issues with multiple FASTA headers (up to first whitespace) in sample\'s genome being identical.\n')
					logObject.error('Issues with multiple FASTA headers (up to first whitespace) in sample\'s genome being identical.')
					sys.exit(1)
				scaffold_ids.add(rec.id)

		for root, dirs, files in os.walk(mobsuite_results_dir):
			for basename in files:
				filename = os.path.join(root, basename)
				if basename == 'contig_report.txt':
					with open(filename) as ofn:
						for i, line in enumerate(ofn):
							if i == 0: continue
							line = line.strip()
							ls = line.split('\t')
							if ls[1] != 'plasmid': continue
							primary_cluster_id = ls[2]
							secondary_cluster_id = ls[3]
							plasmid_id = ls[4]
							scaff = None
							possible_matches = 0
							for scaff_id in scaffold_ids:
								if plasmid_id.startswith(scaff_id):
									scaff = scaff_id
									possible_matches += 1
							try:
								assert(possible_matches == 1)
							except:
								sys.stderr.write('Issues with matching MOB-suite contig_id %s to a single record in the sample\'s genome.\n' % plasmid_id)
								logObject.error('Issues with matching MOB-suite contig_id %s to a single record in the sample\'s genome.' % plasmid_id)
								sys.exit(1)
							plasmid_length = ls[5]
							gc = ls[6]
							circularity_status = ls[6]
							replicon_types = ls[7]
							relaxase_types = ls[9]
							mpf_types = ls[11]
							orit_types = ls[13]
							predicted_mobility = ls[15]
							additional_info = '; '.join(['primary_cluster_id=', primary_cluster_id, 'secondary_cluster_id=' + secondary_cluster_id, 'gc=' + gc, 
									                     'circularity_status=' + circularity_status, 'replicon_types=' + replicon_types,
									                     'relaxase_types=' + relaxase_types, 'mpf_types=' + mpf_types, 'orit_types=' + orit_types,
														 'predicted_mobility=' + predicted_mobility])
							sample_plasmids.append(['MOB-suite', 'MS-' + plasmid_id, scaff, plasmid_length, additional_info])

	if genomad_results_dir != None and os.path.isdir(genomad_results_dir):
		for root, dirs, files in os.walk(genomad_results_dir):
			for basename in files:
				filename = os.path.join(root, basename)
				if basename.endswith('_plasmid_summary.tsv'):
					with open(filename) as ofn:
						for i, line in enumerate(ofn):
							if i == 0: continue
							line = line.strip('\n')
							plasmid_id, plasmid_length, topology, n_genes, genetic_code, plasmid_score, fdr, n_hallmarks, marker_enrichment, conjugation_genes, amr_genes = line.split('\t')
							plasmid_id = plasmid_id.replace('|', '_')
							additional_info = '; '.join(['plasmid_score' + plasmid_score, 'topology=' + topology, 'genetic_code=' + genetic_code, 'n_hallmarks=' + n_hallmarks, 'marker_enrichment=' + marker_enrichment, 'conjugation_genes=' + conjugation_genes, 'amr_genes=' + amr_genes])
							sample_plasmids.append(['geNomad', 'GN-' + plasmid_id, plasmid_id, plasmid_length, additional_info])

	sys.stdout.write('Found %d plasmid predictions!\n' % len(sample_plasmids))
	logObject.info('Found %d plasmid predictions!' % len(sample_plasmids))

	fai_results_dir = outdir + 'fai_or_blast_Results/'
	util.setupReadyDirectory([fai_results_dir])
	plasmid_info_file = outdir + 'TemperateplasmidOme_Info.tsv'
	plasmid_info_handle = open(plasmid_info_file, 'w')
	plasmid_lengths = []
	plasmid_info_handle.write('\t'.join(['plasmid_id', 'plasmid_prediction_software', 'scaffold_id', 'plasmid_length', 'additional_attributes']) + '\n')
	numeric_columns_t1 = ['plasmid_length']
	for plasmid in sample_plasmids:
		plasmid_software, plasmid_id, scaff, plasmid_length, additional_info = plasmid

		plasmid_fai_res_dir = fai_results_dir + plasmid_id + '/'

		plasmid_info_handle.write('\t'.join([str(x) for x in [plasmid_id, plasmid_software, scaff, plasmid_length, additional_info]]) + '\n')	
		plasmid_lengths.append(float(plasmid_length))

		start_coord = '1'
		end_coord = str(plasmid_length)
		if use_simple_blastp_flag:
			# run simple blastp
			util.setupReadyDirectory([plasmid_fai_res_dir])
			reference_genome_gbk = plasmid_fai_res_dir + 'reference.gbk'
			gene_call_cmd = ['runProdigalAndMakeProperGenbank.py', '-i', sample_genome, '-s', 'reference', '-l',
								'OG', '-o', plasmid_fai_res_dir]
			util.runCmdViaSubprocess(gene_call_cmd, logObject, check_files=[reference_genome_gbk])
			locus_genbank = plasmid_fai_res_dir + 'Reference_Query.gbk'
			locus_proteome = plasmid_fai_res_dir + 'Reference_Query.faa'
			fai.subsetGenBankForQueryLocus(reference_genome_gbk, locus_genbank, locus_proteome, scaff, int(start_coord), int(end_coord), logObject)
			util.diamondBlastAndGetBestHits(plasmid_id, locus_proteome, locus_proteome, preptg_db_dir, plasmid_fai_res_dir, logObject,
								   			identity_cutoff=simple_blastp_identity_cutoff, coverage_cutoff=simple_blastp_coverage_cutoff, 
											blastp_mode=simple_blastp_sensitivity_mode, prop_key_prots_needed=0.0,
											evalue_cutoff=simple_blastp_evalue_cutoff, cpus=cpus)
		else:
			# run fai 
			fai_cmd = ['fai', '-r', sample_genome, '-rc', scaff, '-rs', start_coord, '-re', end_coord, 
			           '-tg', preptg_db_dir, fai_options, '-o', plasmid_fai_res_dir, '-c', str(cpus)]
			genome_wide_tsv_result_file = plasmid_fai_res_dir + 'Spreadsheet_TSVs/total_gcs.tsv'
			util.runCmdViaSubprocess(fai_cmd, logObject, check_files=[genome_wide_tsv_result_file])
	plasmid_info_handle.close()

	tg_plasmid_aai = defaultdict(lambda: defaultdict(lambda: 'NA'))
	tg_plasmid_sgp = defaultdict(lambda: defaultdict(lambda: 'NA'))
	plasmids = set([])
	tg_total_plasmids_found = defaultdict(int)
	for plasmid in sample_plasmids:
		plasmid_software, plasmid_id, scaff, plasmid_length, additional_info = plasmid
		plasmids.add(plasmid_id)
		genome_wide_tsv_result_file = fai_results_dir + plasmid_id + '/Spreadsheet_TSVs/total_gcs.tsv'
		if use_simple_blastp_flag:
			genome_wide_tsv_result_file = fai_results_dir + plasmid_id + '/total_gcs.tsv'

		with open(genome_wide_tsv_result_file) as ogwtrf:
			for i, line in enumerate(ogwtrf):
				line = line.strip()
				ls = line.split('\t')
				if i == 0: continue
				tg = ls[0]
				aai = ls[2]
				sgp = ls[4]
				tg_plasmid_aai[tg][plasmid_id] = aai
				tg_plasmid_sgp[tg][plasmid_id] = sgp
				if float(aai) > 0.0:
					tg_total_plasmids_found[tg] += 1
	
	target_genome_listing_file = preptg_db_dir + 'Target_Genome_Annotation_Files.txt'
	result_file = outdir + 'TemperateplasmidOme_to_Target_Genomes_Similarity.tsv'
	result_handle = open(result_file, 'w')
	result_handle.write('\t'.join(['target_genome', 'shared_plasmids', 'aggregate_score'] + [x + ' - AAI' + '\t' + x + ' - Proportion Shared Genes' for x in sorted(plasmids)]) + '\n')
	numeric_columns_t2 = ['shared_plasmids', 'aggregate_score'] + [x + ' - AAI' for x in sorted(plasmids)] + [x + ' - Proportion Shared Genes' for x in sorted(plasmids)]
	tgs = set([])
	all_scores = []
	with open(target_genome_listing_file) as otglf:
		for line in otglf:
			line = line.strip()
			ls = line.split('\t')
			tg = ls[0]
			tgs.add(tg)
			shared_plasmids = tg_total_plasmids_found[tg]
			total_score = 0.0
			for plasmid in sorted(plasmids):
				if tg_plasmid_aai[tg][plasmid] != "NA":
					total_score += float(tg_plasmid_aai[tg][plasmid])*float(tg_plasmid_sgp[tg][plasmid])
			printlist = [tg, str(shared_plasmids), str(total_score)]
			all_scores.append(total_score)
			for plasmid in sorted(plasmids):
				printlist.append(tg_plasmid_aai[tg][plasmid])
				printlist.append(tg_plasmid_sgp[tg][plasmid])
			result_handle.write('\t'.join(printlist) + '\n')
	result_handle.close()

	# construct spreadsheet
	apos_spreadsheet_file = outdir + 'APOS_Results.xlsx'
	writer = pd.ExcelWriter(apos_spreadsheet_file, engine='xlsxwriter')
	workbook = writer.book
	header_format = workbook.add_format({'bold': True, 'text_wrap': True, 'valign': 'top', 'fg_color': '#D7E4BC', 'border': 1})

	t1_results_df = util.loadTableInPandaDataFrame(plasmid_info_file, numeric_columns_t1)
	t1_results_df.to_excel(writer, sheet_name='Sample PlasmidOme Overview', index=False, na_rep="NA")

	t2_results_df = util.loadTableInPandaDataFrame(result_file, numeric_columns_t2)
	t2_results_df.to_excel(writer, sheet_name='Similarity to Target Genomes', index=False, na_rep="NA")

	t1_worksheet = writer.sheets['Sample PlasmidOme Overview']
	t2_worksheet = writer.sheets['Similarity to Target Genomes']

	t1_worksheet.conditional_format('A1:F1', {'type': 'cell', 'criteria': '!=', 'value': 'NA', 'format': header_format})
	t2_worksheet.conditional_format('A1:HA1', {'type': 'cell', 'criteria': '!=', 'value': 'NA', 'format': header_format})

	row_count = len(plasmids) + 1
	t2_row_count = len(tgs) + 1

	# plasmid length (t1)
	t1_worksheet.conditional_format('D2:D' + str(row_count), {'type': '2_color_scale', 'min_color': "#d5f0ea", 'max_color': "#83a39c", 'min_type': 'num', 'max_type': 'num', "min_value": 0, "max_value": max(plasmid_lengths)})

	# total plasmid shared (t2)
	t2_worksheet.conditional_format('B2:B' + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#d9d9d9", 'max_color': "#949494", 'min_type': 'num', 'max_type': 'num', "min_value": 0, "max_value": len(plasmids)})

	# aggregate score (t2)
	t2_worksheet.conditional_format('C2:C' + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#f7db94", 'max_color': "#ad8f40", 'min_type': 'num', 'max_type': 'num', "min_value": 0, "max_value": max(all_scores)})

	plasmid_index = 3
	for plasmid in plasmids:
		columnid_pid = util.determineColumnNameBasedOnIndex(plasmid_index)
		columnid_sgp = util.determineColumnNameBasedOnIndex(plasmid_index+1)
		plasmid_index += 2

		# aai
		t2_worksheet.conditional_format(columnid_pid + '2:' + columnid_pid + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#d8eaf0", 'max_color': "#83c1d4", "min_value": 0.0, "max_value": 100.0, 'min_type': 'num', 'max_type': 'num'})

		# prop genes found
		t2_worksheet.conditional_format(columnid_sgp + '2:' + columnid_sgp + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#e6bec0", 'max_color': "#f26168", "min_value": 0.0, "max_value": 1.0, 'min_type': 'num', 'max_type': 'num'})

	# Freeze the first row of both sheets
	t1_worksheet.freeze_panes(1, 0)
	t2_worksheet.freeze_panes(1, 0)

	# close workbook
	workbook.close()

	sys.stdout.write('Done running apos!\nFinal results can be found at:\n%s\n' % (apos_spreadsheet_file))
	logObject.info('Done running apos!\nFinal results can be found at:\n%s\n' % (apos_spreadsheet_file))

	# Close logging object and exit
	util.closeLoggerObject(logObject)
	sys.exit(0)



if __name__ == '__main__':
	apos()
