# heads.py

"""
This file implements the projection heads used in ContraBin.
Projection heads map high-dimensional encoder embeddings to a common space,
facilitating contrastive learning.
"""

import torch
from torch import nn

class LinearHead(nn.Module):
    """
    Implements a linear projection head for mapping encoder embeddings
    to a lower-dimensional space for contrastive learning.
    """
    def __init__(self, embedding_dim, projection_dim, dropout):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)  # Linear projection
        self.gelu = nn.GELU()  # Activation function
        self.fc = nn.Linear(projection_dim, projection_dim)  # Fully connected layer
        self.dropout = nn.Dropout(dropout)  # Dropout for regularization
        self.layer_norm = nn.LayerNorm(projection_dim)  # Normalization layer

    def forward(self, x):
        """
        Forward pass for the linear projection head.
        Args:
            x (Tensor): Input embeddings from encoders.
        Returns:
            Tensor: Projected embeddings.
        """
        projected = self.projection(x)  # Linear projection
        x = self.gelu(projected)  # Apply activation
        x = self.fc(x)  # Fully connected layer
        x = self.dropout(x)  # Dropout regularization
        x = x + projected  # Residual connection
        x = self.layer_norm(x)  # Layer normalization
        return x


class NonLinearHead(nn.Module):
    """
    Implements a non-linear projection head for mapping encoder embeddings
    to a lower-dimensional space for contrastive learning. Uses additional
    transformations compared to LinearHead.
    """
    def __init__(self, embedding_dim, projection_dim, dropout):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)  # Linear projection
        self.gelu = nn.GELU()  # Activation function
        self.fc1 = nn.Linear(projection_dim, projection_dim)  # First fully connected layer
        self.fc2 = nn.Linear(projection_dim, projection_dim)  # Second fully connected layer
        self.dropout = nn.Dropout(dropout)  # Dropout for regularization
        self.layer_norm = nn.LayerNorm(projection_dim)  # Normalization layer

    def forward(self, x):
        """
        Forward pass for the non-linear projection head.
        Args:
            x (Tensor): Input embeddings from encoders.
        Returns:
            Tensor: Projected embeddings.
        """
        projected = self.projection(x)  # Linear projection
        x = self.gelu(projected)  # Apply activation
        x = self.fc1(x)  # First fully connected layer
        x = self.gelu(x)  # Apply activation again
        x = self.fc2(x)  # Second fully connected layer
        x = self.dropout(x)  # Dropout regularization
        x = x + projected  # Residual connection
        x = self.layer_norm(x)  # Layer normalization
        return x


if __name__ == "__main__":
    # Example usage for testing
    from configs import configs

    # Initialize heads
    linear_head = LinearHead(
        embedding_dim=configs.source_embedding,
        projection_dim=configs.projection_dim,
        dropout=configs.dropout
    )
    nonlinear_head = NonLinearHead(
        embedding_dim=configs.source_embedding,
        projection_dim=configs.projection_dim,
        dropout=configs.dropout
    )

    # Test input
    dummy_input = torch.randn(2, configs.source_embedding).to(configs.device)  # Simulated embeddings
    linear_output = linear_head(dummy_input)
    nonlinear_output = nonlinear_head(dummy_input)

    print("Linear Head Output Shape:", linear_output.shape)
    print("NonLinear Head Output Shape:", nonlinear_output.shape)