import re
import csv
import json
import time
import logging
import requests
import argparse
from pathlib import Path


def parse_arguments():
    """
    Parse command-line arguments for RQ selection.
    Returns:
        argparse.Namespace: Parsed arguments with RQ value.
    """
    parser = argparse.ArgumentParser(
        description="Run metamorphic test evaluation for a specific RQ."
    )
    parser.add_argument(
        "--rq",
        type=int,
        choices=[1, 2, 3],
        default=1,
        help="RQ value (1, 2, or 3) to determine the configuration. Default is 1.",
    )
    return parser.parse_args()


args = parse_arguments()
RQ = args.rq

# Set up logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
logger.addHandler(console_handler)

MRS_FILE_PATH = "./configuration/metamorphic_relations.json"
GUARDME_URL = "http://localhost:8001/api/v1/metamorphic-tests"

MAX_RETRIES = 10
RETRY_DELAY = 20  # seconds

RQ_CONFIG = {
    1: {
        "RQ_FILE_PATH": "./configuration/rq1.json",
        "INPUT_DIR": "./data/rq1/execution",
        "OUTPUT_DIR": "./data/rq1/evaluation",
    },
    2: {
        "RQ_FILE_PATH": "./configuration/rq2.json",
        "INPUT_DIR": "./data/rq2/execution",
        "OUTPUT_DIR": "./data/rq2/evaluation",
    },
    3: {
        "RQ_FILE_PATH": "./configuration/rq3.json",
        "INPUT_DIR": "./data/rq3/execution",
        "OUTPUT_DIR": "./data/rq3/evaluation",
    },
}


def load_json(file_path):
    """
    Load a JSON file from the given path.
    Args:
        file_path (str or Path): Path to the JSON file.
    Returns:
        dict: Parsed JSON content.
    """
    with open(file_path, encoding="utf-8") as f:
        return json.load(f)


def get_config(rq):
    """
    Get the configuration dictionary for the specified RQ.
    Args:
        rq (int): RQ value (1, 2, or 3).
    Returns:
        dict: Configuration for the selected RQ.
    Raises:
        ValueError: If rq is not supported.
    """
    if rq not in RQ_CONFIG:
        raise ValueError(f"Unsupported RQ value: {rq}")
    return RQ_CONFIG[rq]


CONFIG = get_config(RQ)


def save_to_csv(evaluations, file_path):
    """
    Save evaluation results to a CSV file.
    Args:
        evaluations (list): List of evaluation dictionaries.
        file_path (str or Path): Output CSV file path.
    """
    if not evaluations:
        return

    header_keys = evaluations[0].keys()
    headers = ["test_id", "bias_type"]

    if "attribute" in header_keys:
        headers.append("attribute")

    if "attribute_1" in header_keys and "attribute_2" in header_keys:
        headers.extend(["attribute_1", "attribute_2"])

    headers.extend(
        [
            "scenario",
            "prompt_1",
            "response_1",
            "prompt_2",
            "response_2",
            "verdict",
            "severity",
        ]
    )

    if "metric_value" in header_keys:
        headers.append("metric_value")

    if "generation_explanation" in header_keys:
        headers.append("generation_explanation")

    if "evaluation_explanation" in header_keys:
        headers.append("evaluation_explanation")

    headers.extend(["execution_timestamp", "evaluation_timestamp"])

    evaluations = [
        {key: evaluation[key] for key in headers} for evaluation in evaluations
    ]

    file_path = Path(file_path)
    file_path.parent.mkdir(parents=True, exist_ok=True)

    with open(file_path, mode="w", newline="", encoding="utf-8") as csv_file:
        writer = csv.DictWriter(csv_file, fieldnames=headers)
        writer.writeheader()
        writer.writerows(evaluations)


def execute_evaluation(request_body, test_id, mr, judge, execution_index):
    """
    Evaluate a single metamorphic test by sending a request to the GUARD-ME API.
    Args:
        request_body (dict): Request payload for the API.
        test_id (int): Metamorphic test identifier.
        mr (str): Metamorphic relation name.
        judge (str): Judge model name.
        execution_index (int): Index of the execution (for retries/logging).
    Returns:
        dict or None: Evaluation result dictionary, or None if failed.
    """
    retries = 0
    while retries < MAX_RETRIES:
        try:
            response = requests.post(f"{GUARDME_URL}/evaluate", json=request_body)
            evaluation_timestamp = time.time()
            response.raise_for_status()
            response_data = response.json()

            return {
                **request_body,
                "verdict": response_data["verdict"],
                "severity": response_data["severity"],
                "evaluation_explanation": re.sub(
                    r"^\[.*?\]:\s*", "", response_data["evaluation_explanation"]
                ),
                "evaluation_timestamp": evaluation_timestamp,
                "test_id": test_id,
            }
        except requests.exceptions.HTTPError:
            logger.error(response.text)
            retries += 1
            time.sleep(RETRY_DELAY)
            if retries == MAX_RETRIES:
                logger.info(
                    f"Failed to evaluate test {test_id} for {mr} with {judge} ({execution_index + 1}) after {MAX_RETRIES} retries"
                )


def execute_comparison(request_body, test_id, mr, execution_index):
    """
    Evaluate a metamorphic test (through static comparison) by sending a request to the GUARD-ME API.
    Args:
        request_body (dict): Request payload for the API.
        test_id (int): Metamorphic test identifier.
        mr (str): Metamorphic relation name.
        execution_index (int): Index of the execution (for retries/logging).
    Returns:
        dict or None: Evaluation result dictionary, or None if failed.
    """
    retries = 0
    while retries < MAX_RETRIES:
        with requests.Session() as session:
            try:
                response = session.post(f"{GUARDME_URL}/compare", json=request_body)
                evaluation_timestamp = time.time()
                response.raise_for_status()
                response_data = response.json()

                return {
                    **request_body,
                    "verdict": response_data["verdict"],
                    "severity": response_data["severity"],
                    "evaluation_timestamp": evaluation_timestamp,
                    "test_id": test_id,
                    "metric_value": response_data["metric_value"],
                }
            except requests.exceptions.HTTPError:
                logger.error(response.text)
                retries += 1
                time.sleep(RETRY_DELAY)
                if retries == MAX_RETRIES:
                    logger.info(
                        f"Failed to evaluate test {test_id} for {mr} ({execution_index + 1}) after {MAX_RETRIES} retries"
                    )


def process_evaluations(
    metamorphic_relations,
    execution_config,
    evaluation_config,
    multiple_models=False,
    multiple_executions=False,
    multiple_evaluations=False,
):
    """
    Process and evaluate all metamorphic tests for the given configuration.
    Args:
        metamorphic_relations (dict): Metamorphic relations configuration.
        execution_config (dict): Execution configuration.
        evaluation_config (dict): Evaluation configuration.
        multiple_models (bool): Whether to process multiple models.
        multiple_executions (bool): Whether to process multiple executions per metamorphic test.
        multiple_evaluations (bool): Whether to process multiple evaluations per metamorphic test.
    """
    judge_models = evaluation_config.get(
        "judge_models", [evaluation_config.get("judge_model")]
    )
    executions = execution_config.get("executions", 1) if multiple_executions else 1
    n_evaluations = (
        evaluation_config.get("executions", 1) if multiple_evaluations else 1
    )
    judge_temperature = evaluation_config["judge_temperature"]
    models = (
        execution_config["models_under_test"]
        if multiple_models
        else [execution_config.get("model_under_test", "")]
    )

    for judge in judge_models:
        for model in models:
            for mr, mr_info in metamorphic_relations.items():
                input_file = (
                    Path(CONFIG["INPUT_DIR"]) / model / f"{mr}.csv"
                    if len(models) > 1
                    else Path(CONFIG["INPUT_DIR"]) / f"{mr}.csv"
                )
                for evaluation_index in range(n_evaluations):
                    for execution_index in range(executions):
                        evaluations = []
                        if multiple_executions:
                            input_file = (
                                Path(CONFIG["INPUT_DIR"])
                                / model
                                / mr
                                / f"{execution_index + 1}.csv"
                            )

                        if not input_file.exists():
                            continue

                        with open(input_file, encoding="utf-8") as f:
                            reader = csv.DictReader(f)
                            if not mr_info.get("static_evaluation", False):
                                request_body_template = {
                                    "judge_models": [judge],
                                    "evaluation_method": mr_info["evaluation_method"],
                                    "judge_temperature": judge_temperature,
                                }
                                for row in reader:
                                    request_body = request_body_template.copy()
                                    request_body["bias_type"] = row["bias_type"]
                                    request_body.update(
                                        {
                                            key: row[key]
                                            for key in [
                                                "attribute",
                                                "attribute_1",
                                                "attribute_2",
                                            ]
                                            if key in row
                                        }
                                    )
                                    request_body.update(
                                        {
                                            "prompt_1": row["prompt_1"],
                                            "response_1": row["response_1"],
                                            "prompt_2": row["prompt_2"],
                                            "response_2": row["response_2"],
                                        }
                                    )

                                    evaluation = execute_evaluation(
                                        request_body,
                                        row["test_id"],
                                        mr,
                                        judge,
                                        execution_index,
                                    )
                                    if evaluation:
                                        evaluation["execution_timestamp"] = row[
                                            "execution_timestamp"
                                        ]
                                        evaluation["scenario"] = row["scenario"]
                                        evaluations.append(evaluation)
                                        message = f"Evaluated test {row['test_id']} for {mr} on {model} with {judge}"
                                        if executions > 1:
                                            message += (
                                                f" ({execution_index + 1}/{executions})"
                                            )
                                        elif n_evaluations > 1:
                                            message += f" ({evaluation_index + 1}/{n_evaluations})"

                                        logger.info(message)

                                        if multiple_evaluations:  # RQ1
                                            output_file = (
                                                Path(CONFIG["OUTPUT_DIR"])
                                                / "judge"
                                                / judge
                                                / mr
                                                / f"{evaluation_index + 1}.csv"
                                            )
                                        elif (
                                            multiple_models and multiple_executions
                                        ):  # RQ3
                                            output_file = (
                                                Path(CONFIG["OUTPUT_DIR"])
                                                / "judge"
                                                / model
                                                / mr
                                                / f"{execution_index + 1}.csv"
                                            )
                                        elif multiple_models:  # RQ2
                                            output_file = (
                                                Path(CONFIG["OUTPUT_DIR"])
                                                / "judge"
                                                / model
                                                / f"{mr}.csv"
                                            )

                            else:
                                if (
                                    multiple_evaluations and not multiple_executions
                                ):  # RQ1
                                    continue
                                request_body_template = {
                                    "metric": mr_info["evaluation_method"],
                                    "evaluation_method": mr_info["evaluation_method"],
                                    "judge_temperature": judge_temperature,
                                }

                                if mr_info.get("threshold", False):
                                    request_body_template["threshold"] = mr_info[
                                        "threshold"
                                    ]

                                for row in reader:
                                    request_body = request_body_template.copy()
                                    request_body.update(
                                        {
                                            "response_1": row["response_1"],
                                            "response_2": row["response_2"],
                                        }
                                    )

                                    evaluation = execute_comparison(
                                        request_body,
                                        row["test_id"],
                                        mr,
                                        execution_index,
                                    )
                                    if evaluation:
                                        evaluation["execution_timestamp"] = row[
                                            "execution_timestamp"
                                        ]
                                        evaluation["bias_type"] = row["bias_type"]
                                        evaluation["prompt_1"] = row["prompt_1"]
                                        evaluation["prompt_2"] = row["prompt_2"]
                                        evaluation["scenario"] = row["scenario"]
                                        evaluations.append(evaluation)
                                        logger.info(
                                            f"Evaluated test {row['test_id']} for {mr} on {model}"
                                        )

                                        if (
                                            multiple_models and multiple_executions
                                        ):  # RQ3
                                            output_file = (
                                                Path(CONFIG["OUTPUT_DIR"])
                                                / "static"
                                                / model
                                                / mr
                                                / f"{execution_index + 1}.csv"
                                            )
                                        elif multiple_models:  # RQ2
                                            output_file = (
                                                Path(CONFIG["OUTPUT_DIR"])
                                                / "static"
                                                / model
                                                / f"{mr}.csv"
                                            )

                        save_to_csv(evaluations, output_file)


def launch_evaluation(rq):
    """
    Launch the metamorphic test evaluation process for the selected RQ.
    Loads configuration and evaluates all metamorphic tests.
    Args:
        rq (int): RQ value (1, 2, or 3).
    """
    metamorphic_relations = load_json(MRS_FILE_PATH)
    execution_config = load_json(CONFIG["RQ_FILE_PATH"]).get("execution", {})
    evaluation_config = load_json(CONFIG["RQ_FILE_PATH"]).get("evaluation", {})

    if rq == 1:
        process_evaluations(
            metamorphic_relations,
            execution_config,
            evaluation_config,
            multiple_evaluations=True,
        )
    elif rq == 2:
        process_evaluations(
            metamorphic_relations,
            execution_config,
            evaluation_config,
            multiple_models=True,
        )
    elif rq == 3:
        process_evaluations(
            metamorphic_relations,
            execution_config,
            evaluation_config,
            multiple_models=True,
            multiple_executions=True,
        )
    else:
        raise ValueError(f"Unsupported RQ value: {rq}")


if __name__ == "__main__":
    launch_evaluation(RQ)
