# model.py

"""
This file defines the core model architecture for ContraBin, including the encoder models
and the projection heads used in contrastive learning. It integrates source code, binary 
code, and comments for generating embeddings.
"""

import torch
from torch import nn
from transformers import AutoModel

class EncoderAnchor(nn.Module):
    """
    EncoderAnchor is a frozen encoder model used for generating embeddings from pre-trained models.
    """

    def __init__(self, model_name=configs.model_name, trainable=False):
        """
        Initializes the frozen encoder.

        Args:
            model_name (str): The name of the pre-trained model.
            trainable (bool): If False, the encoder's parameters are not updated during training.
        """
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name).to(configs.device)
        for param in self.model.parameters():
            param.requires_grad = trainable  # Freeze the parameters if not trainable

        self.target_token_idx = 0  # Use CLS token embedding

    @torch.no_grad()
    def forward(self, input_ids):
        """
        Forward pass for generating embeddings.

        Args:
            input_ids (torch.Tensor): Input token IDs.

        Returns:
            torch.Tensor: Embeddings for the input.
        """
        output = self.model(input_ids=input_ids.to(configs.device))
        return output.last_hidden_state[:, self.target_token_idx, :]


class EncoderTrainable(nn.Module):
    """
    EncoderTrainable is a fine-tunable encoder model for generating embeddings.
    """

    def __init__(self, model_name=configs.model_name, trainable=True):
        """
        Initializes the trainable encoder.

        Args:
            model_name (str): The name of the pre-trained model.
            trainable (bool): If True, the encoder's parameters are updated during training.
        """
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name).to(configs.device)
        for param in self.model.parameters():
            param.requires_grad = trainable  # Enable parameter updates

        self.target_token_idx = 0  # Use CLS token embedding

    def forward(self, input_ids):
        """
        Forward pass for generating embeddings.

        Args:
            input_ids (torch.Tensor): Input token IDs.

        Returns:
            torch.Tensor: Embeddings for the input.
        """
        output = self.model(input_ids=input_ids.to(configs.device))
        return output.last_hidden_state[:, self.target_token_idx, :]


class ProjectionHead(nn.Module):
    """
    ProjectionHead applies linear transformations and non-linear activation functions
    to project embeddings into the space required for contrastive learning.
    """

    def __init__(self, embedding_dim, projection_dim, dropout):
        """
        Initializes the projection head.

        Args:
            embedding_dim (int): Dimension of the input embeddings.
            projection_dim (int): Dimension of the output embeddings.
            dropout (float): Dropout rate.
        """
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        """
        Forward pass for projection.

        Args:
            x (torch.Tensor): Input embeddings.

        Returns:
            torch.Tensor: Projected embeddings.
        """
        x = self.projection(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.layer_norm(x)
        return x


class ContraBinModel(nn.Module):
    """
    ContraBinModel integrates multiple encoders and projection heads for contrastive learning.
    It handles source code, binary code, and comment embeddings.
    """

    def __init__(self):
        """
        Initializes the ContraBin model with encoders and projection heads.
        """
        super().__init__()
        # Initialize encoders
        self.encoder_update = EncoderTrainable()
        self.encoder_stop = EncoderAnchor()

        # Initialize projection heads
        self.source_projection = ProjectionHead(
            configs.source_embedding, configs.projection_dim, configs.dropout
        )
        self.binary_projection = ProjectionHead(
            configs.binary_embedding, configs.projection_dim, configs.dropout
        )
        self.comment_projection = ProjectionHead(
            configs.comment_embedding, configs.projection_dim, configs.dropout
        )

    def forward(self, batch):
        """
        Forward pass for generating embeddings and projecting them.

        Args:
            batch (tuple): A batch containing source, binary, and comment data.

        Returns:
            tuple: Projected embeddings for source, binary, and comments.
        """
        # Encode inputs
        source_embeddings = self.encoder_stop(batch[0]["input_ids"])
        binary_embeddings = self.encoder_update(batch[1]["input_ids"])
        comment_embeddings = self.encoder_stop(batch[2]["input_ids"])

        # Project embeddings
        source_projected = self.source_projection(source_embeddings)
        binary_projected = self.binary_projection(binary_embeddings)
        comment_projected = self.comment_projection(comment_embeddings)

        return source_projected, binary_projected, comment_projected