import operator
import numpy as np
import dendropy as dp

## Import
# from pp_range import *

## Settings
# tree_fn = 'my_tree.txt'
# pp_range_fn = 'my_run.area_states.txt'
# taxon_name_1,taxon_name_2 = 't1','t50'

## Create data structures
# range_dict = make_range_dict(pp_range_fn, fburn=0.25)
# idx_list = get_node_idx_list(pp_range_fn)
# node_key = get_node_key(tree_fn, taxon_name_1, taxon_name_2, idx_list)
# best_ranges = get_best_ranges(range_dict, node_key)
# prob_best_area_pairs = make_posterior_area_pairs(range_dict, node_key)
# best_area_pairs = get_best_area_pairs(prob_best_area_pairs)
# best_areas = get_best_area_pairs(prob_best_area_pairs, show_marginal=True)

## List ten most probable MRCA ranges
# best_ranges[0:10]

## How much probability do the top 10 MRCA ranges account for?
# sum(p[1] for p in best_ranges[0:10]])

## What are the 10 most probable pairs of areas co-occupied by MRCA (excluding marginal area probs)?
# best_area_pairs[0:10]

## What are the 10 most probable pairs of areas co-occupied by MRCA (including marginal area probs)?
# best_areas[0:10]

## Plot the heatmap of best_area_pairs
# title_str = 'Prob(areas i,j co-occur in MRCA('+taxon_name_1+','+taxon_name_2+')\'s range'
# plot_posterior_area_pairs(prob_best_area_pairs, title_str)

def make_range_dict(fn,fburn=0.25):
    '''
    fn is the filename containing .area_states.txt
    fburn is the fraction of MCMC cycles to discard
    '''
    d = {}
    f = open(fn, 'r')                           # Open the file (first argument)
    lines = f.readlines()                       # Store the file contents
    burn = fburn * float(lines[-1].split()[0])  # Begin sampling after burn
    for i,l in enumerate(lines):                # Read each line
        if i == 0: continue                     # Skip the header (first line)
        fields = l.rstrip('\n').split()         # Split string into string tokens
        node = fields[2]
        if burn > int(fields[0]):               # Skip cycles before burnin
            continue
        elif node not in d:                     # Add new nodes to dict
            d[node] = {}                        # Give each node its own dict
        states = fields[3]
        if states not in d[node]:               # Add new range to node's dict
            d[node][states] = 0
        d[node][states] += 1                    # Increment node-range count
    return d                                    # Return node-range dict

def get_best_ranges(d,k='mrca'):
    '''
    d is the dictionary containing node-range frequencies
    k is the node (key) for the dictionary ('mrca' finds the correct key)
    Returns list of best (range,posterior) pairs for node k
    '''
    if k=='mrca': k = min(d.keys())             # Get the mrca key (default)
    n_samples = sum(d[k].values())              # Get the normalizing sum
    for i,r in enumerate(d[k].keys()):          # Turn counts into frequencies
        d[k][r] = float(d[k][r]) / n_samples

    # Get a list of key-value pairs reverse sorted by value
    best_ranges = sorted(d[k].iteritems(),key=operator.itemgetter(1))[::-1]

    return best_ranges                          # Return the reverse-sorted list

def make_posterior_area_pairs(d,k='mrca'):
    '''
    d is the dictionary containing node-range frequencies
    k is the node (key) for the dictionary ('mrca' finds the correct key)
    Returns matrix of posterior prob. two areas co-occur in range
    '''
    if k=='mrca': k = min(d.keys())             # Get the mrca key (default)
    n_samples = sum(d[k].values())              # Get the normalizing sum
    states = d[k].keys()                        # Get list of sampled ranges
    n_areas = len(states[0])                    # Get number of areas
    m = np.zeros([n_areas,n_areas])             # Initialize matrix full of zeroes
    for i,s in enumerate(states):               # Go over sampled ranges
        areas = []
        for j,a in enumerate(s):                # Go over areas per range
            if a == '1': areas.append(j)        # Store areas marked present
        for j in areas:                         # Go over areas marked present
            for l in areas:
                m[j,l] += d[k][s]               # Increment pair by number of
                                                #   times range sampled
    m /= n_samples                              # Turn counts into frequencies
    return m                                    # Returns the matrix

def get_best_area_pairs(m,show_marginal=False):
    '''
    m is a matrix containing the posterior prob. two areas
        co-occur in the range of interest
    ignore_marginal is a flag to ignore marginal probabilities (along diag.)
    Returns list of best (area-pair,posterior) pairs for node k
    '''
    d = {}
    n_areas = m.shape[0]                        # Get number of areas
    for i in range(n_areas):
        for j in range(i,n_areas):
            if i == j and not show_marginal: continue
            d[(i,j)] = m[i,j]                   # Store pp for i,j area-pair

    # Get a list of key-value pairs reverse sorted by value
    best_area_pairs = sorted(d.iteritems(),key=operator.itemgetter(1))[::-1]
    return best_area_pairs                      # Returns the reverse-sorted list

def plot_posterior_area_pair(m,title=''):
    '''
    m is a matrix containing the posterior prob. two areas
        co-occur in the range of interest
    title is a string to appear in the figure
    '''
    from matplotlib import pyplot as plt        # For plotting
    from pylab import colorbar                  # For colorbar

    fig,ax = plt.subplots()                     # Create figure and axis
    hm = plt.pcolor(m,                          # Create heatmap
            cmap=plt.get_cmap('Blues',20))          # Use 20 shades of blue
    plt.clim(0.,1.)                             # Set the colorbar bounds
    cbar=plt.colorbar(hm)                       # Show colorbar
    fig.suptitle(title,                         # Show figure title
            fontsize=14,fontweight='bold')      # ... with font settings
    n_areas = m.shape[0]                        # Get number of areas

    # Generate text for axis ticks
    tick_txt = [ str(s) if (s)%5==0 else '' for s in range(n_areas) ]
    ax.set_xticks(np.arange(n_areas)+0.5,minor=False)
    ax.set_yticks(np.arange(n_areas)+0.5,minor=False)
    ax.set_xticklabels(tick_txt)
    ax.set_yticklabels(tick_txt)
    plt.savefig('range_area_pair_pp.pdf')      # Save to pdf

def get_area_coords(fn):
    '''
    fn is the filename for the geo.txt file
    Returns a dictionary of area_index keys, {lat,lon} values
    '''
    d = {}                                      # Create an empty dictionary
    f = open (fn,'r')                           # Open the file for reading
    for i,line in enumerate(f.readlines()):     # Go over each line in the file
        if i == 0: continue                     # Skip the header
        fields=line.rstrip('\n').split()        # Split lines by whitespace
        d[i-1] = {  'lat':float(fields[0]),     # Assign lat,lon to area index
                    'lon':float(fields[1])}
    return d

def get_node_key(fn,tn1,tn2,idx_list):
    '''
    fn is the filename for the tree.txt file
    tn1 and tn2 are the taxon names whose MRCA you seek
    idx_list is the list of id values returned by get_node_idx_list()
    Returns the key for the node that is the MRCA of tn1 and tn2
    '''
    t = dp.Tree(stream = open(fn), schema = 'newick')    # Load the tree
    n_taxa = len(t.leaf_nodes())                    # Get number of tips

    # Get a list of internal nodes, sorted in postorder traversal
    int_nodes = [ n for n in t.postorder_node_iter() if not n.is_leaf() ]

    # Get the node id assigned to the MRCA of tn1 and tn2
    idx = int_nodes.index(t.mrca(taxon_labels=[tn1,tn2]))
    return str(idx_list[idx])                       # Return the key (str)

def get_node_idx_list(fn):
    '''
    fn is the filename for the area_states.txt file
    Returns a list of node id values sorted by post-order traversal
    '''
    idx_list = []
    f = open(fn,'r')                            # Open the file for reading
    n_nodes = 0
    first_sample = ''
    for i,line in enumerate(f.readlines()):     # Go over each line
        if i == 0: continue                     # Skip the header
        fields = line.rstrip('\n').split()        # Split lines by whitespace
        if i == 1: first_sample = fields[0]     # Get the first cycle number
        if fields[0] != first_sample: break     # Stop past first cycles
        n_nodes += 1                            # Count the number of nodes
        idx_list.append(int(fields[2]))         # Add the id to list (postorder)
    n_taxa = (n_nodes + 2) / 2                  # Find the number of tips
    f.close()
    return idx_list[n_taxa:]                    # Return only the internal nodes
