import os
import math
import torch
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
import math, torch
import torch
from torch.nn.utils.rnn import pad_sequence

def pad_from_mask(feats, mask):
    B, max_len = mask.shape
    _, d_model = feats.shape
    
    lengths = mask.sum(dim=1).long().tolist()
    pad_feats, start = [], 0
    for seq_len in lengths:
        end = start + seq_len
        pad_len = max_len - seq_len
        pad_tensor = torch.zeros(pad_len, d_model, device=feats.device)
        feat = torch.cat([feats[start:end], pad_tensor], dim=0)
        pad_feats.append(feat)
        start = end

    return torch.stack(pad_feats, dim=0)

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )
    def forward(self, x):
        return self.mlp(x)

class ProtMPNN(nn.Module):
    def __init__(self, hidden_dim, num_heads=4, dropout=0.1, residual=True):
        super().__init__()
        self.pre_h = nn.Linear(hidden_dim, hidden_dim)
        self.pre_e = nn.Linear(hidden_dim, hidden_dim)
        self.node_mlp = MLP(hidden_dim*2, hidden_dim, hidden_dim)
        self.edge_mlp = MLP(hidden_dim*3, hidden_dim, hidden_dim)
        self.norm_h = nn.LayerNorm(hidden_dim)
        self.norm_e = nn.LayerNorm(hidden_dim)
        self.drop = nn.Dropout(dropout)
        self.residual = residual

    def forward(self, g, h, e):
        h = self.pre_h(h)
        e = self.pre_e(e)
        g.ndata['h'] = h
        g.edata['e'] = e
        g.apply_edges(lambda E: {'msg': torch.cat([E.src['h'], E.data['e']], -1)})
        g.update_all(fn.copy_e('msg','m'), fn.sum('m','agg'))
        h_new = self.drop(self.node_mlp(g.ndata.pop('agg')))
        h = self.norm_h(h + h_new if self.residual else h_new)
        g.ndata['h'] = h
        g.apply_edges(lambda E: {'e_new': self.drop(self.edge_mlp(torch.cat([E.src['h'], E.dst['h'], E.data['e']], -1)))})
        e = self.norm_e(e + g.edata.pop('e_new') if self.residual else g.edata['e_new'])
        return h, e


class ProtMPNN_NoEdge(nn.Module):
    def __init__(self, hidden_dim, num_heads=4, dropout=0.1, residual=True):
        super().__init__()
        self.pre_h = nn.Linear(hidden_dim, hidden_dim)
        self.node_mlp = MLP(hidden_dim * 2, hidden_dim, hidden_dim)
        self.norm_h = nn.LayerNorm(hidden_dim)
        self.drop = nn.Dropout(dropout)
        self.residual = residual

    def forward(self, g, h, e):
        h = self.pre_h(h)
        g.ndata['h'] = h
        g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'agg'))
        h_new = self.drop(self.node_mlp(torch.cat([h, g.ndata.pop('agg')], -1)))
        h = self.norm_h(h + h_new if self.residual else h_new)
        return h, e


class Cov2ProtMPNN(nn.Module):
    def __init__(self, hidden_dim, num_heads=4, dropout=0.1, residual=True):
        super().__init__()
        self.layer1 = ProtMPNN(hidden_dim, num_heads, dropout, residual)
        self.layer2 = ProtMPNN(hidden_dim, num_heads, dropout, residual)
    def forward(self, g, h, e):
        h, e = self.layer1(g, h, e)
        h, e = self.layer2(g, h, e)
        return h, e

class TransformerBlock(nn.Module):
    def __init__(self,d_model,num_heads=4,dropout=0.1):
        super().__init__()
        assert d_model%num_heads==0
        self.h=num_heads
        self.d=d_model//num_heads
        self.W_Q=nn.Linear(d_model,d_model)
        self.W_K=nn.Linear(d_model,d_model)
        self.W_V=nn.Linear(d_model,d_model)
        self.W_O=nn.Linear(d_model,d_model)
        self.norm1=nn.LayerNorm(d_model)
        self.norm2=nn.LayerNorm(d_model)
        self.ffn=nn.Sequential(
            nn.Linear(d_model,4*d_model),
            nn.GELU(),
            nn.Linear(4*d_model,d_model)
        )
        self.dropout=nn.Dropout(dropout)

    def forward(self,q,k,v,k_mask=None):
        B,Lq,D=q.shape
        _,Lk,_=k.shape
        Q=self.W_Q(q)
        K=self.W_K(k)
        V=self.W_V(v)
        Q=Q.reshape(B,Lq,self.h,self.d).transpose(1,2)  # [B,h,Lq,d]
        K=K.reshape(B,Lk,self.h,self.d).transpose(1,2)  # [B,h,Lk,d]
        V=V.reshape(B,Lk,self.h,self.d).transpose(1,2)  # [B,h,Lk,d]
        attn_logits=torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(self.d)
        if k_mask is not None:
            mask=k_mask.unsqueeze(1).unsqueeze(2)  # [B,1,1,Lk]
            attn_logits=attn_logits.masked_fill(~mask,float('-inf'))
        attn=F.softmax(attn_logits,dim=-1)
        attn=self.dropout(attn)
        out=torch.matmul(attn,V)  # [B,h,Lq,d]
        out=out.transpose(1,2).contiguous().reshape(B,Lq,self.h*self.d)
        out=self.W_O(out)
        q=q+self.dropout(out)
        q=self.norm1(q)
        ffn_out=self.ffn(q)
        q=q+self.dropout(ffn_out)
        q=self.norm2(q)
        return q

class MMAlloSite(nn.Module):
    def __init__(self, in_dim_dict={'prot_plm':320, 'prot_edge':15}, 
                 max_protein_length=1500, dropout=0.1, hidden_dim=512):
        super().__init__()
        self.prot_node_mlp = MLP(in_dim_dict['prot_plm'], hidden_dim, hidden_dim)
        self.prot_edge_mlp = MLP(in_dim_dict['prot_edge'], hidden_dim, hidden_dim)
        self.prot_graph_encoder=Cov2ProtMPNN(hidden_dim=hidden_dim,num_heads=4,dropout=dropout)
        self.prot_sequence_encoder=TransformerBlock(d_model=hidden_dim,num_heads=4,dropout=dropout)
        self.pocket_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        self.max_protein_length = max_protein_length
        self.return_feature = True

    def forward(self, prot_graph, mask_prot):
        prot_g_node = self.prot_node_mlp(prot_graph.ndata['plm'])        
        prot_g_edge = self.prot_edge_mlp(prot_graph.edata['local'])
        prot_g_node, prot_g_edge = self.prot_graph_encoder(prot_graph, prot_g_node, prot_g_edge)  
        prot_node = pad_from_mask(prot_g_node, mask_prot) 
        prot_node = self.prot_sequence_encoder(q=prot_node,k=prot_node,v=prot_node,k_mask=mask_prot)*mask_prot.unsqueeze(-1).float()
        pocket_logits = self.pocket_predictor(prot_node)
        return pocket_logits.squeeze(-1)

class MMAlloBind(nn.Module):
    def __init__(self, in_dim_dict={'prot_plm':320, 'lig_fp':1024, 'prot_edge':15}, 
                 max_protein_length=1500, dropout=0.1, hidden_dim=512):
        super().__init__()
        self.lig_mlp = MLP(in_dim_dict['lig_fp'], hidden_dim, hidden_dim)
        self.prot_node_mlp = MLP(in_dim_dict['prot_plm'], hidden_dim, hidden_dim)
        self.prot_edge_mlp = MLP(in_dim_dict['prot_edge'], hidden_dim, hidden_dim)
        self.prot_graph_encoder=Cov2ProtMPNN(hidden_dim=hidden_dim,num_heads=4,dropout=dropout)
        self.cross_encoder=TransformerBlock(d_model=hidden_dim,num_heads=4,dropout=dropout)
        self.prot_sequence_encoder=TransformerBlock(d_model=hidden_dim,num_heads=4,dropout=dropout)
        self.pocket_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
        self.max_protein_length = max_protein_length
        self.return_feature = True

    def forward(self, lig_fp, prot_graph, mask_prot):
        lig_feat = self.lig_mlp(lig_fp).unsqueeze(1)
        prot_g_node = self.prot_node_mlp(prot_graph.ndata['plm'])        
        prot_g_edge = self.prot_edge_mlp(prot_graph.edata['local'])
        prot_g_node, prot_g_edge = self.prot_graph_encoder(prot_graph, prot_g_node, prot_g_edge)  
        prot_node = pad_from_mask(prot_g_node, mask_prot) 
        prot_node = self.cross_encoder(q=prot_node,k=lig_feat,v=lig_feat,k_mask=None)*mask_prot.unsqueeze(-1).float()
        prot_node = self.prot_sequence_encoder(q=prot_node,k=prot_node,v=prot_node,k_mask=mask_prot)*mask_prot.unsqueeze(-1).float()
        pocket_logits = self.pocket_predictor(prot_node)
        return pocket_logits.squeeze(-1)

class MMAlloMutate(nn.Module):
    def __init__(self, in_dim_dict={'prot_plm':320, 'prot_edge':15}, 
                 max_protein_length=1500, dropout=0.1, hidden_dim=512):
        super().__init__()
        self.prot_node_mlp = MLP(in_dim_dict['prot_plm'], hidden_dim, hidden_dim)
        self.prot_edge_mlp = MLP(in_dim_dict['prot_edge'], hidden_dim, hidden_dim)
        self.prot_graph_encoder=Cov2ProtMPNN(hidden_dim=hidden_dim,num_heads=4,dropout=dropout)
        self.prot_sequence_encoder=TransformerBlock(d_model=hidden_dim,num_heads=4,dropout=dropout)
        self.site_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
        self.max_protein_length = max_protein_length
        self.return_feature = True

    def forward(self, prot_graph, site, mask_prot):
        prot_g_node = self.prot_node_mlp(prot_graph.ndata['plm'])        
        prot_g_edge = self.prot_edge_mlp(prot_graph.edata['local'])
        prot_g_node, prot_g_edge = self.prot_graph_encoder(prot_graph, prot_g_node, prot_g_edge)  
        prot_node = pad_from_mask(prot_g_node, mask_prot) 
        prot_node = self.prot_sequence_encoder(q=prot_node,k=prot_node,v=prot_node,k_mask=mask_prot)*mask_prot.unsqueeze(-1).float()
        idx = site.unsqueeze(-1).expand(-1, -1, prot_node.size(-1))  # [B,1,256]
        site_feat = torch.gather(prot_node, 1, idx).squeeze(1) 

        site_logits = self.site_predictor(site_feat)
        return site_logits.squeeze(-1)
