# main.py

"""
Main training script for ContraBin.

This script integrates all components for training, validating, and analyzing the ContraBin model.
It includes dataset preparation, model initialization, training, validation, and result visualization.

Highlights:
- Implements contrastive learning with simplex interpolation.
- Focuses on leveraging source code, binary code, and comments for representation learning.
- Designed to ensure reproducibility and scalability.
"""

import os
import torch
import random
import numpy as np
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from configs import Configs
from dataset import Dataset
from encoders import EncoderAnchor, EncoderTrainable
from heads import LinearHead, NonLinearHead
from metrics import compute_bleu, compute_accuracy
from visualization import plot_training_loss, plot_length_distribution

# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

def train_epoch(model, train_loader, optimizer, scheduler, config, interpolation_flag):
    """
    Train the model for one epoch.

    Args:
        model (nn.Module): The ContraBin model.
        train_loader (DataLoader): DataLoader for training data.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        scheduler (torch.optim.lr_scheduler): Learning rate scheduler.
        config (Configs): Configuration object.
        interpolation_flag (str): Interpolation type ('naive', 'linear', 'nonlinear').

    Returns:
        float: Average training loss for the epoch.
    """
    model.train()
    total_loss = 0.0

    for batch in train_loader:
        optimizer.zero_grad()
        loss = model(batch, interpolation_flag)  # Interpolation logic in forward pass
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    scheduler.step()
    return total_loss / len(train_loader)


def validate_epoch(model, val_loader, interpolation_flag):
    """
    Validate the model for one epoch.

    Args:
        model (nn.Module): The ContraBin model.
        val_loader (DataLoader): DataLoader for validation data.
        interpolation_flag (str): Interpolation type ('naive', 'linear', 'nonlinear').

    Returns:
        float: Average validation loss for the epoch.
    """
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for batch in val_loader:
            loss = model(batch, interpolation_flag)  # Validation with specific interpolation
            total_loss += loss.item()

    return total_loss / len(val_loader)


def main():
    """
    Main training pipeline for ContraBin.
    """
    # Load configurations
    config = Configs()

    # Initialize tokenizer and dataset
    tokenizer = config.tokenizer
    train_data = Dataset(config.train_data_path, tokenizer)
    val_data = Dataset(config.val_data_path, tokenizer)

    # Data loaders
    train_loader = DataLoader(
        train_data, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers
    )
    val_loader = DataLoader(
        val_data, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers
    )

    # Initialize the model
    model = nn.DataParallel(config.get_model())
    model.to(config.device)

    # Optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    # Initialize variables for training
    best_loss = float("inf")
    train_losses, val_losses = [], []

    # Training loop
    for epoch in range(config.num_epochs):
        print(f"Epoch {epoch + 1}/{config.num_epochs}")

        # Dynamically set interpolation type
        if epoch < config.num_epochs // 3:
            interpolation_flag = "naive"  # No interpolation, baseline
        elif epoch < 2 * config.num_epochs // 3:
            interpolation_flag = "linear"  # Simple linear interpolation
        else:
            interpolation_flag = "nonlinear"  # Full simplex interpolation

        # Train and validate
        train_loss = train_epoch(model, train_loader, optimizer, scheduler, config, interpolation_flag)
        val_loss = validate_epoch(model, val_loader, interpolation_flag)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Save the best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), f"best_model_epoch_{epoch + 1}.pth")
            print("Saved Best Model!")

    # Visualization
    plot_training_loss(train_losses, val_losses)

    # Dataset statistics visualization
    plot_length_distribution(
        [len(d["source"]) for d in train_data],
        [len(d["binary"]) for d in train_data],
        [len(d["comment"]) for d in train_data],
    )

    # Final evaluation metrics
    print("Evaluating BLEU and accuracy on validation set...")
    bleu_scores, accuracies = [], []
    for batch in val_loader:
        predictions = model(batch, flag="evaluate")  # Evaluation mode
        references = batch["comment"]
        for ref, pred in zip(references, predictions):
            bleu_scores.append(compute_bleu(ref, pred))
            accuracies.append(compute_accuracy(pred, ref))

    print(f"BLEU Score: {np.mean(bleu_scores):.4f}")
    print(f"Accuracy: {np.mean(accuracies):.4f}")


if __name__ == "__main__":
    main()