# -*- coding: utf-8 -*-
"""script-1.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1VfQIgiRGfHwSaQFgSAzZgWoGjW_LW8hU
"""

# ==========================
# Hydrogen Perovskite ML Workflow
# Direct vs Derived properties + Compact Panel Figures + Ranking + Labels
# ==========================

!pip install xgboost shap scikit-learn pandas matplotlib seaborn tensorflow

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
import xgboost as xgb
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.preprocessing import StandardScaler

# ========================================
# 1. Dataset (Tables 4 & 5 merged)
# ========================================

data = [
    # LiMnH3
    ["LiMnH3",0, 1.12,0.16,-31.17,0.49,13.88, 2.77,4181.29,6553.59,4595.50,690.14,1457.66],
    ["LiMnH3",5, 1.13,0.16,-37.00,0.57,16.37, 2.97,4583.37,7199.90,5038.71,774.75,1655.36],
    ["LiMnH3",10,1.24,0.18,-34.95,0.55,15.76, 3.13,4763.68,7647.89,5249.60,821.71,1872.61],
    ["LiMnH3",15,1.25,0.18,-39.85,0.54,17.14, 3.27,4975.80,7996.34,5483.94,870.51,2070.51],
    ["LiMnH3",20,1.23,0.18,-43.71,0.59,18.67, 3.38,5142.44,8239.31,5665.76,909.81,2172.73],
    ["LiMnH3",25,1.10,0.15,-64.52,0.52,23.37, 3.48,5350.82,8350.06,5877.91,953.19,2356.72],
    ["LiMnH3",30,1.24,0.18,-53.80,0.54,20.95, 3.57,5457.69,8748.15,6013.37,983.51,2540.71],

    # NaMnH3
    ["NaMnH3",0, 1.15,0.16,-21.45,1.20,11.91, 2.89,3756.25,5916.35,4130.67,584.84,1118.22],
    ["NaMnH3",5, 1.24,0.18,-29.30,0.61,14.38, 3.18,4429.84,7112.72,4881.76,713.62,1671.00],
    ["NaMnH3",10,1.44,0.22,-20.74,0.62,13.52, 3.46,4720.60,7866.38,5222.05,784.60,2019.50],
    ["NaMnH3",15,1.05,0.14,-62.79,0.77,24.89, 3.71,5167.43,7972.71,5668.85,872.46,2080.47],
    ["NaMnH3",20,1.05,0.14,-62.79,0.77,25.06, 3.96,5201.66,8085.77,5711.47,898.04,2172.08],
    ["NaMnH3",25,1.15,0.16,-67.03,0.78,25.86, 4.20,5478.54,8627.09,6024.48,965.71,2555.21],
    ["NaMnH3",30,1.10,0.15,-76.92,0.85,28.47, 4.42,5524.80,8622.17,6069.08,990.09,2605.76],

    # KMnH3
    ["KMnH3",0, 1.21,0.18,-17.67,2.08,9.03, 2.73,3393.85,5412.08,3737.30,488.50,937.31],
    ["KMnH3",5, 1.29,0.19,-20.50,2.00,10.49, 3.03,3777.47,6114.22,4166.42,564.02,1106.98],
    ["KMnH3",10,1.27,0.19,-23.21,1.36,12.90, 3.27,4126.42,6657.39,4549.73,632.15,1337.15],
    ["KMnH3",15,1.31,0.20,-23.97,1.36,13.61, 3.49,4293.26,6983.45,4737.76,672.01,1472.24],
    ["KMnH3",20,1.22,0.18,-34.02,1.30,16.64, 3.67,4483.78,7163.07,4938.51,712.56,1582.16],
    ["KMnH3",25,1.22,0.18,-36.94,1.11,17.89, 3.83,4612.80,7364.93,5080.29,743.80,1743.61],
    ["KMnH3",30,1.14,0.16,-47.15,1.06,20.77, 3.98,4729.90,7435.87,5200.25,771.13,1832.20],
]

columns = [
    "Compound","Pressure","B/G","Poisson","CauchyP","Anisotropy","HV",
    "Density","v_t","v_l","v_m","theta_D","T_m"
]

df = pd.DataFrame(data, columns=columns)

# ========================================
# 2. Define groups of targets
# ========================================

direct_targets = ["B/G","Poisson","CauchyP","Anisotropy"]
derived_targets = ["HV","theta_D","T_m"]

groups = {
    "Direct DFT": direct_targets,
    "DFT-derived": derived_targets
}

# ========================================
# 3. Helper: Train & Evaluate Models
# ========================================

def train_and_evaluate(X, y, target_group_name):
    results = []
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

    # Random Forest
    rf = MultiOutputRegressor(RandomForestRegressor(n_estimators=300, random_state=42))
    rf.fit(X_train, y_train)
    y_pred_rf = rf.predict(X_test)

    # XGBoost
    xgb_base = xgb.XGBRegressor(n_estimators=300, learning_rate=0.05, max_depth=5, random_state=42)
    xgb_multi = MultiOutputRegressor(xgb_base)
    xgb_multi.fit(X_train, y_train)
    y_pred_xgb = xgb_multi.predict(X_test)

    # Neural Network
    model = Sequential([
        Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
        Dropout(0.2),
        Dense(32, activation='relu'),
        Dense(y_train.shape[1])  # multi-output
    ])
    model.compile(optimizer='adam', loss='mse')
    early_stop = EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=True)
    model.fit(X_train, y_train, validation_split=0.2,
              epochs=500, batch_size=8, verbose=0, callbacks=[early_stop])
    y_pred_nn = model.predict(X_test)

    # Evaluate metrics
    for i, col in enumerate(y.columns):
        mean_val = np.mean(y_test[col])  # normalization baseline
        for model_name, pred in zip(
            ["Random Forest", "XGBoost", "Neural Network"],
            [y_pred_rf, y_pred_xgb, y_pred_nn]):
            results.append([target_group_name, col, model_name,
                            r2_score(y_test[col], pred[:,i]),
                            mean_absolute_error(y_test[col], pred[:,i]),
                            np.sqrt(mean_squared_error(y_test[col], pred[:,i])),
                            mean_absolute_error(y_test[col], pred[:,i]) / mean_val * 100,
                            np.sqrt(mean_squared_error(y_test[col], pred[:,i])) / mean_val * 100])
    return results

# ========================================
# 4. Run for both groups
# ========================================

all_results = []
X_features = df.drop(columns=["Compound"] + direct_targets + derived_targets)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_features)

for group_name, targets in groups.items():
    y = df[targets]
    results = train_and_evaluate(X_scaled, y, group_name)
    all_results.extend(results)

results_df = pd.DataFrame(all_results,
    columns=["Group","Property","Model","R2","MAE","RMSE","MAE_norm(%)","RMSE_norm(%)"])
print(results_df)

# ========================================
# 5. Panel Heatmaps + Overall Ranking with Labels
# ========================================

metrics = ["R2", "MAE", "RMSE", "MAE_norm(%)", "RMSE_norm(%)"]
titles  = ["R²", "MAE", "RMSE", "Normalized MAE (%)", "Normalized RMSE (%)"]
cmaps   = ["coolwarm", "YlGnBu", "YlOrRd", "Blues", "Reds"]
centers = [0, None, None, None, None]

fig, axes = plt.subplots(2, 3, figsize=(18,10))
axes = axes.flatten()

panel_labels = ['(a)','(b)','(c)','(d)','(e)','(f)']

for i, metric in enumerate(metrics):
    pivot_df = results_df.pivot_table(index=["Property"], columns="Model", values=metric)
    sns.heatmap(pivot_df, annot=True, cmap=cmaps[i], center=centers[i], fmt=".2f", ax=axes[i])
    axes[i].set_title(f"{panel_labels[i]} {titles[i]}", fontsize=14)

# Overall ranking in last panel
ranking = results_df.groupby("Model")[["R2","MAE_norm(%)","RMSE_norm(%)"]].mean().reset_index()
ranking_melt = ranking.melt(id_vars="Model", value_vars=["R2","MAE_norm(%)","RMSE_norm(%)"],
                            var_name="Metric", value_name="Score")
sns.barplot(data=ranking_melt, x="Metric", y="Score", hue="Model", ax=axes[-1])
axes[-1].set_title(f"{panel_labels[-1]} Overall Ranking", fontsize=14)

plt.suptitle("Performance Heatmaps + Overall Ranking", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

# ========================================
# 6. Panel Bar Plots for R² per Property with Labels
# ========================================

properties = results_df["Property"].unique()
n_props = len(properties)

fig, axes = plt.subplots(2, (n_props+1)//2, figsize=(18,8))
axes = axes.flatten()
panel_labels_props = [f"({chr(97+i)})" for i in range(len(properties))]  # (a),(b),(c)...

for i, prop in enumerate(properties):
    subset = results_df[results_df["Property"] == prop]
    sns.barplot(data=subset, x="Model", y="R2", hue="Group", palette="Set2", ax=axes[i])
    axes[i].axhline(0, color="k", linestyle="--", linewidth=0.7)
    axes[i].set_title(f"{panel_labels_props[i]} {prop}", fontsize=13)
    axes[i].set_ylim(-1, 1)

# Hide unused subplots
for j in range(i+1, len(axes)):
    fig.delaxes(axes[j])

plt.suptitle("R² Comparison per Property (All Models)", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()