"""
This script takes a rbh file from odp_trio, and performs an HMM search against
all of the other species in the config file.

The output of this is an rbh .tsv file, with or without a group column,
 that has the rbhs and their orthologs. It is allowed for there to be missing
 reads.
"""
# This block imports fasta-parser as fasta
import os
import sys
snakefile_path = os.path.dirname(os.path.realpath(workflow.snakefile))
dependencies_path = os.path.join(snakefile_path, "../dependencies/fasta-parser")
sys.path.insert(1, dependencies_path)
import fasta

import ast
from itertools import groupby
from itertools import product
import math
import numpy as np
import odp_functions as OdpF
from operator import itemgetter
import pandas as pd
import statistics

configfile: "config.yaml"

config["tool"] = "odp_rbh_to_HMM"

OdpF.check_legality(config)

if "missing_ok" not in config:
    config["missing_ok"] == False

if not "rbh_file" in config:
    raise IOError("You must specify 'rbh_file' in config.")

if not "species" in config:
    raise IOError("You must specify all of the species that will be involved in this analysis in the config.yaml file.")

# come up with the species in the rbh
testdf = pd.read_csv(config["rbh_file"], sep = "\t", index_col = None)
rbh_species = [x.replace("_scaf", "") for x in testdf.columns
               if x.endswith("_scaf")]
species_string = "_".join(sorted(rbh_species))

if len(rbh_species) < 3:
    raise IOError("There must be at least two species in the rbh file.")

# come up with other species to perform the HMM search against
other_species = [x for x in config["species"] if x not in rbh_species]
print()
print("other_species")
print(other_species)

if len(rbh_species) < 1:
    raise IOError("There must be at least one species to search against.")

config["nways"] = len(rbh_species)

# come up with number of rbhs
newstring = "rbh{}way_{}".format(len(rbh_species), "_".join(sorted(rbh_species)))
rbh_entries = list(testdf["rbh"])

rule all:
    input:
        #expand(config["tool"] + "/fasta/unaligned/{rbh}.fasta",
        #       rbh = rbh_entries),
        #expand(config["tool"] + "/hmm/hmms/{rbh}.hmm",
        #       rbh = rbh_entries),
        #expand(config["tool"] + "/hmm/searches/{other_species}/{rbh}_against_{other_species}.tsv",
        #       other_species = other_species, rbh = rbh_entries),
        expand(config["tool"] + "/hmm/searches_agg/{other_species}_hmm_results.sorted.tsv",
               other_species = other_species, rbh = rbh_entries),
        #expand("hmm/searches_agg_best/{species}_hmm_best.tsv",
        #       species = config["yaxisspecies"]),
        config["tool"] + "/output/" + species_string + "_rbhhmm_plus_other_species.rbh"

rule generate_fasta_of_each_rbh:
    input:
        rbh = config["rbh_file"],
        proteins = lambda wildcards: [config["species"][x]["proteins"]
                    for x in rbh_species]
    output:
        MBH_fasta = expand(config["tool"] + "/fasta/unaligned/{rbh}.fasta",
                           rbh = rbh_entries)
    threads: 1
    run:
        print(" - Reading the dataframe of rbhs.")
        df = pd.read_csv(input.rbh, index_col = None, sep = "\t")
        print(df)

        # make a dict of gene_to_rbh
        print(" - Making a dict of gene_to_rbh")
        species_to_gene_to_rbh = {}
        for index, row in df.iterrows():
            thisrbh = row["rbh"]
            for this_species in rbh_species:
                if this_species not in species_to_gene_to_rbh:
                    species_to_gene_to_rbh[this_species] = {}
                genecol = "{}_gene".format(this_species)
                species_to_gene_to_rbh[this_species][row[genecol]] = thisrbh

        # now make a dict of rbh_to_fasta_records
        print(" - Making a dict of rbh_to_fasta_records.")
        rbh_to_records = {}
        for this_species in rbh_species:
            infile = config["species"][this_species]["proteins"]
            for record in fasta.parse(infile):
                if record.id in species_to_gene_to_rbh[this_species]:
                    thisrbh = species_to_gene_to_rbh[this_species][record.id]
                    record.name = record.id
                    newid = "{}_{}".format(thisrbh, this_species)
                    record.id = newid
                    if thisrbh not in rbh_to_records:
                        rbh_to_records[thisrbh] = []
                    rbh_to_records[thisrbh].append(record)

        # make sure that all entries have as many sequences as species
        if config["missing_ok"]:
            print(" - Missing records are OK. This is good for rbh_merge.")
        else:
            print(" - Verifying that the records are all complete.")
            for thisrbh in rbh_to_records:
                if len(rbh_to_records[thisrbh]) != len(rbh_species):
                    raise IOError("{} only has {} genes".format(
                        thisrbh, len(rbh_to_records[thisrbh])))

        # print out all of the records to fasta files
        print(" - Printing out the FASTA records.")
        for thisrbh in rbh_to_records:
            outfile = config["tool"] + "/fasta/unaligned/{}.fasta".format(thisrbh)
            with open(outfile, "w") as output_handle:
                for thisrec in rbh_to_records[thisrbh]:
                    print(thisrec, file = output_handle)

rule align_fasta_of_each_group:
    input:
        MBH_fasta = config["tool"] + "/fasta/unaligned/{rbh}.fasta"
    output:
        aligned =   config["tool"] + "/fasta/aligned/{rbh}.aligned.fasta",
    threads: 1
    shell:
        """
        mafft --localpair --anysymbol --maxiterate 1000 {input.MBH_fasta} > {output.aligned}
        """

rule make_hmm:
    input:
        aligned = config["tool"] + "/fasta/aligned/{rbh}.aligned.fasta"
    output:
        hmm     = config["tool"] + "/hmm/hmms/{rbh}.hmm"
    threads: 1
    shell:
        """
        hmmbuild {output.hmm} {input.aligned}
        """

rule hmm_search_against_genome:
    input:
        proteins = lambda wildcards: config["species"][wildcards.other_species]["proteins"],
        hmm = config["tool"] + "/hmm/hmms/{rbh}.hmm"
    output:
        tsv = config["tool"] + "/hmm/searches/{other_species}/{rbh}_against_{other_species}.tsv"
    threads: 2
    shell:
        """
        hmmsearch --tblout {output.tsv} \
          --cpu {threads} \
          --noali \
          --notextw \
          {input.hmm} \
          {input.proteins} > /dev/null
        """

rule aggregate_hmm_results_by_species:
    """
    The header fields are:
    # target name        accession  query name                         accession    E-value  score  bias   E-value  score  bias   exp reg clu  ov env dom rep inc description of target

    the headers for blastp outfmt 6 are:
      qseqid sseqid pident length mismatch gapopen qstart qend sstart send evalue bitscore
    """
    input:
        tsv = expand(config["tool"] + "/hmm/searches/{{other_species}}/{rbh}_against_{{other_species}}.tsv",
                     rbh = rbh_entries)
    output:
        tsv =        config["tool"] + "/hmm/searches_agg/{other_species}_hmm_results.tsv"
    params:
        this_species = lambda wildcards: wildcards.other_species
    threads: 1
    run:
        entries = []
        for thisfile in input.tsv:
            with open(thisfile, "r") as f:
                for line in f:
                    line = line.strip()
                    if line and not line.startswith("#"):
                        fields = line.split()
                        thisentry = {"qseqid": fields[2].replace(".aligned", ""),
                                     "sseqid": fields[0],
                                     "pident": 50.0,
                                     "length": 50.0,
                                     "mismatch": 0,
                                     "gapopen": 0,
                                     "qstart": 0,
                                     "qend": 0,
                                     "sstart": 0,
                                     "send": 0,
                                     "evalue": float(fields[4]),
                                     "bitscore": float(fields[5]) }
                        entries.append(thisentry)
        df = pd.DataFrame(entries)
        df.to_csv(output.tsv, sep="\t", index = False, header = False)

rule sort_the_collated_hmm_results:
    input:
        tsv =        config["tool"] + "/hmm/searches_agg/{other_species}_hmm_results.tsv"
    output:
        tsv =        config["tool"] + "/hmm/searches_agg/{other_species}_hmm_results.sorted.tsv"
    params:
        this_species = lambda wildcards: wildcards.other_species
    shell:
        """
        cat {input.tsv} | sort -k12,12nr > {output.tsv}
        """

rule best_gene_for_each_hmm:
    """
    finds the best hmm for each gene
    """
    input:
        tsv = config["tool"] + "/hmm/searches_agg/{other_species}_hmm_results.sorted.tsv"
    output:
        tsv = config["tool"] + "/hmm/searches_agg_best/{other_species}_hmm_best.tsv"
    threads: 1
    run:
        df = pd.read_csv(input.tsv, header = None, sep = "\t")
        df.columns = ["qseqid", "sseqid", "pident", "length", "mismatch",
                      "gapopen", "qstart", "qend", "sstart", "send",
                      "evalue", "bitscore"]
        df = df.sort_values(["bitscore"], ascending = [False])
        df = df.drop_duplicates(subset = ["sseqid"], keep = "first")
        df = df.drop_duplicates(subset = ["qseqid"], keep = "first")
        df = df.reset_index(drop=True)
        df.to_csv(output.tsv, header = False, index = False, sep="\t")

rule collate_best_hits_into_best_hits_file:
    """
    put the HMM best hits info into a single file for all species.
     we need a {}_gene, {}_scaf, and {}_pos column for each new species we put in.
    """
    input:
        rbh_file = config["rbh_file"],
        best_hits_files = expand(config["tool"] + "/hmm/searches_agg_best/{other_species}_hmm_best.tsv",
                                 other_species = other_species),
        chrom = lambda wildcards: [config["species"][x]["chrom"] for x in
                                   other_species]
    output:
        rbh = config["tool"] + "/output/" + species_string + "_rbhhmm_plus_other_species.rbh"
    run:
        rbh = pd.read_csv(input.rbh_file, header = 0, index_col = None, sep = "\t")
        for thisspecies in other_species:
            thischrom = config["species"][thisspecies]["chrom"]
            species_gene_to_chrom = {}
            species_gene_to_pos = {}
            with open(thischrom, "r") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        fields = line.split()
                        species_gene_to_chrom[fields[0]] = fields[1]
                        species_gene_to_pos[fields[0]] = int(fields[3])
            readthis = config["tool"] + "/hmm/searches_agg_best/{}_hmm_best.tsv".format(thisspecies)
            species_rbh = pd.read_csv(readthis, header = None, sep = "\t")
            species_rbh.columns = ["qseqid", "sseqid", "pident", "length", "mismatch",
                                   "gapopen", "qstart", "qend", "sstart", "send",
                                   "evalue", "bitscore"]


            rbh["{}_gene".format(thisspecies)] = rbh["rbh"].map(
                        dict(zip(species_rbh["qseqid"],
                                 species_rbh["sseqid"])) )

            rbh["{}_scaf".format(thisspecies)] = rbh[
                "{}_gene".format(thisspecies)].map(species_gene_to_chrom)

            rbh["{}_pos".format( thisspecies)] = rbh[
                "{}_gene".format(thisspecies)].map(species_gene_to_pos)

        rbh_sorted = list(sorted([x for x in rbh.columns if x.split("_")[0] in rbh_species]))
        other_sorted = list(sorted([x for x in rbh.columns if x.split("_")[0] in other_species]))
        annot_col = list(sorted([x for x in rbh.columns if x not in ["rbh"] + rbh_sorted + other_sorted]))
        sorted_cols = ["rbh"] + annot_col + rbh_sorted + other_sorted
        rbh = rbh.reindex(sorted_cols, axis=1)
        rbh = rbh.fillna(value = "None")
        rbh.to_csv(output.rbh,  sep="\t", header = True, index = False)


#rule best_gene_for_each_hmm_LEGACY:
#    """
#    finds the best hmm for each gene
#    """
#    input:
#        tsv = config["tool"] + "/hmm/searches_agg/{other_species}_hmm_results.sorted.tsv"
#        chrom = lambda wildcards: config["xaxisspecies"][wildcards.species]["chrom"],
#        rbh = "rbh/reciprocal_best_hits.tsv"
#    output:
#        tsv = "hmm/searches_agg_best/{species}_hmm_best.tsv",
#        rbhinfo = "hmm/rbh_plus_HMM_agg_best/{species}_hmm_chrom_info.tsv",
#        groupby = "hmm/rbh_plus_HMM_agg_best/{species}_hmm_chrom_info.groupby.tsv"
#    params:
#        species = lambda wildcards: wildcards.species
#    threads: 1
#    run:
#        species_gene_to_chrom = {}
#        species_gene_to_pos = {}
#        with open(input.chrom, "r") as f:
#            for line in f:
#                line = line.strip()
#                if line:
#                    fields = line.split()
#                    species_gene_to_chrom[fields[0]] = fields[1]
#                    species_gene_to_pos[fields[0]] = int(fields[3])
#        df = pd.read_csv(input.tsv, index_col = 0, sep = "\t")
#        df = df.sort_values(["sseqid", "evalue"], ascending=[True, True])
#        df = df.drop_duplicates(subset = ["sseqid"])
#        df = df.sort_values(["qseqid", "evalue"], ascending=[True, True])
#        df = df.drop_duplicates(subset = ["qseqid"])
#        df = df.reset_index(drop=True)
#        df.to_csv(output.tsv, header = False, index = False, sep="\t")
#        df = df.iloc[:, 0:2]
#        df.columns = ["rbh", params.species]
#        rbh = pd.read_csv(input.rbh, index_col = 0, sep = "\t")
#        rbh["{}_gene".format(params.species)] = rbh["rbh"].map(
#                    dict(zip(df["rbh"],
#                             df[params.species])) )
#        rbh["{}_scaf".format(params.species)] = rbh[
#            "{}_gene".format(params.species)].map(
#            species_gene_to_chrom)
#        rbh["{}_pos".format(params.species)] = rbh[
#            "{}_gene".format(params.species)].map(
#            species_gene_to_pos)
#        rbh.to_csv(output.rbhinfo,  sep="\t")
#
#        groupbycols = [x for x in rbh.columns if "_scaf" in x]
#        grouped_multiple = rbh.groupby(groupbycols).agg(list).reset_index()
#        # get the size
#        grouped_multiple["count"] = grouped_multiple.rbh.str.len()
#        grouped_multiple.to_csv(output.groupby,  sep="\t")
