# encoder.py

"""
This file implements the encoder components of the ContraBin model.
Encoders process source code, binary code, and comments to generate embeddings.
"""

import torch
from torch import nn
from transformers import AutoModel

class EncoderAnchor(nn.Module):
    """
    A frozen encoder class to extract embeddings from pre-trained models
    without updating their parameters. Used for comparison purposes.
    """
    def __init__(self, model_name, device):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name).to(device)  # Load pretrained model
        self.target_token_idx = 0  # Use CLS token for representation

        # Freeze all parameters of the model
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, input_ids):
        """
        Forward pass for the frozen encoder.
        Args:
            input_ids (Tensor): Tokenized input sequences.
        Returns:
            Tensor: CLS token embedding from the model.
        """
        output = self.model(input_ids=input_ids)
        return output.last_hidden_state[:, self.target_token_idx, :]


class EncoderTrainable(nn.Module):
    """
    A trainable encoder class to extract embeddings while allowing fine-tuning
    on the downstream task.
    """
    def __init__(self, model_name, device):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name).to(device)  # Load pretrained model
        self.target_token_idx = 0  # Use CLS token for representation

    def forward(self, input_ids):
        """
        Forward pass for the trainable encoder.
        Args:
            input_ids (Tensor): Tokenized input sequences.
        Returns:
            Tensor: CLS token embedding from the model.
        """
        output = self.model(input_ids=input_ids)
        return output.last_hidden_state[:, self.target_token_idx, :]


if __name__ == "__main__":
    # Example usage for testing
    from configs import configs
    tokenizer_name = configs.tokenizer_name

    # Initialize encoders
    device = configs.device
    anchor_encoder = EncoderAnchor(model_name=configs.model_name, device=device)
    trainable_encoder = EncoderTrainable(model_name=configs.model_name, device=device)

    # Test input
    dummy_input = torch.randint(0, 100, (2, 512)).to(device)  # Simulated input IDs
    anchor_output = anchor_encoder(dummy_input)
    trainable_output = trainable_encoder(dummy_input)

    print("Anchor Encoder Output Shape:", anchor_output.shape)
    print("Trainable Encoder Output Shape:", trainable_output.shape)