"""
Train and serialise the best-performing model for each emission output, fitted
on the full 45-sample dataset. These models are intended for downstream use
(prediction on new operating points, virtual-sensor deployment, etc.).

NOTE: These are NOT cross-validated predictions. They are trained on all 45
samples and serialised to disk. For estimating prediction error on unseen
data, refer to the LOOCV results in `../results/table6_loocv.json`.

Output: ../trained_models/{model}_{output}.joblib + all_models_summary.json

Run as:
    python train_final_models.py
"""

from __future__ import annotations

import json
import sys
from pathlib import Path

import joblib
import pandas as pd

SCRIPT_DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(SCRIPT_DIR))

from models import OUTPUTS, make_features, make_models  # noqa: E402

DATA_PATH = SCRIPT_DIR.parent / "data" / "emission_data.csv"
MODELS_DIR = SCRIPT_DIR.parent / "trained_models"
MODELS_DIR.mkdir(exist_ok=True)


# Per-output best-model assignments — these are taken from Table 6 of the
# manuscript (the model with the highest LOOCV R² for each output).
BEST_MODEL_PER_OUTPUT = {
    "CO": "GB",
    "HC": "GB",
    "CO2": "Poly",
    "O2": "GB",
    "NOx": "RF",
    "Lambda": "GB",
}


def main():
    print(f"Loading dataset: {DATA_PATH}")
    df = pd.read_csv(DATA_PATH)
    X = make_features(df)
    print(f"Loaded {len(df)} samples; X shape = {X.shape}")
    print()

    summary = {
        "dataset": "VCR diesel engine emission, 45 conditions",
        "feature_columns": ["CR", "load_pct", "fuel_RSO", "fuel_Algae"],
        "feature_note": (
            "Categorical fuel encoded as two indicator columns. "
            "Diesel is the reference category (both indicators = 0). "
            "RSO: fuel_RSO=1, fuel_Algae=0. "
            "Algae: fuel_RSO=0, fuel_Algae=1."
        ),
        "outputs": {},
    }

    for output_name, model_name in BEST_MODEL_PER_OUTPUT.items():
        print(f"Training {model_name} on {output_name} (full dataset)...")
        y = df[output_name].values.astype(float)
        model = make_models()[model_name]
        model.fit(X, y)

        out_path = MODELS_DIR / f"{model_name.lower()}_{output_name.lower()}.joblib"
        joblib.dump(model, out_path)

        # Also dump the alternative GB models for outputs where another model wins
        # (so users have GB available for every output, which is the manuscript's
        # primary recommendation)
        if model_name != "GB":
            gb_model = make_models()["GB"]
            gb_model.fit(X, y)
            gb_path = MODELS_DIR / f"gb_{output_name.lower()}.joblib"
            joblib.dump(gb_model, gb_path)
            also_saved = f"gb_{output_name.lower()}.joblib"
        else:
            also_saved = None

        summary["outputs"][output_name] = {
            "best_model": model_name,
            "trained_file": out_path.name,
            "also_saved": also_saved,
            "loocv_r2": None,  # filled in from results JSON if available
        }
        print(f"  saved -> {out_path.name}")

    # Annotate with LOOCV R² values from results JSON if available
    results_json = SCRIPT_DIR.parent / "results" / "table6_loocv.json"
    if results_json.exists():
        with open(results_json) as f:
            t6 = json.load(f)
        for output_name in summary["outputs"]:
            best_m = summary["outputs"][output_name]["best_model"]
            summary["outputs"][output_name]["loocv_r2"] = t6[output_name][best_m]["R2"]
            summary["outputs"][output_name]["loocv_rmse"] = t6[output_name][best_m]["RMSE"]
            summary["outputs"][output_name]["loocv_mape_pct"] = t6[output_name][best_m]["MAPE"]

    with open(MODELS_DIR / "all_models_summary.json", "w") as f:
        json.dump(summary, f, indent=2)
    print()
    print(f"Wrote all_models_summary.json -> {MODELS_DIR}/")


if __name__ == "__main__":
    main()
