# -*- coding: utf-8 -*-
"""Untitled21.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1z4dF4B3YhRodtzEILiyesPNr_NEE2FOr
"""

# ==========================
# ML Extrapolation with Uncertainty Bands + Multi-panel Figure
# ==========================

!pip install xgboost shap scikit-learn pandas matplotlib seaborn tensorflow

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
import xgboost as xgb
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.preprocessing import StandardScaler

# ========================================
# 1. Dataset (Table 5)
# ========================================

data = [
    # LiMnH3
    ["LiMnH3",0,2.77,4181.29,6553.59,4595.50,690.14,1457.66],
    ["LiMnH3",5,2.97,4583.37,7199.90,5038.71,774.75,1655.36],
    ["LiMnH3",10,3.13,4763.68,7647.89,5249.60,821.71,1872.61],
    ["LiMnH3",15,3.27,4975.80,7996.34,5483.94,870.51,2070.51],
    ["LiMnH3",20,3.38,5142.44,8239.31,5665.76,909.81,2172.73],
    ["LiMnH3",25,3.48,5350.82,8350.06,5877.91,953.19,2356.72],
    ["LiMnH3",30,3.57,5457.69,8748.15,6013.37,983.51,2540.71],

    # NaMnH3
    ["NaMnH3",0,2.89,3756.25,5916.35,4130.67,584.84,1118.22],
    ["NaMnH3",5,3.18,4429.84,7112.72,4881.76,713.62,1671.00],
    ["NaMnH3",10,3.46,4720.60,7866.38,5222.05,784.60,2019.50],
    ["NaMnH3",15,3.71,5167.43,7972.71,5668.85,872.46,2080.47],
    ["NaMnH3",20,3.96,5201.66,8085.77,5711.47,898.04,2172.08],
    ["NaMnH3",25,4.20,5478.54,8627.09,6024.48,965.71,2555.21],
    ["NaMnH3",30,4.42,5524.80,8622.17,6069.08,990.09,2605.76],

    # KMnH3
    ["KMnH3",0,2.73,3393.85,5412.08,3737.30,488.50,937.31],
    ["KMnH3",5,3.03,3777.47,6114.22,4166.42,564.02,1106.98],
    ["KMnH3",10,3.27,4126.42,6657.39,4549.73,632.15,1337.15],
    ["KMnH3",15,3.49,4293.26,6983.45,4737.76,672.01,1472.24],
    ["KMnH3",20,3.67,4483.78,7163.07,4938.51,712.56,1582.16],
    ["KMnH3",25,3.83,4612.80,7364.93,5080.29,743.80,1743.61],
    ["KMnH3",30,3.98,4729.90,7435.87,5200.25,771.13,1832.20],
]

columns = ["Compound","Pressure","Density","v_t","v_l","v_m","theta_D","T_m"]
df = pd.DataFrame(data, columns=columns)

# ========================================
# 2. Features and Targets
# ========================================

X = df[["Pressure","Density","v_t","v_l","v_m"]]
y = df[["theta_D","T_m"]]

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# ========================================
# 3. Train Models (RF, XGB, NN)
# ========================================

rf = MultiOutputRegressor(RandomForestRegressor(n_estimators=300, random_state=42))
rf.fit(X_scaled, y)

xgb_base = xgb.XGBRegressor(n_estimators=300, learning_rate=0.05, max_depth=5, random_state=42)
xgb_multi = MultiOutputRegressor(xgb_base)
xgb_multi.fit(X_scaled, y)

model = Sequential([
    Dense(64, activation='relu', input_shape=(X_scaled.shape[1],)),
    Dropout(0.2),
    Dense(32, activation='relu'),
    Dense(y.shape[1])
])
model.compile(optimizer='adam', loss='mse')
early_stop = EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=True)
model.fit(X_scaled, y, validation_split=0.2,
          epochs=500, batch_size=8, verbose=0, callbacks=[early_stop])

# ========================================
# 4. Extrapolation (40–60 GPa)
# ========================================

new_pressures = [40, 50, 60]
extrap_dfs = []

for comp in df["Compound"].unique():
    sub = df[df["Compound"] == comp]
    coeffs = {}
    features = ["Density","v_t","v_l","v_m"]
    for feat in features:
        coeffs[feat] = np.polyfit(sub["Pressure"], sub[feat], 1)
    extrap_rows = []
    for P in new_pressures:
        feats = [P]
        for feat in features:
            val = np.polyval(coeffs[feat], P)
            feats.append(val)
        extrap_rows.append([comp,P]+feats[1:])
    extrap_df = pd.DataFrame(extrap_rows, columns=["Compound","Pressure"]+features)
    extrap_dfs.append(extrap_df)

extrap_features = pd.concat(extrap_dfs)
X_extra = scaler.transform(extrap_features[["Pressure","Density","v_t","v_l","v_m"]])

# ========================================
# 5. Predictions + Uncertainty
# ========================================

# RF uncertainty
rf_preds_all = []
rf_stds = []
for i, est in enumerate(rf.estimators_):
    preds = np.stack([tree.predict(X_extra) for tree in est.estimators_], axis=0)
    rf_preds_all.append(preds.mean(axis=0))
    rf_stds.append(preds.std(axis=0))

rf_mean = np.vstack(rf_preds_all).T
rf_std = np.vstack(rf_stds).T

# XGB bootstrapping
boot_preds = []
for seed in range(5):
    xgb_temp = xgb.XGBRegressor(n_estimators=300, learning_rate=0.05,
                                max_depth=5, random_state=seed)
    xgb_temp.fit(X_scaled, y["theta_D"])
    theta_temp = xgb_temp.predict(X_extra)
    xgb_temp.fit(X_scaled, y["T_m"])
    Tm_temp = xgb_temp.predict(X_extra)
    boot_preds.append(np.vstack([theta_temp,Tm_temp]).T)
boot_preds = np.stack(boot_preds, axis=0)
xgb_mean = boot_preds.mean(axis=0)
xgb_std = boot_preds.std(axis=0)

# NN predictions
nn_preds = model.predict(X_extra)

# Collect results
extrap_features["theta_RF"], extrap_features["T_m_RF"] = rf_mean[:,0], rf_mean[:,1]
extrap_features["theta_RF_std"], extrap_features["T_m_RF_std"] = rf_std[:,0], rf_std[:,1]
extrap_features["theta_XGB"], extrap_features["T_m_XGB"] = xgb_mean[:,0], xgb_mean[:,1]
extrap_features["theta_XGB_std"], extrap_features["T_m_XGB_std"] = xgb_std[:,0], xgb_std[:,1]
extrap_features["theta_NN"], extrap_features["T_m_NN"] = nn_preds[:,0], nn_preds[:,1]

# ========================================
# 6. Multi-panel Figure (3x2)
# ========================================

compounds = df["Compound"].unique()
fig, axes = plt.subplots(3,2, figsize=(12,12))
panel_labels = ['(a)','(b)','(c)','(d)','(e)','(f)']

for idx, comp in enumerate(compounds):
    sub = df[df["Compound"] == comp]
    sub_ext = extrap_features[extrap_features["Compound"]==comp]

    # θD plot
    ax = axes[idx,0]
    ax.scatter(sub["Pressure"], sub["theta_D"], c="k", marker="o", label="DFT (θD)")
    ax.plot(sub_ext["Pressure"], sub_ext["theta_RF"], "--", label="RF")
    ax.fill_between(sub_ext["Pressure"],
                    sub_ext["theta_RF"]-sub_ext["theta_RF_std"],
                    sub_ext["theta_RF"]+sub_ext["theta_RF_std"],
                    alpha=0.2, color="blue")
    ax.plot(sub_ext["Pressure"], sub_ext["theta_XGB"], "--", label="XGB")
    ax.fill_between(sub_ext["Pressure"],
                    sub_ext["theta_XGB"]-sub_ext["theta_XGB_std"],
                    sub_ext["theta_XGB"]+sub_ext["theta_XGB_std"],
                    alpha=0.2, color="orange")
    ax.plot(sub_ext["Pressure"], sub_ext["theta_NN"], "--", label="NN")
    ax.set_xlabel("Pressure (GPa)")
    ax.set_ylabel("θD (K)")
    ax.set_title(f"{panel_labels[2*idx]} {comp} – θD")

    # Tm plot
    ax = axes[idx,1]
    ax.scatter(sub["Pressure"], sub["T_m"], c="k", marker="o", label="DFT (Tm)")
    ax.plot(sub_ext["Pressure"], sub_ext["T_m_RF"], "--", label="RF")
    ax.fill_between(sub_ext["Pressure"],
                    sub_ext["T_m_RF"]-sub_ext["T_m_RF_std"],
                    sub_ext["T_m_RF"]+sub_ext["T_m_RF_std"],
                    alpha=0.2, color="blue")
    ax.plot(sub_ext["Pressure"], sub_ext["T_m_XGB"], "--", label="XGB")
    ax.fill_between(sub_ext["Pressure"],
                    sub_ext["T_m_XGB"]-sub_ext["T_m_XGB_std"],
                    sub_ext["T_m_XGB"]+sub_ext["T_m_XGB_std"],
                    alpha=0.2, color="orange")
    ax.plot(sub_ext["Pressure"], sub_ext["T_m_NN"], "--", label="NN")
    ax.set_xlabel("Pressure (GPa)")
    ax.set_ylabel("Tm (K)")
    ax.set_title(f"{panel_labels[2*idx+1]} {comp} – Tm")

# Single legend for all
handles, labels = axes[0,0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=4, bbox_to_anchor=(0.5,1.02))

plt.tight_layout()
plt.show()