"""
This script performs mutual-best protein diamond BLAST searches, n-ways,
  then shows how many pairs are conserved between them

This script does not support y-axis species
"""
# This block imports fasta-parser as fasta
import os
import sys
snakefile_path = os.path.dirname(os.path.realpath(workflow.snakefile))
dependencies_path = os.path.join(snakefile_path, "../dependencies/fasta-parser")
sys.path.insert(1, dependencies_path)
import fasta

# ODP-specific imports
scripts_path = os.path.join(snakefile_path, "../scripts")
sys.path.insert(1, scripts_path)
import odp_functions as odpf

# other standard python libraries
from itertools import groupby
from itertools import combinations
import math
from operator import itemgetter
import statistics

# non-standard dependencies
import matplotlib
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np


configfile: "config.yaml"

odpf.check_legality(config)

# make sure none of the sample names have underscores
for thissample in config["xaxisspecies"]:
    if "_" in thissample:
        raise IOError("Sample names can't have '_' char: {}".format(thissample))

# make sure there are at least 2 samples
if len(config["xaxisspecies"]) < 2:
    raise IOError("There must be at least 2 samples")
# make sure that nways is there
if not "nways" in config:
    raise IOError("you must specify nways in the config file. must be at least 2.")
# make sure that nways is greater than equal to 2
if config["nways"] < 2:
    raise IOError("nways must be at least 2")

# make sure that num species is gteq nways
if not len(config["xaxisspecies"]) >= config["nways"]:
    raise IOError("The number of species must be greater than or equal to nways")

# come up with all of the analyses
all_species = [x for x in config["xaxisspecies"]]
analyses_dicts = {n: [list(sorted(x)) for x in combinations(all_species, n)]
                      for n in range(2,len(all_species)+1)}

print("Here is an example of the first few analyses: {}".format(analyses_dicts))
num_possible_combinations = sum([len(analyses_dicts[n]) for n in analyses_dicts])
print("There are {} possible combinations.".format(num_possible_combinations))

# This is specifically for the trio odp
config["yaxisspecies"] = config["xaxisspecies"]

rule all:
    input:
        # find the mutual best hits
        odpf.flatten([["synteny_analysis/blastp_results/reciprocal_best/{}_acceptable_prots_{}_ways.txt".format(
            "_".join(thisanalysis), n) for thisanalysis in analyses_dicts[n]]
            for n in analyses_dicts]),
        #["synteny_analysis/blastp_results/reciprocal_best/{}_edges.txt".format(
        #    "_".join(thisanalysis)) for thisanalysis in analyses_list],
        #["synteny_analysis/MBH/{}_mutual_best_hits.tsv".format(
        #    "_".join(thisanalysis)) for thisanalysis in analyses_list],
        ## now perform the pairs analysis
        #["synteny_analysis/pair_analysis/{}_pairs_analysis.tsv".format(
        #    "_".join(thisanalysis)) for thisanalysis in analyses_list],
        #"synteny_analysis/clustering/MST_quartets.pdf",
        #"synteny_analysis/clustering/UPGMA_quartets.pdf",
        #"synteny_analysis/clustering/WPGMA_quartets.pdf",
        #"synteny_analysis/clustering/UPGMC_quartets.pdf",
        #"synteny_analysis/clustering/WPGMC_quartets.pdf",
        #"synteny_analysis/clustering/Ward_quartets.pdf",
        #"synteny_analysis/clustering/graph.pdf"

rule filter_prots_x:
    """
    Sometimes the prot file with have sequences that are not present in
     the chrom file. Make a prot file of only the proteins in the chrom file.
    """
    input:
        prots = lambda wildcards: config["xaxisspecies"][wildcards.xsample]["proteins"],
        chrom = lambda wildcards: config["xaxisspecies"][wildcards.xsample]["chrom"]
    output:
        pep = "synteny_analysis/db/xaxis/{xsample}_prots.pep"
    threads: 1
    run:
        odpf.filter_fasta_chrom(input.chrom, input.prots, output.pep)

rule filter_prots_y:
    """
    Sometimes the prot file with have sequences that are not present in
     the chrom file. Make a prot file of only the proteins in the chrom file.
    """
    input:
        prots = lambda wildcards: config["yaxisspecies"][wildcards.ysample]["proteins"],
        chrom = lambda wildcards: config["yaxisspecies"][wildcards.ysample]["chrom"]
    output:
        pep = "synteny_analysis/db/yaxis/{ysample}_prots.pep"
    threads: 1
    run:
        opdf.filter_fasta_chrom(input.chrom, input.prots, output.pep)

rule make_diamonddb_x:
    input:
        prots = lambda wildcards: config["xaxisspecies"][wildcards.xsample]["proteins"],
        pep = "synteny_analysis/db/xaxis/{xsample}_prots.pep"
    output:
        dmnd = "synteny_analysis/db/xaxis/dmnd/{xsample}_prots.dmnd"
    threads: workflow.cores - 1
    shell:
        """
        diamond makedb --in {input.pep} --db {output.dmnd}
        """

rule make_diamonddb_y:
    input:
        prots = lambda wildcards: config["yaxisspecies"][wildcards.ysample]["proteins"],
        pep = "synteny_analysis/db/yaxis/{ysample}_prots.pep"
    output:
        dmnd = "synteny_analysis/db/yaxis/dmnd/{ysample}_prots.dmnd"
    threads: workflow.cores - 1
    shell:
        """
        diamond makedb --in {input.pep} --db {output.dmnd}
        """

rule diamond_blast_x_to_y:
    input:
        xpep = "synteny_analysis/db/xaxis/{xsample}_prots.pep",
        ydmnd = "synteny_analysis/db/yaxis/dmnd/{ysample}_prots.dmnd",
    output:
        blastp = "synteny_analysis/blastp_results/xtoy/{xsample}_against_{ysample}.blastp",
    threads: workflow.cores - 1
    shell:
        """
        diamond blastp --query {input.xpep} --db {input.ydmnd} \
          --threads {threads} --evalue 1E-5 --outfmt 6 --out {output.blastp}
        """

rule diamond_blast_y_to_x:
    input:
        ypep = "synteny_analysis/db/yaxis/{ysample}_prots.pep",
        xdmnd = "synteny_analysis/db/xaxis/dmnd/{xsample}_prots.dmnd",
    output:
        blastp = "synteny_analysis/blastp_results/ytox/{ysample}_against_{xsample}.blastp",
    threads: workflow.cores - 1
    shell:
        """
        diamond blastp --query {input.ypep} --db {input.xdmnd} \
          --threads {threads} --evalue 1E-5 --outfmt 6 --out {output.blastp}
        """

rule absolute_best_from_blast_x_to_y:
    input:
        blastp = "synteny_analysis/blastp_results/xtoy/{xsample}_against_{ysample}.blastp",
    output:
        blastp = "synteny_analysis/blastp_results/xtoybest/{xsample}_against_{ysample}.blastp",
    threads: 1
    shell:
        """
        awk 'BEGIN{{former = ""}} {{if ($1 != former){{print($0)}}; former=$1}}' {input.blastp} > {output.blastp}
        """

rule absolute_best_from_blast_y_to_x:
    input:
        blastp = "synteny_analysis/blastp_results/ytox/{ysample}_against_{xsample}.blastp",
    output:
        blastp = "synteny_analysis/blastp_results/ytoxbest/{ysample}_against_{xsample}.blastp",
    threads: 1
    shell:
        """
        awk 'BEGIN{{former = ""}} {{if ($1 != former){{print($0)}}; former=$1}}' {input.blastp} > {output.blastp}
        """

rule reciprocal_best_hits:
    """
    finds the reciprocal best hits.
    reports it in the form of the blastp results from x -> y search
    """
    input:
        ytoxblastp = "synteny_analysis/blastp_results/ytoxbest/{ysample}_against_{xsample}.blastp",
        xtoyblastp = "synteny_analysis/blastp_results/xtoybest/{xsample}_against_{ysample}.blastp",
    output:
        xtoyblastp = "synteny_analysis/blastp_results/reciprocal_best/{xsample}_and_{ysample}_recip.temp.blastp",
    threads: 1
    run:
        pairs = set()
        #first look in y_to_x
        with open(input.ytoxblastp, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    splitd = line.split("\t")
                    add_this = (splitd[1], splitd[0])
                    if add_this in pairs:
                        raise IOError("This set was already found")
                    else:
                        pairs.add(add_this)
        # now go through x_to_y and filter
        out_handle = open(output.xtoyblastp, "w")
        with open(input.xtoyblastp, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    splitd = line.split("\t")
                    check_this = (splitd[0], splitd[1])
                    if check_this in pairs:
                        print(line, file=out_handle)
        out_handle.close()

rule n_ways_reciprocal_best:
    """
    Gets the mutual-best hits in n samples.

              A      B
             / \    /|\
            B___C  A-+-C , et cetera
                    \|/
                     D

    The output of this rule is a yaml file with admissible proteins from each sample
    """
    input:
        xtoyblastp = odpf.expand_avoid_matching_x_and_y(
            "synteny_analysis/blastp_results/reciprocal_best/{}_and_{}_recip.temp.blastp",
                config["xaxisspecies"], config["yaxisspecies"]),
        chrom = [config["xaxisspecies"][x]["chrom"]
                 for x in config["xaxisspecies"]]
    output:
        acceptable_prots = "synteny_analysis/blastp_results/reciprocal_best/{}_acceptable_prots_{{n}}_ways.txt".format(
            "_".join(["{{sample{}}}".format(i) for i in range(config["nways"])])),
        blast_network      = "synteny_analysis/blastp_results/reciprocal_best/{}_edges_{{n}}ways.txt".format(
            "_".join(["{{sample{}}}".format(i) for i in range(config["nways"])])),
        MBH                = "synteny_analysis/MBH/{}_mutual_best_hits_{{n}}ways.tsv".format(
            "_".join(["{{sample{}}}".format(i) for i in range(config["nways"])]))
    threads: 1
    params:
        num_ways = lambda wildcards: wildcards.n
    run:
        import networkx as nx
        species_string = output.acceptable_prots.split("/")[-1].replace("_acceptable_prots.txt", "")
        all_species = [x for x in species_string.split("_")]
        combos = list(combinations(all_species, 2))
        blastfiles = [["synteny_analysis/blastp_results/reciprocal_best/{}_and_{}_recip.temp.blastp".format(x[0], x[1]),
             x[0], x[1]] for x in combos]
        gene_to_species = {}
        gene_list = set()

        # get the chrom files
        chrom_dicts = {}
        for thisspecies in all_species:
            chrom_dicts[thisspecies] = pd.read_csv(
                config["xaxisspecies"][thisspecies]["chrom"],
                header=None, sep = "\t")
            chrom_dicts[thisspecies].columns = ["prot",
                "scaf", "direction", "start", "stop"]

        # initialize the graph
        G = nx.Graph()
        checked_names = set()
        for analysis in blastfiles:
            print(analysis)
            thisfile = analysis[0]
            a = analysis[1]
            b = analysis[2]
            with open(thisfile, "r") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        splitb = line.split()
                        agene = splitb[0]
                        bgene = splitb[1]
                        if a not in checked_names:
                            if agene in gene_list:
                                raise IOError("""We saw a gene twice. {}.
                                This means that two species have the same prot ids.""".format(agene))
                        gene_list.add(agene)
                        if b not in checked_names:
                            if bgene in gene_list:
                                raise IOError("""We saw a gene twice. {}.
                                This means that two species have the same prot ids.""".format(bgene))
                        gene_list.add(bgene)
                        gene_to_species[agene] = a
                        gene_to_species[bgene] = b
                        #add these since we've added the genes already
                        checked_names.add(a)
                        checked_names.add(b)
                        # now add the edge
                        G.add_edge(agene, bgene)
        remnodes = set()
        #get rid of things that couldn't possibly be an n-way best
        for thisentry in list(nx.connected_components(G)):
            if len(thisentry) < len(all_species):
                for node in thisentry:
                    remnodes.add(node)
        for node in remnodes:
            G.remove_node(node)
        remnodes.clear()
        # now get rid of nodes that don't have the correct degree
        #  to be n-connected
        for thisnode in G.nodes:
            if G.degree[thisnode] != (len(all_species) - 1):
                remnodes.add(thisnode)
        for node in remnodes:
            G.remove_node(node)
        remnodes.clear()
        # now get the n-connected components
        nwaybest = []
        for thisentry in list(nx.connected_components(G)):
            if len(thisentry) == len(all_species):
                nwaybest.append(thisentry)
            else:
                for node in thisentry:
                    remnodes.add(node)
        #cleanup the graph
        for node in remnodes:
            G.remove_node(node)
        remnodes.clear()
        # print out the graph
        uniquenodes = set()
        with open(output.blast_network, "w") as f:
            for thisedge in G.edges:
                print("{}\t{}".format(thisedge[0], thisedge[1]), file = f)
                uniquenodes.add(thisedge[0])
                uniquenodes.add(thisedge[1])
        with open(output.acceptable_prots, "w") as f:
            for thisnode in uniquenodes:
                print(thisnode, file = f)
        # print out the list of genes
        CCs = []
        for thisentry in list(nx.connected_components(G)):
            ccdict = {"MBH": "MBH{}way_{}_{}".format(
                len(all_species), "_".join(all_species), len(CCs)+1)}
            for node in thisentry:
                ccdict[gene_to_species[node]] = node
            CCs.append(ccdict)
        genesdf = pd.DataFrame(CCs)

        # now add the other info
        for thisspecies in all_species:
            genesdf["{}_scaf".format(thisspecies)] = genesdf[thisspecies].map( dict(zip(chrom_dicts[thisspecies].prot, chrom_dicts[thisspecies].scaf)) )
            #genesdf["{}_scaf".format(thisspecies)] = genesdf[thisspecies].map( dict(zip(chrom_dicts[thisspecies].prot, chrom_dicts[thisspecies].scaf)) )
        print(genesdf)
        genesdf.to_csv(output.MBH, sep="\t")

#rule pair_analysis:
#    input:
#        MBH = "synteny_analysis/MBH/{}_mutual_best_hits.tsv".format(
#            "_".join(["{{sample{}}}".format(i) for i in range(config["nways"])]))
#    output:
#        MBH = "synteny_analysis/pair_analysis/{}_pairs_analysis.tsv".format(
#            "_".join(["{{sample{}}}".format(i) for i in range(config["nways"])]))
#    threads: 1
#    run:
#        species_string = output.MBH.split("/")[-1].replace("_pairs_analysis.tsv", "")
#        all_species = [x for x in species_string.split("_")]
#        print(all_species)
#        MBHdf = pd.read_csv(input.MBH, sep = "\t", index_col = 0)
#        print(MBHdf)
#        matches_counts = {}
#        num_analyses = 0
#        num_pairs = len(list(combinations(range(len(MBHdf)),2)))
#        for i in range(len(MBHdf)):
#            for ii in range(i+1, len(MBHdf)):
#                subdf = MBHdf.iloc[[i, ii]]
#                matches = []
#                for thisspecies in all_species:
#                    if len(subdf["{}_scaf".format(thisspecies)].unique()) == 1:
#                        matches.append(thisspecies)
#                match_string = "_".join(matches)
#                if match_string == "":
#                    match_string = "NoMatches"
#                if match_string not in matches_counts:
#                    matches_counts[match_string] = 0
#                matches_counts[match_string] += 1
#                num_analyses += 1
#                if num_analyses % 10 == 0:
#                    print("  - Finished {}/{} ({:.2f}%) pairs.  ".format(num_analyses, num_pairs, (num_analyses/num_pairs)*100), end = "\r")
#        print(matches_counts)
#        with open(output.MBH, "w") as f:
#            for key in dict(
#                    sorted(matches_counts.items(), key=lambda item: item[0].count("_"))):
#                print("{}\t{}".format(key, matches_counts[key]), file=f)

#rule generate_distance_matrix:
#    """
#    column 1 is source, column 2 is target, column 3 is counts
#    """
#    input:
#        all_pairs = ["synteny_analysis/pair_analysis/{}_pairs_analysis.tsv".format(
#            "_".join(thisanalysis)) for thisanalysis in analyses_dicts[]]
#    output:
#        pairwise_matrix = "synteny_analysis/matrices/pairwise_matrix_quartets_{n}_ways.tsv"
#    threads: 1
#    run:
#        pairwise_counts = {}
#        for thispairfile in input.all_pairs:
#            with open(thispairfile, "r") as f:
#                for line in f:
#                    line = line.strip()
#                    if line:
#                        source = line.split()[0]
#                        count  = int(line.split()[1])
#                        if source.count("_") == 1:
#                            if source not in pairwise_counts:
#                                pairwise_counts[source] = 0
#                            pairwise_counts[source] += count
#        with open(output.pairwise_matrix, "w") as f:
#            for entry in sorted(pairwise_counts):
#                source = entry.split("_")[0]
#                target = entry.split("_")[1]
#                print("{}\t{}\t{}".format(
#                    source, target, pairwise_counts[entry]),
#                      file = f)
#
#def distance_matrix(matrix_file):
#    """
#    Sets up a distance matrix using the source file
#    """
#    rows = {}
#    with open(matrix_file, "r") as f:
#        for line in f:
#            line = line.strip()
#            if line:
#                source = line.split()[0]
#                target = line.split()[1]
#                count  = int(line.split()[2])
#                if source not in rows:
#                    rows[source] = {}
#                if target not in rows:
#                    rows[target] = {}
#                rows[source][target] = count
#                rows[target][source] = count
#    df = pd.DataFrame(rows)
#    df = df.sort_index()
#    df = 1/df
#    df = df.fillna(0)
#    return df
#
#rule hierarchy_quartets:
#    input:
#        pairwise_matrix = "synteny_analysis/matrices/pairwise_matrix_quartets.tsv"
#    output:
#        MST   = "synteny_analysis/clustering/MST_quartets.pdf",
#        UPGMA = "synteny_analysis/clustering/UPGMA_quartets.pdf",
#        WPGMA = "synteny_analysis/clustering/WPGMA_quartets.pdf",
#        UPGMC = "synteny_analysis/clustering/UPGMC_quartets.pdf",
#        WPGMC = "synteny_analysis/clustering/WPGMC_quartets.pdf",
#        Ward =  "synteny_analysis/clustering/Ward_quartets.pdf"
#    threads: 1
#    threads: 1
#    run:
#        from scipy.cluster.hierarchy import dendrogram, linkage
#        from scipy.spatial.distance import squareform
#
#        #sys.exit()
#        df = distance_matrix(input.pairwise_matrix)
#        dists = squareform(df)
#        method_dict = {"single": output.MST,
#                       "average": output.UPGMA,
#                       "weighted": output.WPGMA,
#                       "centroid": output.UPGMC,
#                       "median":   output.WPGMC,
#                       "ward":     output.Ward}
#        for thismethod in method_dict:
#            print(dists)
#            linkage_matrix = linkage(dists, thismethod)
#            print("linkage_matrix")
#            print(linkage_matrix)
#            fig = dendrogram(linkage_matrix, labels=df.columns,
#                             orientation = "left",
#                             leaf_font_size = 6)
#            print(fig)
#            plt.title("test")
#            plt.savefig(method_dict[thismethod])
#            plt.close()
#
#rule plot_graph:
#    input:
#        pairwise_matrix = "synteny_analysis/matrices/pairwise_matrix_quartets.tsv"
#    output:
#        graph = "synteny_analysis/clustering/graph.pdf"
#    threads: 1
#    run:
#        import networkx as nx
#        source_target_weight = []
#        highest_count = 0
#        with open(input.pairwise_matrix, "r") as f:
#            for line in f:
#                line = line.strip()
#                if line:
#                    source = line.split()[0]
#                    target = line.split()[1]
#                    count  = int(line.split()[2])
#                    if count > highest_count:
#                        highest_count = count
#                    source_target_weight.append([source, target, count])
#        g = nx.Graph()
#
#        maxLW = 20
#        for entry in source_target_weight:
#            g.add_edge(entry[0],
#                       entry[1], weight = (entry[2]/highest_count)*maxLW)
#        pos=nx.spring_layout(g) # pos = nx.nx_agraph.graphviz_layout(G)
#        edges = g.edges()
#        weights = [g[u][v]['weight'] for u,v in edges]
#        nx.draw_networkx(g,pos, edges = edges, width = weights)
#        plt.savefig(output.graph)
