# visualization.py

"""
This file provides visualization utilities for ContraBin.
Visualizations include dataset distributions and performance metrics.
"""

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_length_distribution(source_lengths, binary_lengths, comment_lengths, save_path="length_distribution.pdf"):
    """
    Plot the distribution of program lengths for source, binary, and comment code.

    Args:
        source_lengths (list): Lengths of source code samples.
        binary_lengths (list): Lengths of binary code samples.
        comment_lengths (list): Lengths of comment samples.
        save_path (str): Path to save the generated plot.

    Returns:
        None
    """
    sns.set_style("whitegrid")
    plt.figure(figsize=(8, 6))

    # Convert to log scale for better visualization
    source_log = np.log10(source_lengths)
    binary_log = np.log10(binary_lengths)
    comment_log = np.log10(comment_lengths)

    sns.histplot(source_log, kde=True, label="Source Code", color="blue", alpha=0.6, bins=20)
    sns.histplot(binary_log, kde=True, label="Binary Code", color="green", alpha=0.6, bins=20)
    sns.histplot(comment_log, kde=True, label="Comments", color="orange", alpha=0.6, bins=20)

    plt.xlabel("Log Length")
    plt.ylabel("Frequency")
    plt.title("Program Length Distribution")
    plt.legend()
    plt.savefig(save_path, bbox_inches="tight", transparent=True)
    plt.show()

def plot_training_loss(train_loss, val_loss, save_path="training_loss.pdf"):
    """
    Plot training and validation loss over epochs.

    Args:
        train_loss (list): Training loss values.
        val_loss (list): Validation loss values.
        save_path (str): Path to save the generated plot.

    Returns:
        None
    """
    plt.figure(figsize=(8, 6))
    epochs = range(1, len(train_loss) + 1)

    plt.plot(epochs, train_loss, label="Training Loss", color="blue", marker="o")
    plt.plot(epochs, val_loss, label="Validation Loss", color="orange", marker="x")

    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss")
    plt.legend()
    plt.savefig(save_path, bbox_inches="tight", transparent=True)
    plt.show()

if __name__ == "__main__":
    # Example usage for testing
    source_lengths = [100, 200, 300, 400, 500]
    binary_lengths = [150, 250, 350, 450, 550]
    comment_lengths = [50, 100, 150, 200, 250]
    plot_length_distribution(source_lengths, binary_lengths, comment_lengths)

    train_loss = [0.9, 0.7, 0.5, 0.3]
    val_loss = [1.0, 0.8, 0.6, 0.4]
    plot_training_loss(train_loss, val_loss)