#!/usr/bin/env python

"""
This is a helper program of GToTree (https://github.com/AstrobioMike/GToTree/wiki)
to facilitate generating single-copy gene HMMs.

For examples, please visit the GToTree wiki here: https://github.com/AstrobioMike/GToTree/wiki/example-usage
For details on the process, please see: https://github.com/AstrobioMike/GToTree/wiki/SCG-sets
"""

VERSION = "v1.6.35"
import os
import sys
import argparse
import textwrap
import shutil
import urllib.request
import pandas as pd
import numpy as np
import gzip
import io
import subprocess
import filecmp
import multiprocessing as mp
from Bio import SeqIO
from functools import partial
from math import ceil
from glob import glob

print("\n\n                                 GToTree " + VERSION)
print("                        (github.com/AstrobioMike/GToTree)\n")
print(" ------------------------------------------------------------------------------- ")

parser = argparse.ArgumentParser(description="This program takes a list of NCBI accessions and generates a set of single-copy gene HMMs based on PFams.\
                                              Please see the wiki here for details on the process: https://github.com/AstrobioMike/GToTree/wiki/SCG-sets")

required = parser.add_argument_group('required arguments')

required.add_argument("-a", "--target-accessions", help="fasta file", action="store", required=True)
parser.add_argument("-o", "--output-directory", help='Name of desired output directory (default: "gtt-gen-SCG-HMMs-output")', action="store", default="gtt-gen-SCG-HMMs-output", dest="output_dir")
parser.add_argument("-n", "--num-cpus", help='Number of cpus to use on parallelizable steps (default: 1)', action="store", default=1, type=int)
parser.add_argument("-p", "--percent-single-copy", help='The percent of target genomes that need to have exactly 1 copy identified for an HMM to be included, needs to be between 1 and 100 (default: 90)', action="store", default=90, type=int)
parser.add_argument("--rerun", help="Specify this flag if you'd like to try to restart a prior run from any currently existing files", action="store_true")
parser.add_argument("--keep-working-dir", help="Specify this flag if you'd like to keep the working directory", action="store_true")

if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()

################################################################################

def main():

    percent_single_copy = check_input_percent(args.percent_single_copy)

    tmp_dir, output_dir, rerun = setup_output_dir(args.target_accessions, args.output_dir, args.rerun)

    num_cpus = check_cpus(args.num_cpus)

    target_accs, needed_dbs = get_target_accs(args.target_accessions)

    ncbi_tab, accs_not_found = get_ncbi_tabs(needed_dbs, target_accs, tmp_dir, output_dir, rerun)

    remaining_accs = get_amino_acids(ncbi_tab, tmp_dir, output_dir, num_cpus, rerun)

    filtered_pfam_IDs = get_and_filter_pfam_hmms(tmp_dir, rerun)

    run_hmmsearch(tmp_dir, num_cpus, rerun)

    wanted_pfam_count = filter_HMM_hits(remaining_accs, filtered_pfam_IDs, percent_single_copy, tmp_dir, output_dir, rerun)

    filter_final_HMM_set(wanted_pfam_count, tmp_dir, output_dir, rerun)

    report_missed_accessions(output_dir, rerun)

    note_gtt_store_SCG_HMMs()

    if not args.keep_working_dir:
        shutil.rmtree(tmp_dir)

################################################################################

# setting some colors
tty_colors = {
    'green' : '\033[0;32m%s\033[0m',
    'yellow' : '\033[0;33m%s\033[0m',
    'red' : '\033[0;31m%s\033[0m'
}


### functions ###
def color_text(text, color='green'):
    if sys.stdout.isatty():
        return tty_colors[color] % text
    else:
        return text


def wprint(text):
    """ print wrapper """

    print(textwrap.fill(text, width=80, initial_indent="  ", 
          subsequent_indent="  ", break_on_hyphens=False))


def check_input_percent(input_percent):
    """ checks that the input percent is within 0 and 100) """

    if not 0 < input_percent <= 100:
        print("")
        wprint(color_text("The --percent-single-copy (-p) parameter needs to be between 1–100.", "yellow"))
        print("  See `gtt-gen-SCG-HMMs -h` for usage info.")
        print("\nExiting for now.\n")
        sys.exit(1)

    return(input_percent)

def check_cpus(num_cpus):
    """ checks the number of requested cpus doesn't excede what's available """

    available_cpus = os.cpu_count()

    if num_cpus > available_cpus:

        print("")
        wprint(color_text("The number of cpus specified (" + str(num_cpus) + ") is greater than what it seems is available on the current system (" + str(available_cpus) + ").", "yellow"))
        wprint("To be on the safer side, we are going to utilize " + str(available_cpus - 1) + " moving forward :)")
        print("")

        num_cpus = available_cpus - 1

    return(num_cpus)


def setup_output_dir(input_accessions_file, output_dir, rerun):
    """ setup re-usable temp dir, and use to check if restarting a run """

    output_dir = output_dir.rstrip("/") + "/"

    tmp_dir = output_dir + "tmp-dir/"

    if os.path.exists(output_dir):

        if not rerun:
            if os.path.exists(output_dir):
                print("")
                wprint(color_text("It seems the output directory '" + str(output_dir) + "' already exists.", "yellow"))
                wprint("We don't want to accidentally overwrite anything, if you'd like to intentionally work with this directory again, add the `--rerun` flag.")
                print("\nExiting for now.\n")
                sys.exit(1)

        else:
            # making sure input accessions match before allowing rerun:
            if os.path.exists(tmp_dir + "input-accessions.txt"):

                if not filecmp.cmp(input_accessions_file, tmp_dir + "input-accessions.txt"):

                    print("")
                    wprint(color_text("It seems the input accessions don't match those of the previous run, and we shouldn't attempt a `--rerun` in these situations :(", "yellow"))
                    print("  We should maybe try with a new output directory.\n")
                    print("Exiting for now.\n")
                    sys.exit(1)

            else:

                print("")
                wprint("It seems there were no stored input accessions from the previous run, the `--rerun` flag will be ignored.")

                if not os.path.exists(tmp_dir):
                    os.mkdir(tmp_dir)

                shutil.copy(input_accessions_file, tmp_dir + "input-accessions.txt")

                # setting rerun to false since the input accessions weren't there to be checked
                rerun=False

    else:

        if rerun:
            print("")
            wprint("It seems there is no data available from a previous run, the `--rerun` flag will be ignored.")
            rerun=False

        os.mkdir(output_dir)
        os.mkdir(tmp_dir)
        # making copy of input accession file to be able to check if the same on a restart
        shutil.copy(input_accessions_file, tmp_dir + "input-accessions.txt")

    return(tmp_dir, output_dir, rerun)


def get_target_accs(input_file):
    """ reading in target accs """

    target_accs = set()
    types = set()

    with open(input_file) as accs:
        for acc in accs:

            target_accs.add(acc.strip())

            # keeping track if GCA and/or GCF so we know which we need to download
            # to construct links
            if acc.startswith("GCA"):
                types.add("GCA")

            elif acc.startswith("GCF"):
                types.add("GCF")

            else:
                print(color_text("\n  It seems input accession '" + str(acc.strip()) + "' is not typical NCBI-accession format :(", "yellow"))
                print("\nExiting for now.\n")
                sys.exit(0)

    if "GCA" and not "GCF" in types:
        needed_dbs = "GCA"

    elif "GCF" and not "GCA" in types:
        needed_dbs = "GCF"

    else:
        needed_dbs = "both"

    return(list(target_accs), needed_dbs)


def get_ncbi_tabs(needed_dbs, target_accs, tmp_dir, output_dir, rerun):
    """downloading needed NCBI assembly summary files to construct download links"""

    # checking if can use from previous run
    if rerun == True and os.path.exists(tmp_dir + "ncbi-assembly-info.tsv"):

        print("")
        wprint("NCBI assembly summary files from previous run found and loaded...")

        ncbi_tab = pd.read_csv(tmp_dir + "ncbi-assembly-info.tsv", sep="\t", header=0, low_memory=False)
        accs_not_found = []

        with open(tmp_dir + "not-found-accs.txt") as accs:
            for acc in accs:
                accs_not_found.append(acc.strip())

    else:

        print("")
        wprint("Downloading needed NCBI assembly summary files...")

        if needed_dbs == "GCA":
            genbank_tab = urllib.request.urlopen("ftp://ftp.ncbi.nlm.nih.gov/genomes/genbank/assembly_summary_genbank.txt")
            ncbi_tab = pd.read_csv(genbank_tab, sep="\t", header=1, low_memory=False, usecols=["# assembly_accession", "ftp_path"])

        elif needed_dbs == "GCF":
            refseq_tab = urllib.request.urlopen("ftp://ftp.ncbi.nlm.nih.gov/genomes/refseq/assembly_summary_refseq.txt")
            ncbi_tab = pd.read_csv(refseq_tab, sep="\t", header=1, low_memory=False, usecols=["# assembly_accession", "ftp_path"])

        else:
            genbank_tab = urllib.request.urlopen("ftp://ftp.ncbi.nlm.nih.gov/genomes/genbank/assembly_summary_genbank.txt")
            genbank_tab = pd.read_csv(genbank_tab, sep="\t", header=1, low_memory=False, usecols=["# assembly_accession", "ftp_path"])
            refseq_tab = urllib.request.urlopen("ftp://ftp.ncbi.nlm.nih.gov/genomes/refseq/assembly_summary_refseq.txt")
            refseq_tab = pd.read_csv(refseq_tab, sep="\t", header=1, low_memory=False, usecols=["# assembly_accession", "ftp_path"])

            ncbi_tab = pd.concat([genbank_tab, refseq_tab])

        # filtering down to just those needed
        ncbi_tab = ncbi_tab[ncbi_tab["# assembly_accession"].isin(target_accs)]

        # figuring out if any were not found, to track and report
        found_accs = ncbi_tab["# assembly_accession"].tolist()
        accs_not_found = list(set(target_accs) - set(found_accs))

        # writing out to be re-used if run fails in the middle
        ncbi_tab.to_csv(tmp_dir + "ncbi-assembly-info.tsv", index=False, header=True, sep="\t")

        with open(tmp_dir + "not-found-accs.txt", "w") as out:
            for acc in accs_not_found:
                out.write(acc + "\n")

    if len(accs_not_found) != 0:

        with open(output_dir + "Missed-accessions.tsv", "w") as out:
            out.write("accession\twhy_missing\n")

            for acc in accs_not_found:
                out.write(acc + "\t" + "not found" + "\n")

        print("")
        print(color_text("    " + str(len(accs_not_found)) + " accession(s) not found in NCBI assembly tables.", "yellow"))
        print("      Reported in:")
        print("        " + output_dir + "Missed-accessions.tsv")

    return(ncbi_tab, accs_not_found)


def get_amino_acids(ncbi_tab, tmp_dir, output_dir, num_cpus, rerun):
    """ controller for getting amino acids """

    # checking if can use from a previous run
    if rerun == True and os.path.exists(tmp_dir + "all-genes.faa") and os.path.exists(tmp_dir + "remaining_accs.txt") and os.path.exists(tmp_dir + "download_failed.txt") and os.path.exists(tmp_dir + "prodigal_failed.txt"):

        print("")
        wprint("Amino-acid sequences from previous run detected and being used...")

        remaining_accs, download_failed, prodigal_failed = [], [], []

        with open(tmp_dir + "remaining_accs.txt") as accs:
            for acc in accs:
                if acc != "":
                    remaining_accs.append(acc.strip())

        with open(tmp_dir + "download_failed.txt") as accs:
            for acc in accs:
                if acc != "":
                    download_failed.append(acc.strip())

        with open(tmp_dir + "prodigal_failed.txt") as accs:
            for acc in accs:
                if acc != "":
                    prodigal_failed.append(acc.strip())

    else:

        print("")
        wprint("Getting target genome amino-acid sequences...")

        if num_cpus == 1:

            remaining_accs, download_failed, prodigal_failed = dl_genomes_and_get_amino_acids(ncbi_tab, tmp_dir)

        else:

            # Running this step in parallel with greater than ~15 cpus can often cause it to hang in my testing.
            # After some poking around and a short consult with the wonderful Evan Bolyen, it is likely this is an
            # I/O problem. So setting this to a max of 12 cpus to be on the safer side
            # (higher amounts of cpus do  work on the parallel hmmsearch step later)

            if num_cpus > 12:
                num_cpus = 12

            remaining_accs, download_failed, prodigal_failed = parallelize_get_amino_acids(ncbi_tab, dl_genomes_and_get_amino_acids, num_cpus, tmp_dir)

        # writing out so can re-start if needed (and adding to ongoing report file "Missed-accessions.tsv")
        with open(tmp_dir + "remaining_accs.txt", "w") as out:
            if len(remaining_accs) > 0:
                out.write("\n".join(remaining_accs))
                out.write("\n")
            else:
                pass

        with open(tmp_dir + "download_failed.txt", "w") as out:
            if len(download_failed) > 0:
                out.write("\n".join(download_failed))
                out.write("\n")

                if not os.path.exists(output_dir + "Missed-accessions.tsv"):
                    with open(output_dir + "Missed-accessions.tsv") as out:
                        out.write("accession\twhy_missing\n")
                with open(output_dir + "Missed-accessions.tsv") as out:
                    for acc in download_failed:
                        out.write(acc + "\t" + "download failed" + "\n")

            else:
                pass

        with open(tmp_dir + "prodigal_failed.txt", "w") as out:
            if len(prodigal_failed) > 0:
                out.write("\n".join(prodigal_failed))
                out.write("\n")

                if not os.path.exists(output_dir + "Missed-accessions.tsv"):
                    with open(output_dir + "Missed-accessions.tsv") as out:
                        out.write("accession\twhy_missing\n")
                with open(output_dir + "Missed-accessions.tsv") as out:
                    for acc in prodigal_failed:
                        out.write(acc + "\t" + "prodigal failed" + "\n")

            else:
                pass

        # here combining all target seqs into one and removing intermediate files
        all_individual_AA_files = glob(tmp_dir + "*-to-combine.faa")

        with open(tmp_dir + "all-genes.faa", "wb") as all_seqs:
            for file in all_individual_AA_files:
                with open(file, "rb") as sub_seqs:
                    shutil.copyfileobj(sub_seqs, all_seqs)
                os.remove(file)

    return(remaining_accs)


def dl_genomes_and_get_amino_acids(ncbi_tab, tmp_dir):
    """ main function for downloading and getting amino acid sequences """

    # tracking those that fail
    remaining_accs = []
    download_failed = []
    prodigal_failed = []

    target_accs = ncbi_tab["# assembly_accession"].tolist()

    AA_suffix = "_protein.faa.gz"
    nt_suffix = "_genomic.fna.gz"

    for index, row in ncbi_tab.iterrows():
        curr_acc = row["# assembly_accession"]
        curr_base_url = row["ftp_path"]
        base_name = os.path.basename(curr_base_url)
        AA_url = curr_base_url + "/" + base_name + AA_suffix

        # attempting to download amino-acid file first
        try:

            curr_AA_gz = urllib.request.urlopen(AA_url)
            curr_AA = gzip.decompress(curr_AA_gz.read()).decode("UTF-8", "ignore")

            curr_AA = io.StringIO(curr_AA)

        except:

            # now trying to get the nucleotide file and run prodigal on it
            try:

                nt_url = curr_base_url + "/" + base_name + nt_suffix

                curr_nt_gz = urllib.request.urlopen(nt_url)
                curr_nt = gzip.decompress(curr_nt_gz.read()).decode("UTF-8", "ignore")

            except:

                download_failed.append(curr_acc)
                continue

            # writing out so can run prodigal on it
            tmp_nt_file = tmp_dir + curr_acc + ".fa"
            tmp_AA_file = tmp_dir + curr_acc + ".faa"

            with open(tmp_nt_file, "w") as out:
                out.write(curr_nt)

            prodigal_call = subprocess.run(["prodigal", "-c", "-q", "-i", tmp_nt_file, "-a", tmp_AA_file], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

            if prodigal_call.returncode != 0:
                prodigal_failed.append(curr_acc)
                os.remove(tmp_nt_file)

                continue

            os.remove(tmp_nt_file)

            # setting file handle so will fit in properly with those downloaded as AA seqs
            curr_AA = tmp_AA_file

        # moving forward with processing amino acid file (includes renaming headers and sticking all together)
            # I initially had these all writing to the same "master" output file holding all amino acid sequences,
            # but that started showing odd behavior when being run in parallel with many cpus (occassionally some seqs weren't showing up),
            # so now they are each individually written to their own file, and then combined afterwards

        n = 0 # iterator for keeping headers unique while labeling them with source accession

        with open(tmp_dir + curr_acc + "-to-combine.faa", "w") as curr_out:

            for seq_record in SeqIO.parse(curr_AA, "fasta"):

                n += 1

                curr_out.write(">" + curr_acc + "_" + str(n) + "\n")
                curr_out.write(str(seq_record.seq + "\n").replace("*", "")) # removing stop-codon stars from prodigal files

        remaining_accs.append(curr_acc)

        # cleaning up intermediate file if it exists
        try:
            if os.path.exists(tmp_AA_file):
                os.remove(tmp_AA_file)
        except (UnboundLocalError, NameError):
            pass

    return(remaining_accs, download_failed, prodigal_failed)


def parallelize_get_amino_acids(df, func, num_cpus, tmp_dir):
    """ run dl_genomes_and_get_amino_acids() in parallel """

    df_split = np.array_split(df, num_cpus)

    pool = mp.Pool(num_cpus)
    function_call = partial(func, tmp_dir=tmp_dir) # allows providing additional constant argument "tmp_dir" to pool.map next
    output = pool.map(function_call, df_split)
    pool.close()
    pool.join()

    remaining_accs, download_failed, prodigal_failed = [], [], []

    for i in range(num_cpus):
        remaining_accs += output[i][0]
        download_failed += output[i][1]
        prodigal_failed += output[i][2]

    return(remaining_accs, download_failed, prodigal_failed)


def get_and_filter_pfam_hmms(tmp_dir, rerun, full_pfams_hmm_file="all-pfams.hmm", filtered_pfam_IDs_file="filtered-pfam-IDs.txt", filtered_pfams_info_file="filtered-pfams-info.tsv", filtered_pfams_hmm_file="filtered-pfams.hmm"):
    """ getting all latest PFams and filtering based on coverage """

    if rerun == True and os.path.exists(tmp_dir + filtered_pfam_IDs_file) and os.path.exists(tmp_dir + filtered_pfams_info_file) and os.path.exists(tmp_dir + filtered_pfams_hmm_file):

        print("")
        wprint("PFams that were already filtered based on average coverage of the underlying proteins were detected from a previous run...")

        filtered_pfam_IDs = []

        with open(tmp_dir + filtered_pfams_info_file) as pfams_info:
            for line in pfams_info:
                if line.startswith("PFam_ID"):
                    continue

                filtered_pfam_IDs.append(line.strip().split("\t")[0])

    else:

        print("")
        wprint("Downloading all PFam HMMs and filtering based on average coverage of underlying proteins...")

        try:
            pfam_hmms_gz = urllib.request.urlopen("ftp://ftp.ebi.ac.uk/pub/databases/Pfam/current_release/Pfam-A.hmm.gz")
            pfam_hmms = gzip.decompress(pfam_hmms_gz.read()).decode("UTF-8", "ignore")

            with open(tmp_dir + full_pfams_hmm_file, "w") as out:
                out.write(pfam_hmms)

        except:
            print("")
            wprint(color_text("Downloading the master PFam HMM file failed :(", "yellow"))
            print("Exiting for now.\n")
            sys.exit(0)

        try:
            pfam_hmm_info_gz = urllib.request.urlopen("ftp://ftp.ebi.ac.uk/pub/databases/Pfam/current_release/database_files/pfamA.txt.gz")
            pfam_hmm_info_tab = pd.read_csv(pfam_hmm_info_gz, sep="\t", compression="gzip", header=None, low_memory=False, usecols=[0,1,3,27,33])

        except:
            print("")
            wprint(color_text("Downloading the PFam HMM information file failed :(", "yellow"))
            print("Exiting for now.\n")
            sys.exit(0)

        # creating list of length-coverage filtered pfams and writing out base info table we're keeping and ID file for giving to hmmfetch
        filtered_pfam_IDs = []

        with open(tmp_dir + filtered_pfams_info_file, "w") as out:
            with open(tmp_dir + filtered_pfam_IDs_file, "w") as out_IDs:

                out.write("PFam_ID" + "\t" + "short_name" + "\t" + "name" + "\n")

                for index, row in pfam_hmm_info_tab.iterrows():
                    if row[33] > 50:
                        ID_with_version = str(row[0]) + "." + str(row[27])
                        filtered_pfam_IDs.append(ID_with_version)
                        out.write(ID_with_version + "\t" + str(row[1]) + "\t" + str(row[3]) + "\n")
                        out_IDs.write(ID_with_version + "\n")

        with open(tmp_dir + filtered_pfams_hmm_file, "w") as out:
            pfam_filter_run = subprocess.run(["hmmfetch", "-f", tmp_dir + full_pfams_hmm_file, tmp_dir + filtered_pfam_IDs_file], stdout=out)

        if pfam_filter_run.returncode != 0:
            print("")
            wprint(color_text("Something went wrong with the `hmmfetch` call :(", "yellow"))
            print("Exiting for now.\n")
            sys.exit(0)

    return(filtered_pfam_IDs)


def run_hmmsearch(tmp_dir, num_cpus, rerun, input_seqs="all-genes.faa", filtered_pfams_hmm_file="filtered-pfams.hmm"):
    """ controller for running hmmsearch """

    if rerun == True and os.path.exists(tmp_dir + "pfam-hits.tab"):

        print("")
        wprint("Filtered-PFam hmmsearch already performed, using previous results...")

    else:

        print("")
        wprint("Running hmmsearch of all proteins against the filtered PFams...")

        if num_cpus == 1:

            hmmsearch(tmp_dir + input_seqs, tmp_dir + filtered_pfams_hmm_file, tmp_dir + "pfam-hits.tab")

        else:

            parallelize_hmmsearch(tmp_dir, input_seqs, filtered_pfams_hmm_file, num_cpus, hmmsearch)


def hmmsearch(input_seqs, filtered_pfams_hmm_file, output_table):
    """ main hmmsearch function """

    hmmsearch = subprocess.run(["hmmsearch", "--cut_ga", "--cpu", "1", "--tblout", output_table, filtered_pfams_hmm_file, input_seqs], stdout=subprocess.DEVNULL)


def parallelize_hmmsearch(tmp_dir, input_seqs, filtered_pfams_hmm_file, num_cpus, func):
    """ function to parallelize main hmmsearch function """

    # splitting sequence file into blocks to run in parallel
      # tested splitting the HMMs instead, came out pretty even in my limited testing
    # getting number of lines
    with open(tmp_dir + input_seqs) as seqs:
        for i, l in enumerate(seqs):
            pass
    lines = i + 1

    # getting lines per block, rounding up to next 10 so will always be even
    lines_per_block = int(ceil(lines / num_cpus / 10)) * 10

    # splitting the file
    subprocess.run(["split", "-l", str(lines_per_block), "-a", "3", tmp_dir + input_seqs, tmp_dir + "sub-"])

    # getting the split file names
    split_files = glob(tmp_dir + "sub-*")

    # building arguments tuple list to pass to the pool parallel call
    args_tuple_list = []

    for file in split_files:
        args_tuple_list.append((file, tmp_dir + filtered_pfams_hmm_file, file + ".tab"))

    pool = mp.Pool(num_cpus)
    pool.starmap(hmmsearch, args_tuple_list)
    pool.close()

    # combining results tables
    with open(tmp_dir + "pfam-hits.tab", "wb") as out:
        for file in split_files:
            with open(file + ".tab", "rb") as subtab:
                shutil.copyfileobj(subtab, out)

    # removing split files and results now that all are combined
    files_to_remove = glob(tmp_dir + "sub-*")
    for file in files_to_remove:
        os.remove(file)


def filter_HMM_hits(target_genome_accs, filtered_pfam_IDs, percent_single_copy, tmp_dir, output_dir, rerun, pfam_hits_file="pfam-hits.tab", wanted_pfams_info_file="Wanted-PFams-info.tsv", filtered_pfams_info_file="filtered-pfams-info.tsv", wanted_pfam_IDs_file="wanted-PFam-IDs.txt"):
    """ filtering pfam hit table """

    if rerun == True and os.path.exists(output_dir + wanted_pfams_info_file) and os.path.exists(output_dir + "PFam-HMM-hit-counts.tsv") and os.path.exists(tmp_dir + wanted_pfam_IDs_file):

        print("")
        wprint("Filtered-PFam hmmsearch results table was already parsed down, using previous results...")

        wanted_pfams = []
        with open(output_dir + wanted_pfams_info_file) as pfams_info:
            for line in pfams_info:
                if line.startswith("PFam_ID"):
                    continue
                wanted_pfams.append(line.strip().split("\t")[0])

    else:

        print("")
        wprint("Parsing hit table to find which PFams had exactly 1 hit in >= " + str(percent_single_copy) + "% of the included genomes...")

        array = np.zeros( (len(target_genome_accs), len(filtered_pfam_IDs)), dtype=int)
        count_tab = pd.DataFrame(array, index=target_genome_accs, columns=filtered_pfam_IDs)

        with open(tmp_dir + pfam_hits_file, "r") as hits:
            for line in hits:

                if line.startswith("#"):
                    continue

                line_list = line.strip().split()
                curr_acc_list = line_list[0].split("_")[:2]
                curr_acc = "_".join(curr_acc_list)
                curr_pfam = line_list[3]

                count_tab.loc[curr_acc, curr_pfam] = count_tab.loc[curr_acc, curr_pfam] + 1

        wanted_pfams = []

        for pfam in count_tab.columns.tolist():

            if sum(count_tab[pfam] == 1) / len(target_genome_accs) * 100 >= percent_single_copy:
                wanted_pfams.append(pfam)

        # writing out because they are needed for the hmmfetch program
        with open(tmp_dir + wanted_pfam_IDs_file, "w") as out:
            out.write("\n".join(wanted_pfams))
            out.write("\n")

        # generating and writing out wanted PFams info table
        filtered_pfams_info_tab = pd.read_csv(tmp_dir + filtered_pfams_info_file, header=0, sep="\t")
        wanted_pfams_info_tab = filtered_pfams_info_tab[filtered_pfams_info_tab.PFam_ID.isin(wanted_pfams)]
        wanted_pfams_info_tab.to_csv(output_dir + wanted_pfams_info_file, sep="\t", header=True, index=False)

        print("    Wanted PFam info table written to:")
        print("      " + output_dir + "Wanted-PFams-info.tsv\n")

        # writing out full count table
        count_tab.to_csv(output_dir + "PFam-HMM-hit-counts.tsv", sep="\t", header=True, index=True)

        print("    PFam HMM hit-count table written to:")
        print("      " + output_dir + "PFam-HMM-hit-counts.tsv")


    return(len(wanted_pfams))


def filter_final_HMM_set(wanted_pfam_count, tmp_dir, output_dir, rerun, filtered_pfams_hmm_file="filtered-pfams.hmm", wanted_pfam_IDs_file="wanted-PFam-IDs.txt"):
    """ filter final HMM set """

    # if specific output_dir was provided by user, going to use that as prefix to final output hmm file
    if args.output_dir != "gtt-gen-SCG-HMMs-output":
        wanted_HMM_out_file = str(output_dir).rstrip("/") + ".hmm"

    # otherwise using just numbers
    else:
        wanted_HMM_out_file = "Wanted-" + str(wanted_pfam_count) + "-SCG-targets.hmm"

    if rerun == True and os.path.exists(output_dir + wanted_HMM_out_file):

        print(color_text("\n -------------------------------------------------------------------------------\n"))

        wprint("It seems all is done already, you should specify a non-existent output directory or a different input accessions file if you'd like to do something different :)")
        print("")
        print("  The previous run generated the following SCG-HMMs file holding " + str(wanted_pfam_count) + " target genes:")
        print(color_text("    " + output_dir + wanted_HMM_out_file))
        print("")

    else:

        print("")
        wprint("Filtering for the final SCGs' HMMs holding " + str(wanted_pfam_count) + " targets...")


        with open(output_dir + wanted_HMM_out_file, "w") as out:
            pfam_filter_run = subprocess.run(["hmmfetch", "-f", tmp_dir + filtered_pfams_hmm_file, tmp_dir + wanted_pfam_IDs_file], stdout=out)

        if pfam_filter_run.returncode != 0:
            print("")
            wprint(color_text("Something went wrong with the `hmmfetch` call :(", "yellow"))
            print("Exiting for now.\n")
            sys.exit(0)

        print(color_text("\n -------------------------------------------------------------------------------\n"))

        print("  New SCG-HMMs holding " + str(wanted_pfam_count) + " target genes written to:")
        print(color_text("    " + output_dir + wanted_HMM_out_file))
        print("")


def report_missed_accessions(output_dir, rerun):
    """ filter final HMM set """

    if os.path.exists(output_dir + "Missed-accessions.tsv"):

        wprint("Any input accessions that didn't make it through are reported in:")
        wprint(color_text("  " + output_dir + "Missed-accessions.tsv", "yellow"))
        print("")


def note_gtt_store_SCG_HMMs():
    try:
        gtt_hmm_dir = os.environ['GToTree_HMM_dir']
        wprint("If you'd like to add this new SCG-HMM set to the stored GToTree ones, you can do so with the `gtt-store-SCG-HMMs` program :)")
        print("")
    except:
        pass


if __name__ == "__main__":
    main()
