from pathlib import Path
import warnings

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import to_rgb
from matplotlib.patches import Patch


# ============================================================
# USER INSTRUCTIONS
# ============================================================
# Please edit ONLY the following sections as required:
#
# 1. INPUT FILE NAMES
#    Update the names of the model-importance and feature workbook files.
#
# 2. OUTPUT FILE NAMES
#    Modify the desired names for exported PNG and PDF figures.
#
# 3. MODEL SHEET NAMES
#    Update only if your workbook uses different sheet names.
#
# 4. GROUP / FEATURE SETTINGS
#    Update only if your climate features or custom group labels differ.
#
# 5. FIGURE / LAYOUT SETTINGS
#    Adjust subplot spacing, font sizes, and point/bar settings if needed.
#
# 6. TRAINING INFO IN LEGEND
#    Update the training sample counts if needed.
#
# 7. PALETTE / STYLE SETTINGS
#    Update only if needed for your own visualization style.
#
#
# Recommended folder structure:
# project/
# ├── Grouped_Feature_Importance_Analysis.py
# ├── data/
# │   ├── model_importance_avg.xlsx
# │   ├── model_importance_trend.xlsx
# │   ├── profiling_workbook.xlsx
# │   └── climate_workbook.xlsx
# └── outputs/
# ============================================================

# ============================================================

BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
OUTPUT_DIR = BASE_DIR / "outputs"

# ----------------------------
# INPUT FILES
# ----------------------------
AVG_MODEL_IMPORTANCE_FILE = DATA_DIR / "model_importance_avg.xlsx"
TREND_MODEL_IMPORTANCE_FILE = DATA_DIR / "model_importance_trend.xlsx"

PROFILE_FEATURES_FILE = DATA_DIR / "profiling_workbook.xlsx"
CLIMATE_FEATURES_FILE = DATA_DIR / "climate_workbook.xlsx"

# ----------------------------
# OUTPUT FILES
# ----------------------------
OUTPUT_PDF_NAME = "FIG_grouped_feature_importance.pdf"
OUTPUT_PNG_NAME = "FIG_grouped_feature_importance.png"

# ----------------------------
# MODEL SHEETS
# ----------------------------
MODEL_SHEETS = ["LASSO", "ElasticNet", "Bayesian Ridge", "XGBoost", "Random Forest"]

# ----------------------------
# CUSTOM GROUP SETTINGS
# Update only if your feature/group names differ
# ----------------------------
CLIMATE_FEATURES = ["T2M_AVG", "RH_AVG", "SolarRadiation_AVG"]
CLIMATE_GROUP_NAME = "CLIMATE"
COMBINED_GROUP_FEATURE = "Exceedance"
COMBINED_GROUP_NAME = "CLIMATE+\nURBAN"

# Groups to exclude in the filtered panels
EXCLUDED_GROUPS = [CLIMATE_GROUP_NAME, COMBINED_GROUP_NAME]

# ----------------------------
# TRAINING DATA / LEGEND INFO
# Update only if needed
# ----------------------------
N_INDIAN = 88
N_GLOBAL = 52
TOTAL_TRAIN_POINTS = N_INDIAN + N_GLOBAL

# ----------------------------
# GLOBAL STYLE & FONT CONTROLS
# ----------------------------
AXIS_LABEL_FONTSIZE = 7
TICK_LABEL_FONTSIZE = 7
LEGEND_FONTSIZE = 7
LEGEND_TITLE_FONTSIZE = 7
SUBPLOT_LABEL_FONTSIZE = 14
XTICK_FONTSIZE = 7
SHEET_LABEL_FONTSIZE = 7
SHEET_ARROW_WIDTH = 0.9

# ----------------------------
# DATA-POINT OVERLAY CONTROLS
# ----------------------------
POINT_SIZE = 6
POINT_ALPHA = 0.95
POINT_JITTER_FRAC = 0.14
BAR_ALPHA = 0.45
RNG_SEED = 42

# ----------------------------
# FIGURE / LAYOUT CONTROLS
# ----------------------------
FIG_W = 7.0866
FIG_H = 8.27
DPI = 600

SUBPLOTS_ADJUST = {
    "left": 0.08,
    "right": 0.995,
    "top": 0.955,
    "bottom": 0.025,
    "hspace": 1.6
}

# legend position below each subplot
LEGEND_Y_ANCHOR_TOP = -0.9
LEGEND_Y_ANCHOR_BOTTOM = -1.08

# ----------------------------
# STYLE SETTINGS
# Update palette/settings only if needed
# ----------------------------
AVG_MODEL_COLORS = {
    "LASSO": "#d2e6f3",
    "ElasticNet": "#6c94d0",
    "XGBoost": "#d25f60",
    "Random Forest": "#a73f43",
    "Bayesian Ridge": "#4c72b0"
}

TREND_MODEL_COLORS = {
    "LASSO": "#d6e5c9",
    "ElasticNet": "#a3c3a0",
    "XGBoost": "#d6a673",
    "Random Forest": "#b56b3f",
    "Bayesian Ridge": "#92b388"
}


# ============================================================
# GLOBAL SETTINGS
# ============================================================
sns.set(style="whitegrid")

plt.rcParams["font.family"] = "Arial"
plt.rcParams["font.size"] = 6
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42

RNG = np.random.default_rng(RNG_SEED)


# ============================================================
# HELPERS
# ============================================================

def validate_files():
    required_files = [
        AVG_MODEL_IMPORTANCE_FILE,
        TREND_MODEL_IMPORTANCE_FILE,
        PROFILE_FEATURES_FILE,
        CLIMATE_FEATURES_FILE,
    ]
    missing = [str(p) for p in required_files if not p.exists()]
    if missing:
        raise FileNotFoundError("Missing required files:\n" + "\n".join(missing))


def darken_color(color, factor=0.60):
    r, g, b = to_rgb(color)
    return (r * factor, g * factor, b * factor)


def extract_groups_from_excel(excel_path):
    xl = pd.ExcelFile(excel_path)
    rows = []
    ignore_cols = ["City", "Latitude", "Longitude", "Year"]

    for sheet in xl.sheet_names:
        df_tmp = pd.read_excel(excel_path, sheet_name=sheet)
        feature_cols = [c for c in df_tmp.columns if c not in ignore_cols]

        for col in feature_cols:
            rows.append({"Feature": col, "Group": sheet})

    return pd.DataFrame(rows)


def build_feature_group_table():
    df_groups_profile = extract_groups_from_excel(PROFILE_FEATURES_FILE)
    df_groups_climate_raw = extract_groups_from_excel(CLIMATE_FEATURES_FILE)

    df_climate_super = pd.DataFrame({
        "Feature": CLIMATE_FEATURES,
        "Group": [CLIMATE_GROUP_NAME] * len(CLIMATE_FEATURES)
    })

    df_combined = pd.DataFrame({
        "Feature": [COMBINED_GROUP_FEATURE],
        "Group": [COMBINED_GROUP_NAME]
    })

    df_feature_groups = pd.concat(
        [
            df_groups_profile,
            df_groups_climate_raw,
            df_climate_super,
            df_combined
        ],
        ignore_index=True
    )

    return df_feature_groups


def prepare_model_dataframe(model_importance_file, df_feature_groups):
    rows = []
    n_folds_detected = None

    for sheet in MODEL_SHEETS:
        df_model = pd.read_excel(model_importance_file, sheet_name=sheet)
        feature_col = df_model.columns[0]
        df_model = df_model.rename(columns={feature_col: "Feature"})

        fold_cols = [c for c in df_model.columns if str(c).startswith("Fold_")]
        if n_folds_detected is None:
            n_folds_detected = len(fold_cols)

        df_model = df_model.dropna(subset=["Feature"])
        df_model = df_model.merge(df_feature_groups, on="Feature", how="left")
        df_model = df_model.dropna(subset=["Group"])

        for _, row in df_model.iterrows():
            vals = row[fold_cols].values.astype(float)
            vals = vals[~np.isnan(vals)]
            n_vals = len(vals)

            rows.append({
                "Group": row["Group"],
                "Feature": row["Feature"],
                "Model": sheet,
                "MeanImportance": np.mean(vals) if n_vals > 0 else np.nan,
                "StdErr": (np.std(vals, ddof=1) / np.sqrt(n_vals)) if n_vals > 1 else np.nan,
                "DataPoints": vals,
                "N": n_vals
            })

    df_all = pd.DataFrame(rows)

    if df_all.empty:
        raise ValueError(f"No features loaded from Excel: {model_importance_file}")

    if n_folds_detected is None:
        n_folds_detected = 5

    ordered_groups = df_feature_groups["Group"].drop_duplicates().tolist()

    feature_order = []
    for g in ordered_groups:
        feats = df_feature_groups[df_feature_groups["Group"] == g]["Feature"]
        for f in feats:
            if f in df_all["Feature"].values:
                feature_order.append((g, f))

    for g, f in df_all[["Group", "Feature"]].drop_duplicates().itertuples(index=False):
        if (g, f) not in feature_order:
            feature_order.append((g, f))

    full_feature_list = feature_order.copy()
    df_filtered = df_all[~df_all["Group"].isin(EXCLUDED_GROUPS)]
    filtered_list = [(g, f) for (g, f) in full_feature_list if g not in EXCLUDED_GROUPS]

    return df_all, df_filtered, full_feature_list, filtered_list, n_folds_detected


# ============================================================
# PLOTTING
# ============================================================

def plot_feature_level_grouped(ax, df_subset, feature_list, custom_colors, group_label_height=0.05):
    groups = [g for g, f in feature_list]
    unique_groups = []
    group_sizes = []

    for g in groups:
        if g not in unique_groups:
            unique_groups.append(g)
            group_sizes.append(groups.count(g))

    n_features = len(feature_list)
    x = np.arange(n_features)
    width = 0.15

    for i, model in enumerate(MODEL_SHEETS):
        df_m = df_subset[df_subset["Model"] == model]

        means, errs, point_vals_all, ns = [], [], [], []

        for g, f in feature_list:
            r = df_m[(df_m["Group"] == g) & (df_m["Feature"] == f)]
            if len(r) == 1:
                means.append(r["MeanImportance"].values[0])
                errs.append(r["StdErr"].values[0])
                point_vals_all.append(r["DataPoints"].values[0])
                ns.append(r["N"].values[0])
            else:
                means.append(np.nan)
                errs.append(np.nan)
                point_vals_all.append(np.array([]))
                ns.append(0)

        xpos = x + (i - len(MODEL_SHEETS) / 2) * width + width / 2

        bar_color = custom_colors.get(model, "#999999")
        dot_color = darken_color(bar_color, factor=0.60)

        for xi, mean_val, err_val, vals, n_here in zip(xpos, means, errs, point_vals_all, ns):
            if np.isnan(mean_val):
                continue

            ax.bar(
                xi, mean_val,
                width=width,
                yerr=err_val,
                capsize=2,
                color=bar_color,
                alpha=BAR_ALPHA,
                edgecolor="black",
                linewidth=0
            )

            if n_here > 0:
                jitter = RNG.uniform(
                    -width * POINT_JITTER_FRAC,
                    width * POINT_JITTER_FRAC,
                    size=n_here
                )
                x_points = np.full(n_here, xi) + jitter
                y_points = vals

                ax.scatter(
                    x_points,
                    y_points,
                    s=POINT_SIZE,
                    color=dot_color,
                    alpha=POINT_ALPHA,
                    zorder=5,
                    linewidths=0
                )

    cum = 0
    for size in group_sizes[:-1]:
        cum += size
        ax.axvline(cum - 0.5, color="lightgray", linestyle="--", linewidth=1.2)

    ax.set_xticks(x)
    ax.set_xticklabels([f for (_, f) in feature_list], fontsize=XTICK_FONTSIZE, rotation=90)

    y_min, y_max = ax.get_ylim()
    if y_min == y_max:
        y_max = y_min + 1

    y_label = y_max + (y_max - y_min) * group_label_height

    cum = 0
    for g, size in zip(unique_groups, group_sizes):
        left = cum
        right = cum + size - 1
        mid = (left + right) / 2

        ax.text(
            mid, y_label, g,
            ha="center", va="bottom",
            fontsize=SHEET_LABEL_FONTSIZE, fontweight="bold"
        )

        ax.annotate(
            "",
            xy=(left, y_label * 0.985),
            xytext=(right, y_label * 0.985),
            arrowprops=dict(arrowstyle="<->", color="black", lw=SHEET_ARROW_WIDTH)
        )

        cum += size

    ax.set_ylim(y_min, y_label * 1.04)
    ax.set_ylabel("Feature \nImportance", fontsize=AXIS_LABEL_FONTSIZE, fontweight="bold")
    ax.tick_params(axis="y", labelsize=TICK_LABEL_FONTSIZE)
    ax.yaxis.grid(True, linestyle="--", linewidth=0.5, alpha=0.4)
    ax.xaxis.grid(False)

    for pos, spine in ax.spines.items():
        if pos in ["left", "bottom"]:
            spine.set_linewidth(0.9)
            spine.set_color("black")
        else:
            spine.set_visible(False)


def add_subplot_legend(ax, custom_colors, n_folds_detected, y_anchor=-0.78):
    legend_handles = [
        Patch(facecolor=custom_colors[m], edgecolor="black", alpha=1.0, label=m)
        for m in MODEL_SHEETS
    ]

    legend_title_text = (
        f"Bars = mean feature importance across {n_folds_detected} folds; "
        f"error bars = ±SE across folds; "
        f"dots = fold-wise values; "
        f"models trained on {N_INDIAN} Indian + {N_GLOBAL} global data "
        f"(total n = {TOTAL_TRAIN_POINTS})"
    )

    ax.legend(
        handles=legend_handles,
        labels=MODEL_SHEETS,
        fontsize=LEGEND_FONTSIZE,
        title=legend_title_text,
        title_fontsize=LEGEND_TITLE_FONTSIZE,
        ncol=5,
        loc="upper center",
        bbox_to_anchor=(0.5, y_anchor),
        frameon=True,
        borderaxespad=0.2,
        columnspacing=0.8,
        handletextpad=0.4
    )


def create_figure(df_all_avg, df_filtered_avg, full_feature_list_avg, filtered_list_avg, n_folds_avg,
                  df_all_trend, df_filtered_trend, full_feature_list_trend, filtered_list_trend, n_folds_trend):
    fig, axes = plt.subplots(4, 1, figsize=(FIG_W, FIG_H), dpi=DPI)
    ax1, ax2, ax3, ax4 = axes

    plot_feature_level_grouped(ax1, df_all_avg, full_feature_list_avg, AVG_MODEL_COLORS, group_label_height=0.06)
    plot_feature_level_grouped(ax2, df_filtered_avg, filtered_list_avg, AVG_MODEL_COLORS, group_label_height=0.06)

    plot_feature_level_grouped(ax3, df_all_trend, full_feature_list_trend, TREND_MODEL_COLORS, group_label_height=0.06)
    plot_feature_level_grouped(ax4, df_filtered_trend, filtered_list_trend, TREND_MODEL_COLORS, group_label_height=0.06)

    add_subplot_legend(ax1, AVG_MODEL_COLORS, n_folds_avg, y_anchor=LEGEND_Y_ANCHOR_TOP)
    add_subplot_legend(ax2, AVG_MODEL_COLORS, n_folds_avg, y_anchor=LEGEND_Y_ANCHOR_TOP)
    add_subplot_legend(ax3, TREND_MODEL_COLORS, n_folds_trend, y_anchor=LEGEND_Y_ANCHOR_BOTTOM)
    add_subplot_legend(ax4, TREND_MODEL_COLORS, n_folds_trend, y_anchor=LEGEND_Y_ANCHOR_BOTTOM)

    ax1.text(-0.02, 1.02, "a", transform=ax1.transAxes, fontsize=SUBPLOT_LABEL_FONTSIZE, fontweight="bold")
    ax2.text(-0.02, 1.02, "b", transform=ax2.transAxes, fontsize=SUBPLOT_LABEL_FONTSIZE, fontweight="bold")
    ax3.text(-0.02, 1.02, "c", transform=ax3.transAxes, fontsize=SUBPLOT_LABEL_FONTSIZE, fontweight="bold")
    ax4.text(-0.02, 1.02, "d", transform=ax4.transAxes, fontsize=SUBPLOT_LABEL_FONTSIZE, fontweight="bold")

    plt.subplots_adjust(
        left=SUBPLOTS_ADJUST["left"],
        right=SUBPLOTS_ADJUST["right"],
        top=SUBPLOTS_ADJUST["top"],
        bottom=SUBPLOTS_ADJUST["bottom"],
        hspace=SUBPLOTS_ADJUST["hspace"]
    )

    return fig


# ============================================================
# MAIN
# ============================================================

def main():
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    validate_files()

    df_feature_groups = build_feature_group_table()

    df_all_avg, df_filtered_avg, full_feature_list_avg, filtered_list_avg, n_folds_avg = prepare_model_dataframe(
        AVG_MODEL_IMPORTANCE_FILE, df_feature_groups
    )
    df_all_trend, df_filtered_trend, full_feature_list_trend, filtered_list_trend, n_folds_trend = prepare_model_dataframe(
        TREND_MODEL_IMPORTANCE_FILE, df_feature_groups
    )

    fig = create_figure(
        df_all_avg, df_filtered_avg, full_feature_list_avg, filtered_list_avg, n_folds_avg,
        df_all_trend, df_filtered_trend, full_feature_list_trend, filtered_list_trend, n_folds_trend
    )

    out_pdf = OUTPUT_DIR / OUTPUT_PDF_NAME
    out_png = OUTPUT_DIR / OUTPUT_PNG_NAME

    fig.savefig(
        out_pdf,
        format="pdf",
        dpi=DPI,
        bbox_inches="tight",
        pad_inches=0.02
    )

    fig.savefig(
        out_png,
        format="png",
        dpi=DPI,
        bbox_inches="tight",
        pad_inches=0.02
    )

    plt.show()

    print(f"Figure saved as PDF: {out_pdf}")
    print(f"Figure saved as PNG: {out_png}")


if __name__ == "__main__":
    main()