import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import NNConv, Set2Set, global_mean_pool


class Encoder(nn.Module):
    def __init__(self, N_types, embed_dim, dim):
        super(Encoder, self).__init__()
        self.N_types = N_types
        self.embed_dim = embed_dim
        self.dim = dim

        self.embed = nn.Embedding(N_types, embed_dim)

        self.lin0 = nn.Linear(embed_dim + 1, dim)

        NN = nn.Sequential(nn.Linear(1, 128), nn.LeakyReLU(),
                           nn.Linear(128, dim * dim))
        self.conv = NNConv(dim, dim, NN, aggr='mean', root_weight=False)
        self.gru = nn.GRU(dim, dim)

        self.set2set = Set2Set(dim, processing_steps=3)
        self.lin1 = nn.Linear(2 * dim, dim)

    def forward(self, data):
        x_emb = self.embed(data.x[:, 0])
        x_charge = data.x[:, 1].float().view(-1, 1)
        x = torch.cat((x_emb, x_charge), dim=1)
        out = F.leaky_relu(self.lin0(x))
        h = out.unsqueeze(0)

        for i in range(5):
            m = F.relu(self.conv(out, data.edge_index, data.edge_attr))
            out, h = self.gru(m.unsqueeze(0), h)
            out = out.squeeze(0)

        out = self.set2set(out, data.batch)
        out = F.leaky_relu(self.lin1(out))
        return out


class Encoder_simple(nn.Module):
    def __init__(self, N_types, embed_dim, dim):
        super(Encoder_simple, self).__init__()
        self.N_types = N_types
        self.embed_dim = embed_dim
        self.dim = dim

        self.embed = nn.Embedding(N_types, embed_dim)

        self.lin0 = nn.Linear(embed_dim + 1, dim)

        NN = nn.Sequential(nn.Linear(1, 128), nn.LeakyReLU(),
                           nn.Linear(128, dim * dim))
        self.conv1 = NNConv(dim, dim, NN, aggr='mean')

        NN1 = nn.Sequential(nn.Linear(1, 128), nn.LeakyReLU(),
                            nn.Linear(128, dim * dim))
        self.conv2 = NNConv(dim, dim, NN1, aggr='mean')

        self.lin1 = nn.Linear(dim, dim)

    def forward(self, data):
        x_emb = self.embed(data.x[:, 0])
        x_charge = data.x[:, 1].float().view(-1, 1)
        x = torch.cat((x_emb, x_charge), dim=1)

        x = F.leaky_relu(self.lin0(x))
        x = F.leaky_relu(self.conv1(x, data.edge_index, data.edge_attr))
        x = F.leaky_relu(self.conv2(x, data.edge_index, data.edge_attr))

        x = global_mean_pool(x, data.batch)
        return F.leaky_relu(self.lin1(x))


class Decoder(nn.Module):
    def __init__(self, N_types, max_nodes, x_rep_dim, latent_dim, node_rep_dim,
                 device, undirected, one_shot, pred_edge, multi_ae,
                 multi_node):
        super(Decoder, self).__init__()

        self.N_types = N_types
        self.max_nodes = max_nodes
        self.x_rep_dim = x_rep_dim
        self.latent_dim = latent_dim
        self.node_rep_dim = node_rep_dim
        self.device = device
        self.undirected = undirected
        self.one_shot = one_shot
        self.pred_edge = pred_edge
        self.multi_node = multi_node
        self.multi_ae = multi_ae

        if self.one_shot:
            self.triu_mask = torch.ones(max_nodes, max_nodes,
                                        device=device).triu_().bool()
            if self.undirected:
                self.fin_nf = nn.Sequential(
                    nn.Linear(latent_dim, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim, max_nodes * (N_types + 1)))

                self.fin_adj = nn.Sequential(
                    nn.Linear(latent_dim, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim,
                              max_nodes * (max_nodes + 1) // 2))

                if self.pred_edge:
                    self.fin_ef = nn.Sequential(
                        nn.Linear(latent_dim, x_rep_dim), nn.LeakyReLU(),
                        nn.Linear(x_rep_dim, x_rep_dim), nn.LeakyReLU(),
                        nn.Linear(x_rep_dim,
                                  max_nodes * (max_nodes + 1) // 2))
            else:
                self.fin_nf = nn.Sequential(
                    nn.Linear(latent_dim, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim, max_nodes * (N_types + 1)))

                self.fin_adj = nn.Sequential(
                    nn.Linear(latent_dim, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim, max_nodes * (max_nodes + 1)))

                if self.pred_edge:
                    self.fin_ef = nn.Sequential(
                        nn.Linear(latent_dim, x_rep_dim), nn.LeakyReLU(),
                        nn.Linear(x_rep_dim, x_rep_dim), nn.LeakyReLU(),
                        nn.Linear(x_rep_dim, max_nodes * (max_nodes + 1)))
        else:
            if self.multi_node:
                self.fin_nf = nn.ModuleList([
                    nn.Sequential(nn.Linear(latent_dim, x_rep_dim // 2),
                                  nn.LeakyReLU(),
                                  nn.Linear(x_rep_dim // 2, x_rep_dim),
                                  nn.LeakyReLU(),
                                  nn.Linear(x_rep_dim, x_rep_dim // 2),
                                  nn.LeakyReLU(),
                                  nn.Linear(x_rep_dim // 2, N_types))
                    for i in range(max_nodes)
                ])
                self.fin_charge = nn.ModuleList([
                    nn.Sequential(nn.Linear(latent_dim, x_rep_dim),
                                  nn.LeakyReLU(),
                                  nn.Linear(x_rep_dim, x_rep_dim // 2),
                                  nn.LeakyReLU(), nn.Linear(x_rep_dim // 2, 1))
                    for i in range(max_nodes)
                ])
            else:
                self.fin_nf = nn.Sequential(
                    nn.Linear(latent_dim, x_rep_dim // 2), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim // 2, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim, max_nodes * N_types))
                self.fin_charge = nn.Sequential(
                    nn.Linear(latent_dim, x_rep_dim // 2), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim // 2, x_rep_dim), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim, max_nodes))

            self.adj_x = nn.Sequential(nn.Linear(latent_dim, x_rep_dim // 2),
                                       nn.LeakyReLU(),
                                       nn.Linear(x_rep_dim // 2, x_rep_dim))
            self.node_rep_adj = nn.Sequential(
                nn.Linear((N_types + 1), node_rep_dim // 2), nn.LeakyReLU(),
                nn.Linear(node_rep_dim // 2, node_rep_dim))
            if self.multi_ae:
                self.fin_af = nn.ModuleList([
                    nn.Sequential(
                        nn.Linear(node_rep_dim + x_rep_dim, x_rep_dim // 2),
                        nn.LeakyReLU(), nn.Linear(x_rep_dim // 2, x_rep_dim),
                        nn.LeakyReLU(), nn.Linear(x_rep_dim, x_rep_dim // 2),
                        nn.LeakyReLU(), nn.Linear(x_rep_dim // 2, 1))
                    for i in range(max_nodes * (max_nodes + 1))
                ])
            else:
                self.fin_af = nn.Sequential(
                    nn.Linear(node_rep_dim + x_rep_dim, x_rep_dim // 2),
                    nn.LeakyReLU(), nn.Linear(x_rep_dim // 2, x_rep_dim),
                    nn.LeakyReLU(), nn.Linear(x_rep_dim, x_rep_dim // 2),
                    nn.LeakyReLU(), nn.Linear(x_rep_dim // 2, 1))

            if self.pred_edge:
                self.edge_x = nn.Sequential(
                    nn.Linear(latent_dim, x_rep_dim // 2), nn.LeakyReLU(),
                    nn.Linear(x_rep_dim // 2, x_rep_dim))
                self.node_rep_edge = nn.Sequential(
                    nn.Linear((N_types + 1), node_rep_dim // 2),
                    nn.LeakyReLU(), nn.Linear(node_rep_dim // 2, node_rep_dim))
                if self.multi_ae:
                    self.fin_ef = nn.ModuleList([
                        nn.Sequential(
                            nn.Linear(node_rep_dim + x_rep_dim,
                                      x_rep_dim // 2), nn.LeakyReLU(),
                            nn.Linear(x_rep_dim // 2,
                                      x_rep_dim), nn.LeakyReLU(),
                            nn.Linear(x_rep_dim, x_rep_dim // 2),
                            nn.LeakyReLU(), nn.Linear(x_rep_dim // 2, 1))
                        for i in range(max_nodes * (max_nodes + 1))
                    ])
                else:
                    self.fin_ef = nn.Sequential(
                        nn.Linear(node_rep_dim + x_rep_dim, x_rep_dim // 2),
                        nn.LeakyReLU(), nn.Linear(x_rep_dim // 2, x_rep_dim),
                        nn.LeakyReLU(), nn.Linear(x_rep_dim, x_rep_dim // 2),
                        nn.LeakyReLU(), nn.Linear(x_rep_dim // 2, 1))

    def forward(self, x):
        batch_size = x.size(0)

        if self.one_shot:
            nodes = self.fin_nf(x).view(-1, self.max_nodes, self.N_types + 1)
        else:
            if self.multi_node:
                nodes = torch.empty(batch_size,
                                    self.max_nodes,
                                    self.N_types + 1,
                                    device=self.device)
                for i in range(self.max_nodes):
                    nodes[:, i, :-1] = self.fin_nf[i](x)
                    nodes[:, i, -1] = self.fin_charge[i](x).view(-1)
            else:
                nodes = self.fin_nf(x).view(-1, self.max_nodes, self.N_types)
                charges = self.fin_charge(x).view(-1, self.max_nodes, 1)
                nodes = torch.cat((nodes, charges), dim=-1)

        if self.one_shot:
            if self.undirected:
                adjX = self.fin_adj(x)  # adj loss takes logits as input
                adj = torch.zeros(batch_size,
                                  self.max_nodes,
                                  self.max_nodes,
                                  device=self.device)
                adj[:, self.triu_mask] = adjX
                adj[:, self.triu_mask.t()] = adjX

                if self.pred_edge:
                    edgeX = self.fin_ef(x)
                    edge = torch.zeros(batch_size,
                                       self.max_nodes,
                                       self.max_nodes,
                                       device=self.device)

                    edge[:, self.triu_mask] = edgeX
                    edge[:, self.triu_mask.t()] = edgeX
                    edge = torch.sigmoid(edge)
            else:
                adj = self.fin_adj(x).view(-1, self.max_nodes, self.max_nodes)

                if self.pred_edge:
                    edge = torch.sigmoid(self.fin_ef(x)).view(
                        -1, self.max_nodes, self.max_nodes)
        else:

            new_nodes = torch.cat(
                (F.softmax(nodes[:, :, :-1], dim=1), torch.tanh(
                    nodes[:, :, -1]).view(-1, self.max_nodes, 1)),
                dim=-1)
            adj = torch.empty(batch_size,
                              self.max_nodes,
                              self.max_nodes,
                              device=self.device)
            if self.pred_edge:
                edge = torch.empty(batch_size,
                                   self.max_nodes,
                                   self.max_nodes,
                                   device=self.device)

            adj_x = F.leaky_relu(self.adj_x(x))
            if self.pred_edge:
                edge_x = F.leaky_relu(self.edge_x(x))

            node_rep_adj = F.leaky_relu(self.node_rep_adj(new_nodes))

            if self.pred_edge:
                node_rep_edge = F.leaky_relu(self.node_rep_edge(nodes))

            k = 0
            for i in range(self.max_nodes):
                for j in range(i, self.max_nodes):
                    node_in_adj = (node_rep_adj[:, i] +
                                   node_rep_adj[:, j]) / 2.
                    node_in_adj = torch.cat((adj_x, node_in_adj), dim=1)

                    if self.pred_edge:
                        node_in_edge = (node_rep_edge[:, i] +
                                        node_rep_edge[:, j]) / 2.
                        node_in_edge = torch.cat((edge_x, node_in_edge), dim=1)

                    if self.multi_ae:
                        adjX = self.fin_af[k](node_in_adj).view(
                            -1)  # adj loss takes logits as input
                    else:
                        adjX = self.fin_af(node_in_adj).view(
                            -1)  # adj loss takes logits as input
                    if self.pred_edge:
                        if self.multi_ae:
                            edgeX = torch.sigmoid(
                                self.fin_ef[k](node_in_edge).view(-1))
                        else:
                            edgeX = torch.sigmoid(
                                self.fin_ef(node_in_edge).view(-1))

                    adj[:, i, j] = adjX
                    adj[:, j, i] = adjX

                    if self.pred_edge:
                        edge[:, i, j] = edgeX
                        edge[:, j, i] = edgeX

                    k += 1

        if not self.pred_edge:
            return nodes, adj
        else:
            return nodes, adj, edge


class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(VAE, self).__init__()

        self.Encoder = Encoder
        self.Decoder = Decoder

        self.dim = Encoder.dim
        self.latent_dim = Decoder.latent_dim
        self.pred_edge = Decoder.pred_edge

        self.mu_enc = nn.Linear(self.dim, self.latent_dim)
        self.logvar_enc = nn.Linear(self.dim, self.latent_dim)

    def forward(self, x):
        feats = self.Encoder(x)

        self.mu = self.mu_enc(feats)
        self.logvar = self.logvar_enc(feats)
        self.z = self.reparameterize(self.mu, self.logvar)

        if self.pred_edge:
            nodes, adj, edge = self.Decoder(self.z)
            return nodes, adj, edge
        else:
            nodes, adj = self.Decoder(self.z)
            return nodes, adj

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std


class AE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(AE, self).__init__()

        self.Encoder = Encoder
        self.Decoder = Decoder

        self.dim = Encoder.dim
        self.latent_dim = Decoder.latent_dim
        self.pred_edge = Decoder.pred_edge

        self.enc = nn.Linear(self.dim, self.latent_dim)

    def forward(self, x):
        feats = self.Encoder(x)

        self.z = self.enc(feats)
        self.mu = self.z

        if self.pred_edge:
            nodes, adj, edge = self.Decoder(self.z)
            return nodes, adj, edge
        else:
            nodes, adj = self.Decoder(self.z)
            return nodes, adj
