#!/usr/bin/env python3

"""
Program: zol-scape
Author: Rauf Salamzade
Kalan Lab
UW Madison, Department of Medical Microbiology and Immunology
"""

# BSD 3-Clause License
#
# Copyright (c) 2023-2025, Kalan-Lab
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: 
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, 
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


from collections import defaultdict
from rich_argparse import RawTextRichHelpFormatter
from time import sleep
from zol import util
import argparse
import multiprocessing
import os
import pandas as pd
import sys

os.environ["OMP_NUM_THREADS"] = "1"
def create_parser(): 
    """ Parse arguments """
    parser = argparse.ArgumentParser(description = """
    Program: zol-scape
    Author: Rauf Salamzade
    Affiliation: Kalan Lab, UW Madison, Department of Medical Microbiology and
        Immunology

    zolscape is a wrapper which runs zol for each GCF identified by BiG-SCAPE analysis to
    complement CORASON analysis.

    You can run from start to finish as such (will run 10 zol jobs at once, each using
    4 threads): 

    zol-scape -i Input_Folder_to_BiG-SCAPE/ -r Result_Folder_to_BiG-SCAPE/ \\
              -o ZOL-SCAPE_Results/ -j 10 -c 4

    If you are have access to an HPC and would like to parallelize using that - you
    can use the --print-mode option to create a task file which lists individual zol
    commands per line which can be parallelized on an HPC via a job array: 

    zol-scape -i Input_Folder_to_BiG-SCAPE/ -r Result_Folder_to_BiG-SCAPE/ \\
              -o ZOL-SCAPE_Results/ -p -c 4

    Then, after you run the printed list of zol commands, to create a finalized
    consolidated tsv/xlsx of zol results across all GCFs, simply run the command
    again, same as before: 

    zol-scape -i Input_Folder_to_BiG-SCAPE/ -r Result_Folder_to_BiG-SCAPE/ \\
              -o ZOL-SCAPE_Results/ -p -c 4
    """, formatter_class = RawTextRichHelpFormatter)

    parser.add_argument('-i', 
        '--big-scape-input', 
        help = "Path to folder which was input into BiG-SCAPE.", 
        required = True)
    parser.add_argument('-r', 
        '--big-scape-results', 
        help = "Path to BiG-SCAPE results directory - if multiple", 
        required = True)
    parser.add_argument('-o', '--outdir', help = "Output directory.", required = True)
    parser.add_argument('-z', 
        '--zol-parameters', 
        help = "Parameters to pass to zol - please surround by double quotes [Default is ''].", 
        required = False, 
        default = "")
    parser.add_argument('-p', 
        '--print-mode', 
        action = 'store_true', 
        help = "Print zol commands - one per line - and exit - to allow for\n"
               "parallelizing on an HPC using a job array.", 
        required = False, 
        default = False)
    parser.add_argument('-c', 
        '--threads', 
        type = int, 
        help = "The number of threads to use per zol job [Default is 1].", 
        required = False, 
        default = 1)
    parser.add_argument('-j', 
        '--jobs', 
        type = int, 
        help = "The number of parallel zol jobs to run at once [Default is 1].", 
        required = False, 
        default = 1)
    args = parser.parse_args()
    return args

def zolscape(): 
    myargs = create_parser()

    bigscape_input_dir = os.path.abspath(myargs.big_scape_input) + '/'
    bigscape_results_dir = os.path.abspath(myargs.big_scape_results) + '/'
    outdir = os.path.abspath(myargs.outdir) + '/'
    zol_parameters = myargs.zol_parameters
    print_mode = myargs.print_mode
    threads = myargs.threads
    jobs = myargs.jobs

    # create output directory if needed, or warn of over-writing
    if os.path.isdir(outdir): 
        sys.stderr.write("Output directory exists. Will continue, possibly overwrite some results, but generally attempt to use checkpoints to skip steps successfuly completed...\n ")
        sleep(5)
    else: 
        util.setup_ready_directory([outdir])

    try: 
        assert(os.path.isdir(bigscape_input_dir) and os.path.isdir(bigscape_results_dir))
    except Exception as e: 
        msg = 'Either the BiG-SCAPE input directory or results directory could not be validated as a directory - please check the paths provided!'
        sys.stderr.write(msg + '\n')
        sys.exit(1)

    # create logging object
    log_file = outdir + 'Progress.log'
    log_object = util.create_logger_object(log_file)
    version_string = util.get_version()

    sys.stdout.write(f'Running zol-scape version {version_string}\n')
    log_object.info(f'Running zol-scape version {version_string}')

    # log command used
    parameters_file = outdir + 'Command_Issued.txt'
    parameters_handle = open(parameters_file, 'a+')
    parameters_handle.write(' '.join(sys.argv) + '\n')
    parameters_handle.close()

    # Step 1: Parse BGCs from input dir
    msg = '--------------------\nStep 1: Determine BGC GenBank paths in BiG-SCAPE input directory\n--------------------'

    bgc_ids_to_paths = defaultdict(lambda: None)
    for currentpath, folders, files in os.walk(bigscape_input_dir): 
        for file in files: 
            if file.endswith('.gbk'): 
                gbk_file = os.path.join(currentpath, file)
                bgc_id = '.'.join(file.split('.')[: -1])
                bgc_ids_to_paths[bgc_id] = gbk_file

    # Step 2: Parse GCFs from results dir
    msg = '--------------------\nStep 2: Parse GCF clustering info from BiG-SCAPE results directory\n--------------------'

    cluster_gcf_bgcs = defaultdict(lambda: defaultdict(list))
    for currentpath, folders, files in os.walk(bigscape_results_dir): 
        for file in files: 
            if '_clustering_' in file and file.endswith('.tsv'): 
                cluster_file = currentpath + '/' + file
                with open(cluster_file) as ocf: 
                    for line in ocf: 
                        if line.startswith('#'): continue
                        line = line.strip()
                        bgc_id, gcf_id = line.split('\t')
                        cluster_gcf_bgcs[cluster_file][gcf_id].append(bgc_id)

    # Step 3: Run zol or create zol task file
    zol_results_dir = outdir + 'zol_Results/'
    if not os.path.isdir(zol_results_dir): 
        util.setup_ready_directory([zol_results_dir], delete_if_exist = True)

    zol_results = 0
    for item in os.listdir(zol_results_dir): 
        zol_results += 1

    if zol_results == 0: 
        zol_cmds = []
        for i, cluster in enumerate(sorted(cluster_gcf_bgcs)): 
            for gcf in cluster_gcf_bgcs[cluster]: 
                uniq_gcf_id = 'ClusterFile_' + str(i) + '_' + gcf
                output_dir = zol_results_dir + uniq_gcf_id + '/'
                bgc_ids_in_gcf = 0
                bgc_gbks_found = 0
                input_gbks = []
                for bgc in cluster_gcf_bgcs[cluster][gcf]: 
                    bgc_ids_in_gcf += 1
                    if bgc in bgc_ids_to_paths: 
                        bgc_gbk_file = bgc_ids_to_paths[bgc]
                        input_gbks.append(bgc_gbk_file)
                        bgc_gbks_found += 1
                    else: 
                        msg = f'Warning: unable to find GenBank for BGC {bgc}'
                        sys.stderr.write(msg + '\n')
                        log_object.warning(msg)
                msg = f'For clustering results {cluster} for GCF {gcf},\n'
                msg += f'{bgc_ids_in_gcf} BGCs were reported and {bgc_gbks_found} BGC GenBank files were found.\n'
                msg += 'Missing BGC GenBank files could be because they are MIBiG reference\n'
                msg += 'BGCs if BiG-SCAPE was run using the --mibig arugment.'
                sys.stderr.write(msg + '\n')
                log_object.warning(msg)
                if len(input_gbks) > 0: 
                    zol_cmd = ['zol', 
            '-c', 
            str(threads), 
            zol_parameters, 
            '-o', 
            output_dir, 
            '-i'] + input_gbks + [log_object]
                    zol_cmds.append(zol_cmd)
                else: 
                    msg = f'Warning: no BGC GenBanks found for GCF {gcf} in clustering results file {cluster}'
                    sys.stderr.write(msg + '\n')
                    log_object.warning(msg)

        if print_mode: 
            outf = open(outdir + 'zol.cmds')
            outf.write('\n'.join([' '.join(cmd) for cmd in zol_cmd]) + '\n')
            outf.close()
            msg = f'Wrote zol commands for individual GCFs to {outdir + "zol.cmds"}'
            log_object.info(msg)
            sys.stderr.write(msg + '\n')
            sys.exit(0)
        else: 
            p = multiprocessing.Pool(jobs)
            p.map(util.multi_process, zol_cmds)
            p.close()
    else: 
        msg = 'Note, something was found in the zol result\'s directory from a previous run, will attempt to use available results.'
        sys.stderr.write(msg + '\n')
        log_object.info(msg)

    consolidated_table_file = outdir + 'zol-scape_Results.tsv'
    zctf_handle = open(consolidated_table_file, 'w')

    zol_sheet_header = ['Clustering File Path', 
                        'GCF ID', 
                        'Ortholog Group (OG) ID', 
                        'OG is Single Copy?', 
                        'Proportion of Total Gene Cluster Instances with OG', 
                        'OG Median Length (bp)', 'OG Consensus Order', 
                        'OG Consensus Direction', 'Tajima\'s D', 'Proportion of Filtered Codon Alignment is Segregating Sites', 
                        'Entropy', 'Upstream Region Entropy', 'Median Beta-RD-gc', 'Max Beta-RD-gc', 
                        'Proportion of sites which are highly ambiguous in codon alignment', 
                        'Proportion of sites which are highly ambiguous in trimmed codon alignment', 'Median GC', 'Median GC Skew',
                        'BGC score (GECCO weights)', 'Viral score (V-Score)']

    domain_mode = False
    if ' -dom ' in zol_parameters or ' --domain-mode ' in zol_parameters: 
        domain_mode = True
        zol_sheet_header = ['Clustering File Path', 
                            'GCF ID', 
                            'Ortholog Group (OG) ID', 
                            'OG is Single Copy?', 
                            'Proportion of Total Gene Cluster Instances with OG', 
                            'Proportion of Complete Gene Cluster Instances with OG', 
                            'OG Median Length (bp)', 
                            'Single-Linkage Full Protein Cluster', 
                            'OG Consensus Order', 'OG Consensus Direction', 'Tajima\'s D', 'Proportion of Filtered Codon Alignment is Segregating Sites', 
                            'Entropy', 'Upstream Region Entropy', 'Median Beta-RD-gc', 'Max Beta-RD-gc', 
                            'Proportion of sites which are highly ambiguous in codon alignment', 
                            'Proportion of sites which are highly ambiguous in trimmed codon alignment', 'Median GC', 'Median GC Skew', 
                            'BGC score (GECCO weights)', 'Viral score (V-Score)']

    if '-s' in zol_parameters or '--selection-analysis' in zol_parameters: 
        zol_sheet_header += ['GARD Partitions Based on Recombination Breakpoints', 
                             'Number of Sites Identified as Under Positive or Negative Selection by FUBAR', 
                             'Average delta(Beta, Alpha) by FUBAR across sites', 
                             'Proportion of Sites Under Selection which are Positive',
                             'P-value for gene-wide episodic selection by BUSTED']

    zol_sheet_header += ['KO Annotation (E-value)', 'PGAP Annotation (E-value)', 'PaperBLAST Annotation (E-value)', 
                          'CARD Annotation (E-value)', 'IS Finder (E-value)', 
                          'MIBiG Annotation (E-value)', 'VOG Annotation (E-value)', 
                          'VFDB Annotation (E-value)', 'Pfam Domains', 'CDS Locus Tags', 
                          'OG Consensus Sequence']

    zctf_handle.write('\t'.join(zol_sheet_header) + '\n')
    num_rows = 1
    for i, cluster in enumerate(sorted(cluster_gcf_bgcs)): 
        for gcf in cluster_gcf_bgcs[cluster]: 
            uniq_gcf_id = 'ClusterFile_' + str(i) + '_' + gcf
            zol_res_dir = zol_results_dir + uniq_gcf_id + '/'
            # ^ basically added two columns (Clustering file ID and GCF ID) and took away one (custom db annotation)
            gcf_result_file = zol_res_dir + '/Final_Results/Consolidated_Report.tsv'
            if not os.path.isfile(gcf_result_file): continue
            with open(gcf_result_file) as ogrf: 
                for i, line in enumerate(ogrf): 
                    if i == 0: continue
                    line = line.strip()
                    ls = line.split('\t')
                    row = [cluster, gcf] + ls[:18] + ls[19: ]
                    if ' -s ' in zol_parameters or ' --selection-analysis ' in zol_parameters: 
                        row = [cluster, gcf] + ls[: 20] + ls[21: ]
                    zctf_handle.write('\t'.join(row) + '\n')
                    num_rows += 1
    zctf_handle.close()

    zr_numeric_columns = set(['Proportion of Total Gene Cluster Instances with OG', 
                              'Proportion of Complete Gene Cluster Instances with OG', 
                              'OG Median Length (bp)', 'OG Consensus Order', 
                              'Tajima\'s D', 'GARD Partitions Based on Recombination Breakpoints', 
                              'GARD Partitions Based on Recombination Breakpoints', 
                              'Number of Sites Identified as Under Positive or Negative Selection by FUBAR', 
                              'Average delta(Beta, Alpha) by FUBAR across sites', 
                              'Proportion of Sites Under Selection which are Positive', 
                              'Proportion of Filtered Codon Alignment is Segregating Sites', 
                              'Entropy', 'Upstream Region Entropy', 'Median Beta-RD-gc', 'Max Beta-RD-gc', 
                              'Proportion of sites which are highly ambiguous in codon alignment', 
                              'Proportion of sites which are highly ambiguous in trimmed codon alignment', 'Median GC', 'Median GC Skew', 
                              'BGC score (GECCO weights)', 'Viral score (V-Score)'])

    zr_data = util.load_table_in_panda_data_frame(consolidated_table_file, 
        zr_numeric_columns)

    # construct spreadsheet
    zol_spreadsheet_file = outdir + 'zol-scale_Results.xlsx'
    writer = pd.ExcelWriter(zol_spreadsheet_file, engine = 'xlsxwriter')
    workbook = writer.book

    warn_format = workbook.add_format({'bg_color': '#bf241f', 
        'bold': True, 
        'font_color': '#FFFFFF'})
    na_format = workbook.add_format({'font_color': '#a6a6a6', 
        'bg_color': '#FFFFFF', 
        'italic': True})
    header_format = workbook.add_format({'bold': True, 
        'text_wrap': True, 
        'valign': 'top', 
        'fg_color': '#D7E4BC', 
        'border': 1})

    zr_data.to_excel(writer, 
        sheet_name = 'zol-scape Results', 
        index = False, 
        na_rep = "NA")
    zr_sheet = writer.sheets['zol-scape Results']

    zr_sheet.conditional_format('D2:D' + str(num_rows), 
        {'type': 'cell', 
        'criteria': '==', 
        'value': '"False"', 
        'format': warn_format})
    zr_sheet.conditional_format('A2:CA' + str(num_rows), 
        {'type': 'cell', 
        'criteria': '==', 
        'value': '"NA"', 
        'format': na_format})
    zr_sheet.conditional_format('A1:CA1', 
        {'type': 'cell', 
        'criteria': '!=', 
        'value': 'NA', 
        'format': header_format})

    # prop gene-clusters with hg
    zr_sheet.conditional_format('E2:E' + str(num_rows), 
        {'type': '2_color_scale', 
        'min_color': "#f7de99", 
        'max_color': "#c29006", 
        "min_value": 0.0, 
        "max_value": 1.0, 
        'min_type': 'num', 
        'max_type': 'num'})

    # gene-lengths
    zr_sheet.conditional_format('F2:F' + str(num_rows), 
        {'type': '2_color_scale', 
        'min_color': "#a3dee3", 
        'max_color': "#1ebcc9", 
        "min_value": 100, 
        "max_value": 2500, 
        'min_type': 'num', 
        'max_type': 'num'})

    hfo = ['I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T']
    if domain_mode: 
        hfo = ['J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U']
    # taj-d
    zr_sheet.conditional_format(hfo[0] + '2:' + hfo[0] + str(num_rows), 
                                    {'type': '3_color_scale', 'min_color': "#f7a09c", "mid_color": "#e0e0e0", 'min_type': 'num', 'max_type': 'num', 'mid_type': 'num', 
                                    'max_color': "#87cefa", "min_value": -2.0, "mid_value": 0.0, "max_value": 2.0})

    # prop seg sites
    zr_sheet.conditional_format(hfo[1] + '2:' + hfo[1] + str(num_rows), 
                                    {'type': '2_color_scale', 'min_color': "#eab3f2", 'min_type': 'num', 'max_type': 'num', 
                                    'max_color': "#a37ba8", "min_value": 0.0, "max_value": 1.0})

    # entropy
    zr_sheet.conditional_format(hfo[2] + '2:' + hfo[2] + str(num_rows), 
                                    {'type': '2_color_scale', 'min_color': "#f7a8bc", 'min_type': 'num', 'max_type': 'num', 
                                    'max_color': "#fa6188", "min_value": 0.0, "max_value": 1.0})

    # upstream region entropy
    zr_sheet.conditional_format(hfo[3] + '2:' + hfo[3] + str(num_rows), 
                                    {'type': '2_color_scale', 'min_color': "#f7a8bc", 'min_type': 'num', 'max_type': 'num', 
                                    'max_color': "#fa6188", "min_value": 0.0, "max_value": 1.0})

    # median beta-rd gc
    zr_sheet.conditional_format(hfo[4] + '2:' + hfo[4] + str(num_rows), 
                                    {'type': '3_color_scale', 'min_color': "#fac087", "mid_color": "#e0e0e0", 'min_type': 'num', 'max_type': 'num', 'mid_type': 'num', 
                                    'max_color': "#9eb888", "min_value": 0.75, "mid_value": 1.0, "max_value": 1.25})
    # max beta-rd gc
    zr_sheet.conditional_format(hfo[5] + '2:' + hfo[5] + str(num_rows), 
                                    {'type': '3_color_scale', 'min_color': "#fac087", "mid_color": "#e0e0e0", 'min_type': 'num', 'max_type': 'num', 'mid_type': 'num', 
                                    'max_color': "#9eb888", "min_value": 0.75, "mid_value": 1.0, "max_value": 1.25})

    # ambiguity full ca
    zr_sheet.conditional_format(hfo[6] + '2:' + hfo[6] + str(num_rows), 
                                    {'type': '2_color_scale', 'min_color': "#ed8c8c", 'min_type': 'num', 'max_type': 'num', 
                                    'max_color': "#ab1616", "min_value": 0.0, "max_value": 1.0})

    # ambiguity trim ca
    zr_sheet.conditional_format(hfo[7] + '2:' + hfo[7] + str(num_rows), 
                                    {'type': '2_color_scale', 'min_color': "#ed8c8c", 'min_type': 'num', 'max_type': 'num', 
                                    'max_color': "#ab1616", "min_value": 0.0, "max_value": 1.0})

    # GC
    zr_sheet.conditional_format(hfo[8] + '2:' + hfo[8] + str(num_rows), 
                                    {'type': '2_color_scale', 'min_color': "#abffb7", 'min_type': 'num', 'max_type': 'num', 
                                    'max_color': "#43bf55", "min_value": 0.0, "max_value": 1.0})

    # GC Skew
    zr_sheet.conditional_format(hfo[9] + '2:' + hfo[9] + str(num_rows), 
                                    {'type': '2_color_scale', 'min_color': "#c7afb4", 'min_type': 'num', 'max_type': 'num', 
                                    'max_color': "#965663", "min_value": -2.0, "max_value": 2.0})


    # BGC score
    zr_sheet.conditional_format(hfo[10] + "2:" + hfo[10] + str(num_rows),
                {"type": "2_color_scale", "min_color": "#f5aca4", "min_type": "num", "max_type": "num", "max_color": "#c75246", "min_value": -7.0, "max_value": 13.0},
            )

    # viral score
    zr_sheet.conditional_format(hfo[11] + "2:" + hfo[11] + str(num_rows),
                {"type": "2_color_scale", "min_color": "#dfccff", "min_type": "num", "max_type": "num", "max_color": "#715a96", "min_value": 0.0, "max_value": 10.0},
            )

    # Freeze the first row of both sheets
    zr_sheet.freeze_panes(1, 0)

    # close workbook
    workbook.close()

    sys.stdout.write(f'Done running zol-scape!\nFinal spreadsheet can be found at: \n{zol_spreadsheet_file}\n')
    log_object.info(f'Done running zol-scape!\nFinal spreadsheet can be found at: \n{zol_spreadsheet_file}\n')

    # Close logging object and exit
    util.close_logger_object(log_object)
    sys.exit(0)

if __name__ == '__main__': 
    zolscape()
