import scipy as sp

from sklearn import metrics

import torch

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import  GCNConv, MessagePassing
import pathpy as pp

from torch.nn import Linear, ModuleList, Module
import torch.nn.functional as F

import pickle
import json





def load_paths(data, train_frac=0.5):

    with open('data/{0}_paths_train_{1}.pickle'.format(data, train_frac), 'rb') as f:
        paths_train = pickle.load(f)
    with open('data/{0}_paths_val_{1}.pickle'.format(data, train_frac), 'rb') as f:
        paths_val = pickle.load(f)

    return paths_train, paths_val

def time_split(network, train_frac=0.5):
    split_time = min(network.ordered_times) + network.observation_length()*train_frac
    filter_1 = lambda v, w, t: t<split_time
    filter_2 = lambda v, w, t: t>=split_time
    return network.filter_edges(filter_1), network.filter_edges(filter_2)

def get_paths(net, delta=1):
    return pp.path_extraction.sample_paths_from_temporal_network_dag(net, delta=delta, num_roots = 20, max_subpath_length=2)

def centralities(causal_paths, centrality):
    if centrality == 'betweenness':
        temporal_c = pp.algorithms.centralities.betweenness(causal_paths)
    elif centrality == 'closeness':
        temporal_c = pp.algorithms.centralities.closeness(causal_paths)
    return temporal_c


def get_edge_index(network, directed):
    sources = []
    targets = []
    weights = []
    map = network.node_to_name_map()
    for (v,w) in network.edges:
        sources.append(map[v])
        targets.append(map[w])
        if isinstance(network, pp.HigherOrderNetwork):
            weights.append(network.edges[(v,w)]['weight'][0])
        else:
            weights.append(network.edges[(v,w)]['weight'])
    edge_index = torch.tensor([sources, targets])
    edge_weight = torch.tensor(weights).float()  

    if not directed:
        edge_index, edge_weight = torch_geometric.utils.to_undirected(edge_index, edge_weight)
    return edge_index, edge_weight


def pp_to_pyg(paths, c, max_nodes, directed):
    network = pp.Network.from_paths(paths)
    edge_index, edge_weight = get_edge_index(network, directed=directed)
    map = network.node_to_name_map()
    centralities = torch.zeros(network.ncount())
    for v in network.nodes:
        centralities[map[v]] = c[v]
    data = Data(x=torch.nn.functional.one_hot(torch.arange(0, max_nodes)).float()[torch.randperm(max_nodes)][:network.ncount()], edge_index=edge_index, y=centralities, edge_weight=edge_weight)
    return data

def get_data(exp_name, exp, max_nodes, train_frac, centrality):

    paths_train, paths_val = load_paths(exp_name, train_frac)

    t_1_temporal_centralities = centralities(paths_train, centrality)
    t_2_temporal_centralities = centralities(paths_val, centrality)
    
    data_train = pp_to_pyg(paths_train, t_1_temporal_centralities, max_nodes, exp['directed'])
    data_val = pp_to_pyg(paths_val, t_2_temporal_centralities, max_nodes, exp['directed'])

    return data_train, data_val


def train_model(data_train, model, lr=0.01, epochs=3000):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr,weight_decay=5e-4)
    losses = []

    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        out,_ = model(data_train)
        loss = torch.nn.functional.mse_loss(out.squeeze(), data_train.y)
        loss.backward()
        optimizer.step()
        losses.append(loss.cpu().detach().numpy())
        
    return model


def hits_in_k(gt, pred, k):
    A = set(torch.topk(gt, k=k, dim=0).indices.squeeze().detach().cpu().tolist())
    B = set(torch.topk(pred, k=k, dim=0).indices.squeeze().detach().cpu().tolist())
    C = A.intersection(B)
    return len(C)


def eval_model(data_val, model):
    model.eval()
    predicted_centralities = model(data_val)[0].squeeze().detach()


    results = {}
    results['MAE']=float(metrics.mean_absolute_error(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()))
    results['Spearmanr_statistic']=sp.stats.spearmanr(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()).statistic
    results['Spearmanr_p']=sp.stats.spearmanr(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()).pvalue
    results['Kendalltau_statistic']=sp.stats.kendalltau(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()).statistic
    results['Kendalltau_p']=sp.stats.kendalltau(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()).pvalue
    results['hitsIn30'] = hits_in_k(data_val.y, predicted_centralities, k=30)
    results['hitsIn30'] = hits_in_k(data_val.y, predicted_centralities, k=30)
    results['hitsIn10'] = hits_in_k(data_val.y, predicted_centralities, k=10)
    results['hitsIn5'] = hits_in_k(data_val.y, predicted_centralities, k=5)
    
    return results


class GCN(torch.nn.Module):

    def __init__(self, num_features, hidden_dim_1 = 16, out_ch=1, dropout=0.2):
        super().__init__()

        self.input_to_hidden = torch_geometric.nn.GCNConv(num_features, hidden_dim_1)

        self.hidden_to_hidden2 = torch_geometric.nn.GCNConv(hidden_dim_1,8)

        self.hidden_to_output= torch.nn.Linear(8,out_ch)

        self.p = dropout
        
    def forward(self, data):        
        
        # first graph convolution -> map nodes to representations in hidden_dim dimensions
        x = self.input_to_hidden(data.x, data.edge_index, data.edge_weight)

        # non-linear activation function
        x = torch.nn.functional.sigmoid(x)
        x = torch.nn.functional.dropout(x, self.p, training=self.training)

        # second graph convolution -> map nodes to representations in hidden_dim 2 dimensions
        x = self.hidden_to_hidden2(x, data.edge_index, data.edge_weight)

        # non-linear activation function        
        # x = torch.nn.functional.elu(x)
        # x = torch.nn.functional.dropout(x, self.p, training=self.training)

        # third graph convolution -> map nodes to representations in hidden_dim 3 dimensions
        #x = self.hidden2_to_hidden3(x, data.edge_index, data.edge_weight)

        # non-linear activation function        
        x = torch.nn.functional.sigmoid(x)
        x = torch.nn.functional.dropout(x, self.p, training=self.training)



        embedding = x
        x = self.hidden_to_output(x)

        # output class probabilities
        return torch.nn.functional.elu(x), embedding
    
    
class BipartiteGraphOperator(MessagePassing):
    def __init__(self, in_ch, out_ch):
        super(BipartiteGraphOperator, self).__init__('add')
        self.lin1 = Linear(in_ch, out_ch)
        self.lin2 = Linear(in_ch, out_ch)

    def forward(self, x, bipartite_index, N, M):
        x = (self.lin1(x[0]), self.lin2(x[1]))
        return self.propagate(bipartite_index, size=(N, M), x=x)

class DBGNN(Module):
    """Implementation of time-aware graph neural network DBGNN
    Reference paper: https://openreview.net/pdf?id=Dbkqs1EhTr

    Args:
        num_classes: int - number of classes
        num_features: list - number of features for first order and higher order nodes, e.g. [first_order_num_features, second_order_num_features]
        hidden_dims: list - number of hidden dimensions per each layer in the first/higher order network
        p_dropout: float - drop-out probability
    """
    def __init__(
        self,
        out_channels,
        num_features,
        hidden_dims,
        p_dropout=0.0
        ):
        super().__init__()

        self.num_features = num_features
        self.out_channels = out_channels
        self.hidden_dims = hidden_dims
        self.p_dropout = p_dropout

        # higher-order layers
        self.higher_order_layers = ModuleList()
        self.higher_order_layers.append(GCNConv(self.num_features[1], self.hidden_dims[0]))

        # first-order layers
        self.first_order_layers = ModuleList()
        self.first_order_layers.append(GCNConv(self.num_features[0], self.hidden_dims[0]))

        for dim in range(1, len(self.hidden_dims)-1):
            # higher-order layers
            self.higher_order_layers.append(GCNConv(self.hidden_dims[dim-1], self.hidden_dims[dim]))
            # first-order layers
            self.first_order_layers.append(GCNConv(self.hidden_dims[dim-1], self.hidden_dims[dim]))

        self.bipartite_layer = BipartiteGraphOperator(self.hidden_dims[-2], self.hidden_dims[-1])

        # Linear layer
        self.lin = torch.nn.Linear(self.hidden_dims[-1], out_channels)



    def forward(self, data):

        x = data.x
        x_h = data.x_h

        # First-order convolutions
        for layer in self.first_order_layers:
            x = F.dropout(x, p=self.p_dropout, training=self.training)
            x = F.sigmoid(layer(x, data.edge_index, data.edge_weight))
        x = F.dropout(x, p=self.p_dropout, training=self.training)

        # Second-order convolutions
        for layer in self.higher_order_layers:
            x_h = F.dropout(x_h, p=self.p_dropout, training=self.training)
            x_h = F.sigmoid(layer(x_h, data.edge_index_higher_order, data.edge_weight_higher_order))
        x_h = F.dropout(x_h, p=self.p_dropout, training=self.training)

        # Bipartite message passing
        x = torch.nn.functional.elu(self.bipartite_layer((x_h, x), data.bipartite_edge_index, N = data.num_ho_nodes, M= data.num_nodes))
        x = F.dropout(x, p=self.p_dropout, training=self.training)

        embedding = x
        # Linear layer
        x = self.lin(x)

        #x = F.relu(x)
        return F.elu(x), embedding
    
def get_bipartite_edge_index(g2, g1):
    sources = []
    targets = []

    map_2 = g2.node_to_name_map()
    map_1 = g1.node_to_name_map()

    for v in g2.nodes:
        s = map_2[v]
        t = map_1[g2.higher_order_node_to_path(v)[-1]]
        sources.append(s)
        targets.append(t)

    return torch.tensor([sources, targets])


def train_dbgnn(model, data, n_epochs, lr, device):

        model.train()
    
        optimizer = torch.optim.Adam(model.parameters(),  lr=lr, weight_decay=5e-4)

        data = data.to(device)

        losses = []
        for epoch in range(n_epochs):
                output,_ = model(data) 
                loss = torch.nn.functional.mse_loss(output.squeeze(), data.y)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                losses.append(loss.detach().cpu())


def eval_dbgnn(data_val, model):
    model.eval()
    predicted_centralities = model(data_val)[0].squeeze().detach()
    results = {}
    results['MAE']= float(metrics.mean_absolute_error(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()))
    results['Spearmanr_stat']=sp.stats.spearmanr(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()).statistic
    results['Spearmanr_p']=sp.stats.spearmanr(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()).pvalue
    results['Kendalltau_stat']=sp.stats.kendalltau(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()).statistic
    results['Kendalltau_p']=sp.stats.kendalltau(data_val.y.squeeze().cpu().numpy(), predicted_centralities.squeeze().cpu().numpy()).pvalue
    results['hitsIn30'] = hits_in_k(data_val.y, predicted_centralities, k=30)
    results['hitsIn10'] = hits_in_k(data_val.y, predicted_centralities, k=10)
    results['hitsIn5'] = hits_in_k(data_val.y, predicted_centralities, k=5)    
    
    return results    



def run_exp(exp_name, exp, defaults, centrality,lr,device):
    
    if exp['train_fracs']:
        train_frac = exp['train_fracs']
    else:
        train_frac = defaults['train_fracs']

    if exp['epochs']:
            n_epochs = exp['epochs']
    else:
        n_epochs = defaults['epochs']

    lr = lr
    paths_train, paths_val = load_paths(exp_name, train_frac)
    
    n_train = pp.Network.from_paths(paths_train)
    n_val = pp.Network.from_paths(paths_val)
    
    max_fo_nodes = max(n_train.ncount(), n_val.ncount())
    max_ho_nodes = max(n_train.ecount(), n_val.ecount())

    

    print('===\nRunning experiment {0} for lr {1}'.format(exp_name ,lr))

    data_train, data_val = get_data(exp_name, exp, max_nodes = max_fo_nodes, train_frac=train_frac, centrality= centrality)

    # GCN 
    data_val.to(device)
    data_train.to(device)
    model = GCN(data_train.num_node_features,dropout=0.0)
    model.to(device)

    model = train_model(data_train, model, lr=lr, epochs=n_epochs)

    results_gcn = eval_model(data_val, model)

    # Train DBGNN
    G_m = pp.MultiOrderModel(paths_train, max_order=2)
    g2 = G_m.layers[2]
    pp.Network.from_paths(paths_train)
    g1 = pp.Network.from_paths(paths_train)

    ho_index, ho_weights = get_edge_index(g2, directed=True)

    data = Data(
        num_nodes = g1.ncount(),
        num_ho_nodes = g2.ncount(),
        x = torch.eye(max_fo_nodes, max_fo_nodes)[torch.randperm(max_fo_nodes)][:g1.ncount()],
        x_h = torch.eye(max_ho_nodes, max_ho_nodes)[torch.randperm(max_ho_nodes)][:g2.ncount()],
        edge_index = data_train.edge_index,
        edge_weight = data_train.edge_weight,
        edge_index_higher_order = ho_index,
        edge_weight_higher_order = ho_weights.float(),
        bipartite_edge_index = get_bipartite_edge_index(g2, g1),
        y = data_train.y
    )

    model_dbgnn = DBGNN(
        num_features =[max_fo_nodes, max_ho_nodes],
        out_channels= 1,
        hidden_dims = [16,8],
        p_dropout = 0.0
        ).to(device)

    train_dbgnn(model_dbgnn, data=data, n_epochs=n_epochs, lr=lr, device=device)
    

    # Evaluate DBGNN
    G_m = pp.MultiOrderModel(paths_val, max_order=2)
    g2 = G_m.layers[2]

    g1 =  pp.Network.from_paths(paths_val) ##pp.Network.from_temporal_network(t_2)

    fo_index, fo_weights = get_edge_index(g1, directed=exp['directed'])
    ho_index, ho_weights = get_edge_index(g2, directed=True)

    data_val_dbgnn = Data(
        num_nodes = g1.ncount(),
        num_ho_nodes = g2.ncount(),
        x = torch.eye(max_fo_nodes, max_fo_nodes)[:g1.ncount()],
        x_h = torch.eye(max_ho_nodes, max_ho_nodes)[:g2.ncount()],
        edge_index = fo_index,
        edge_weight = fo_weights,
        edge_index_higher_order = ho_index,
        edge_weight_higher_order = ho_weights.float(),
        bipartite_edge_index = get_bipartite_edge_index(g2, g1),
        y = data_val.y
    )
    data_val_dbgnn.to(device)

    results_dbgnn = eval_dbgnn(data_val_dbgnn, model_dbgnn)    

    return results_gcn, results_dbgnn
    


def run_experiment(exp_name, centrality, lr, device):

    with open('experiments.json', 'r') as f:
        experiments = json.load(f)

    defaults = experiments.pop('default_params')

    results_gcn, results_dbgnn = run_exp(exp_name, experiments[exp_name], defaults, centrality, lr, device)

    return results_gcn, results_dbgnn