import os
import sys
import logging
from pathlib import Path
from typing import List
import numpy as np
import pandas as pd


######################################################################################################
# 0. Paths & Logging
######################################################################################################

BASE_DIR = Path(__file__).resolve().parents[1]
DATA_DIR = BASE_DIR / "data"
DATA_EXT_DIR = BASE_DIR / "data_extensions"
LOGS_DIR = BASE_DIR / "code" / "logs"

DATA_EXT_DIR.mkdir(parents=True, exist_ok=True)
LOGS_DIR.mkdir(parents=True, exist_ok=True)


def setup_logging(log_dir: Path, log_filename: str) -> Path:
    """
    Basic logging setup: writes both to a log file and to stdout.
    """
    log_dir.mkdir(parents=True, exist_ok=True)
    log_path = log_dir / log_filename

    if not logging.getLogger().handlers:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
            handlers=[
                logging.FileHandler(log_path),
                logging.StreamHandler(sys.stdout),
            ],
        )
    else:
        file_handler_paths = {
            getattr(h, "baseFilename", None)
            for h in logging.getLogger().handlers
            if isinstance(h, logging.FileHandler)
        }
        if str(log_path) not in file_handler_paths:
            logging.getLogger().addHandler(logging.FileHandler(log_path))

    return log_path


LOG_FILE_PATH = setup_logging(LOGS_DIR, "data_gn_adder.log")


######################################################################################################
# 1. Gaussian Noise Function
######################################################################################################

NUMERICAL_COLS: List[str] = [
    "ESM",
    "HSPC",
    "CHOL",
    "PEG",
    "TFR",
    "FRR",
    "SIZE",
    "PDI",
]


def add_gaussian_noise(
    df: pd.DataFrame,
    numerical_cols: List[str],
    noise_level: float = 0.01,
) -> pd.DataFrame:
    """
    Add Gaussian noise to specified numerical columns in a DataFrame.

    - ESM: noise only where HSPC == 0
    - HSPC: noise only where ESM == 0
    - All other numerical columns: noise everywhere

    Args:
        df: input DataFrame
        numerical_cols: columns to which noise will be added
        noise_level: proportion of the standard deviation used as noise scale

    Returns:
        DataFrame with noise added (in-place modifications on a copy)
    """
    df = df.copy()
    logging.info(
        "Adding Gaussian noise with noise level %s to numerical columns: %s",
        noise_level,
        numerical_cols,
    )

    for col in numerical_cols:
        if col not in df.columns:
            logging.warning("Column '%s' not found. Skipping.", col)
            continue

        if df[col].std() == 0 or df[col].isna().all():
            logging.warning(
                "Column '%s' has zero/undefined std. Skipping noise addition.", col
            )
            continue

        if col == "ESM":
            noise = np.random.normal(
                loc=0,
                scale=df["ESM"].std() * noise_level,
                size=df["ESM"].shape,
            )
            mask = df["HSPC"] == 0
            df.loc[mask, "ESM"] = df.loc[mask, "ESM"] + noise[mask]
        elif col == "HSPC":
            noise = np.random.normal(
                loc=0,
                scale=df["HSPC"].std() * noise_level,
                size=df["HSPC"].shape,
            )
            mask = df["ESM"] == 0
            df.loc[mask, "HSPC"] = df.loc[mask, "HSPC"] + noise[mask]
        else:
            noise = np.random.normal(
                loc=0,
                scale=df[col].std() * noise_level,
                size=df[col].shape,
            )
            df[col] = df[col] + noise

    return df


######################################################################################################
# 2. Pipeline
######################################################################################################

def run_gaussian_augmentation(
    input_filename: str = "formulations.csv",
    output_filename: str = "formulations_extended_with_gn.csv",
    sample_frac: float = 0.25,
    noise_level: float = 0.01,
) -> None:
    """
    Load the main cleaned dataset, subsample, add Gaussian noise, enforce
    ESM/HSPC mutual exclusivity, and save to data_extensions/.
    """
    logging.info("Starting data augmentation using Gaussian noise...".upper())
    file_path = DATA_DIR / input_filename
    logging.info("---> Loading dataset from %s", file_path)

    if not file_path.exists():
        logging.error("Input file %s not found. Aborting.", file_path)
        return

    try:
        df = pd.read_csv(file_path)
        logging.info(
            "Dataset loaded successfully with %d rows and %d columns.",
            df.shape[0],
            df.shape[1],
        )

        # Keep internal behavior: sample 25% before augmentation
        df = df.sample(frac=sample_frac, random_state=42)
        logging.info(
            "Subsampled dataset to %.0f%%: now %d rows.",
            sample_frac * 100,
            df.shape[0],
        )

        # Add Gaussian noise
        df = add_gaussian_noise(df, NUMERICAL_COLS, noise_level=noise_level)

        # Ensure 'ESM' and 'HSPC' remain mutually exclusive
        if "ESM" in df.columns and "HSPC" in df.columns:
            df.loc[df["HSPC"] > 0, "ESM"] = 0
            df.loc[df["ESM"] > 0, "HSPC"] = 0
            logging.info(
                "Ensured mutual exclusivity of 'ESM' and 'HSPC' after noise addition."
            )

        # Ensure column order
        column_order = [
            "ESM",
            "HSPC",
            "CHOL",
            "PEG",
            "TFR",
            "FRR",
            "AQUEOUS",
            "SIZE",
            "PDI",
        ]
        missing_cols = [c for c in column_order if c not in df.columns]
        if missing_cols:
            logging.warning(
                "Some expected columns are missing in the augmented dataset: %s",
                missing_cols,
            )
        existing_order = [c for c in column_order if c in df.columns]
        df = df[existing_order + [c for c in df.columns if c not in existing_order]]

        # Save
        output_path = DATA_EXT_DIR / output_filename
        df.to_csv(output_path, index=False)
        logging.info("Saved augmented data to %s", output_path)
        logging.info("...DONE!\n\n")
    except Exception as e:
        logging.error("Error during Gaussian augmentation: %s", str(e))
        raise


if __name__ == "__main__":
    run_gaussian_augmentation()
