import os
import sys
import logging
from pathlib import Path
from typing import List
import pandas as pd
from imblearn.over_sampling import SMOTENC
from sklearn.preprocessing import LabelEncoder


######################################################################################################
# 0. Paths & Logging
######################################################################################################

BASE_DIR = Path(__file__).resolve().parents[1]
DATA_DIR = BASE_DIR / "data"
DATA_EXT_DIR = BASE_DIR / "data_extensions"
LOGS_DIR = BASE_DIR / "code" / "logs"

DATA_EXT_DIR.mkdir(parents=True, exist_ok=True)
LOGS_DIR.mkdir(parents=True, exist_ok=True)


def setup_logging(log_dir: Path, log_filename: str) -> Path:
    """
    Basic logging setup: writes both to a log file and to stdout.
    """
    log_dir.mkdir(parents=True, exist_ok=True)
    log_path = log_dir / log_filename

    if not logging.getLogger().handlers:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
            handlers=[
                logging.FileHandler(log_path),
                logging.StreamHandler(sys.stdout),
            ],
        )
    else:
        # Avoid duplicate file handlers
        file_handlers = {
            getattr(h, "baseFilename", None)
            for h in logging.getLogger().handlers
            if isinstance(h, logging.FileHandler)
        }
        if str(log_path) not in file_handlers:
            logging.getLogger().addHandler(logging.FileHandler(log_path))

    return log_path


LOG_FILE_PATH = setup_logging(LOGS_DIR, "data_smote_adder.log")


######################################################################################################
# 1. Feature Columns
######################################################################################################

FEATURE_COLS: List[str] = [
    "ESM",
    "HSPC",
    "CHOL",
    "PEG",
    "TFR",
    "FRR",
    "SIZE",
    "PDI",
    "AQUEOUS_ENC",
]


######################################################################################################
# 2. Main SMOTE Augmentation Pipeline
######################################################################################################

def run_smote_augmentation(
    input_filename: str = "formulations.csv",
    output_filename: str = "formulations_extended_with_SMOTE.csv",
    sample_frac: float = 0.25,
) -> None:
    """
    Load dataset → sample 25% → encode AQUEOUS → drop NaNs (Option A) →
    run SMOTENC → restore AQUEOUS → enforce mutual exclusivity →
    save to data_extensions/.
    """

    logging.info("STARTING DATA AUGMENTATION USING SMOTENC (SMOTE + categorical)...")
    file_path = DATA_DIR / input_filename
    logging.info(f"---> Loading dataset from {file_path}")

    if not file_path.exists():
        logging.error(f"Input file {file_path} not found. Aborting.")
        return

    try:
        df = pd.read_csv(file_path)
        logging.info(f"Dataset loaded successfully with {df.shape[0]} rows and {df.shape[1]} columns.")

        # Keep original behavior → sample 25%
        df = df.sample(frac=sample_frac, random_state=42)
        logging.info(f"Subsampled dataset to {sample_frac*100:.0f}% → {df.shape[0]} rows.")

        # Ensure AQUEOUS exists
        if "AQUEOUS" not in df.columns:
            logging.error("Column 'AQUEOUS' not found. Cannot perform SMOTENC.")
            return

        # Encode AQUEOUS
        logging.info("Encoding 'AQUEOUS' column.")
        label_encoder = LabelEncoder()
        df["AQUEOUS_ENC"] = label_encoder.fit_transform(df["AQUEOUS"])
        logging.info(f"AQUEOUS categories: {list(label_encoder.classes_)}")

        # -------------------------------------------------------------------------------------------
        # OPTION A: DROP NaNs BEFORE SMOTE
        # -------------------------------------------------------------------------------------------
        required_cols = ["ESM", "HSPC", "CHOL", "PEG", "TFR", "FRR", "SIZE", "PDI", "AQUEOUS_ENC"]
        before_drop = df.shape[0]
        df = df.dropna(subset=required_cols)
        after_drop = df.shape[0]
        logging.info(f"Dropped {before_drop - after_drop} rows containing NaN before SMOTENC.")

        if df.empty:
            logging.error("All rows contain NaN after cleaning → cannot run SMOTENC.")
            return

        # Check missing required features
        missing_features = [c for c in FEATURE_COLS if c not in df.columns]
        if missing_features:
            logging.error(f"Missing required columns for SMOTENC: {missing_features}")
            return

        # SMOTE setup
        categorical_indices = [FEATURE_COLS.index("AQUEOUS_ENC")]
        smote = SMOTENC(categorical_features=categorical_indices, random_state=42)

        logging.info("Applying SMOTENC to balance dataset...")
        X_resampled, y_resampled = smote.fit_resample(df[FEATURE_COLS], df["AQUEOUS_ENC"])
        logging.info(f"SMOTENC completed. Original: {df.shape[0]} rows → Resampled: {X_resampled.shape[0]} rows.")

        # Build DataFrame
        df_resampled = pd.DataFrame(X_resampled, columns=FEATURE_COLS)

        # Decode AQUEOUS back
        df_resampled["AQUEOUS"] = label_encoder.inverse_transform(
            df_resampled["AQUEOUS_ENC"].astype(int)
        )
        df_resampled.drop(columns=["AQUEOUS_ENC"], inplace=True)

        # Reorder columns
        column_order = [
            "ESM", "HSPC", "CHOL", "PEG",
            "TFR", "FRR", "AQUEOUS",
            "SIZE", "PDI"
        ]
        df_resampled = df_resampled[column_order]

        # Enforce mutual exclusivity ESM ↔ HSPC
        df_resampled.loc[df_resampled["HSPC"] > 0, "ESM"] = 0
        df_resampled.loc[df_resampled["ESM"] > 0, "HSPC"] = 0

        # Save augmented dataset
        output_path = DATA_EXT_DIR / output_filename
        df_resampled.to_csv(output_path, index=False)
        logging.info(f"Saved augmented dataset → {output_path}")
        logging.info("...DONE!\n")

    except Exception as e:
        logging.error(f"Error during SMOTENC augmentation: {e}")
        raise


######################################################################################################
# 3. Run script as executable
######################################################################################################

if __name__ == "__main__":
    run_smote_augmentation()
