# metrics.py

"""
This file provides evaluation metrics for ContraBin.
Metrics include BLEU scores for summarization and accuracy for classification tasks.
"""

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np

def compute_bleu(reference, candidate, weights=(0.25, 0.25, 0.25, 0.25)):
    """
    Compute BLEU score for a given reference and candidate text.

    Args:
        reference (str): Ground truth text.
        candidate (str): Predicted text.
        weights (tuple): N-gram weights for BLEU calculation.

    Returns:
        float: BLEU score.
    """
    smoothing_function = SmoothingFunction().method1
    score = sentence_bleu(
        [reference.split()], candidate.split(), weights=weights, smoothing_function=smoothing_function
    )
    return score

def compute_accuracy(predictions, labels):
    """
    Compute accuracy for classification tasks.

    Args:
        predictions (list): Model predictions.
        labels (list): Ground truth labels.

    Returns:
        float: Accuracy score.
    """
    correct = sum(p == l for p, l in zip(predictions, labels))
    return correct / len(labels)

if __name__ == "__main__":
    # Example usage for testing
    ref = "This is a sample reference"
    cand = "This is a sample candidate"
    bleu_score = compute_bleu(ref, cand)
    print("BLEU Score:", bleu_score)

    preds = [1, 0, 1, 1, 0]
    labels = [1, 0, 0, 1, 0]
    accuracy = compute_accuracy(preds, labels)
    print("Accuracy:", accuracy)