import os
import re
import glob
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score
from torch.utils.data import TensorDataset, DataLoader
import random

# --- Reproducibility ---
seed = 29
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# --- Constants ---
base_result_path = f"/gws/nopw/j04/naiar/abhilash/tmp/Results_{seed}"
os.makedirs(base_result_path, exist_ok=True)
base_path = "/home/users/abhilash/IISERB/Diffusion/Final_Data/FinalStations"
depth_map = {
    "0.000000_0.050000": ("5cm", 5),
    "0.050000_0.050000": ("5cm", 5),
    "0.100000_0.100000": ("10cm", 10),
    "0.200000_0.200000": ("20cm", 20),
    "0.400000_0.400000": ("40cm", 40)
}

# --- Parse station and depth from filename ---
def parse_filename(filename):
    fname = os.path.basename(filename)
    match = re.search(r"(.+?)_sm_([0-9.]+_[0-9.]+)_.*\.stm$", fname)
    if match:
        station_name = match.group(1)
        depth_key = match.group(2)
        return station_name, depth_key
    return None, None

# --- Find files for a given station ---
def get_station_files():
    all_files = glob.glob(os.path.join(base_path, "*.stm"))
    station_depth_files = {}

    for file in all_files:
        fname = os.path.basename(file)
        station, depth_key = parse_filename(fname)

        if station is None or depth_key not in depth_map:
            print(f"⚠️ Skipping file (unmatched or unknown depth): {fname}")
            continue

        if station not in station_depth_files:
            station_depth_files[station] = {}
        station_depth_files[station][depth_map[depth_key][0]] = os.path.join(base_path, fname)

    return station_depth_files

# --- Load data from a file ---
def extract_sm_column(filepath, depth_label):
    df = pd.read_csv(filepath, sep='\s+')
    df = df.rename(columns={df.columns[0]: "Date", df.columns[1]: "Time", df.columns[2]: f"SM_{depth_label}"})
    df["Datetime"] = pd.to_datetime(df["Date"] + " " + df["Time"])
    return df[["Datetime", f"SM_{depth_label}"]]

# --- Load and merge data ---
def load_station_data(file_dict):
    dfs = []
    for depth in ["5cm", "10cm", "20cm", "40cm"]:
        if depth not in file_dict:
            return None
        dfs.append(extract_sm_column(file_dict[depth], depth))
    df = dfs[0]
    for i in range(1, 4):
        df = df.merge(dfs[i], on="Datetime")
    return df.dropna().reset_index(drop=True)

# --- Diffusion model utilities ---
def linear_beta_schedule(timesteps):
    return torch.linspace(0.0001, 0.02, timesteps)

class Denoiser(nn.Module):
    def __init__(self, cond_dim, target_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(cond_dim + target_dim + 1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, target_dim)
        )

    def forward(self, x, t, cond):
        t = t.float().unsqueeze(-1) / 1000.0
        x_input = torch.cat([x, t, cond], dim=-1)
        return self.net(x_input)

@torch.no_grad()
def sample(model, cond_input, timesteps=1000):
    model.eval()
    device = next(model.parameters()).device
    betas = linear_beta_schedule(timesteps).to(device)
    alphas = 1.0 - betas
    alpha_hat = torch.cumprod(alphas, dim=0)

    x = torch.randn((cond_input.shape[0], 3)).to(device)
    for t in reversed(range(timesteps)):
        t_batch = torch.full((x.shape[0],), t, device=device)
        z = torch.randn_like(x) if t > 0 else 0
        noise_pred = model(x, t_batch, cond_input.to(device))
        a = alphas[t]
        a_hat = alpha_hat[t]
        x = (x - (1 - a).sqrt() * noise_pred) / a.sqrt()
        if t > 0:
            x += betas[t].sqrt() * z
    return x

def train_diffusion_combined(model, loader, out_dir, timesteps=1000, epochs=1000, device="cuda", lambda_smooth=0.1, lambda_fick=0.1):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    betas = linear_beta_schedule(timesteps).to(device)
    alphas = 1.0 - betas
    alpha_hat = torch.cumprod(alphas, dim=0)
    mse_loss = nn.MSELoss()
    loss_log = []

    for epoch in range(epochs):
        for x_cond, y_target in loader:
            x_cond, y_target = x_cond.to(device), y_target.to(device)
            t = torch.randint(0, timesteps, (y_target.shape[0],), device=device)
            a_hat = alpha_hat[t].unsqueeze(-1)
            noise = torch.randn_like(y_target)
            y_noisy = torch.sqrt(a_hat) * y_target + torch.sqrt(1 - a_hat) * noise
            noise_pred = model(y_noisy, t, x_cond)

            loss_data = mse_loss(noise_pred, noise)
            smoothness = torch.mean((noise_pred[:, 1:] - noise_pred[:, :-1]) ** 2)
            curvature = torch.mean((noise_pred[:, 2:] - 2 * noise_pred[:, 1:-1] + noise_pred[:, :-2]) ** 2)
            total_loss = loss_data + lambda_smooth * smoothness + lambda_fick * curvature

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        loss_log.append({"Epoch": epoch + 1, "DataLoss": loss_data.item(), "Smoothness": smoothness.item(), "Fickian": curvature.item()})
        print(f"[{epoch+1}/{epochs}] DataLoss: {loss_data.item():.4f}, Smooth: {smoothness.item():.4f}, Fickian: {curvature.item():.4f}")

    pd.DataFrame(loss_log).to_csv(os.path.join(out_dir, "training_loss_log.csv"), index=False)

# --- Process One Station ---
def run_pipeline_for_station(station, file_dict):
    print(f"🔍 Processing station: {station}")
    df = load_station_data(file_dict)
    if df is None:
        print(f"⚠️  Missing depths for station {station}, skipping.")
        return

    # Split and normalize
    n = int(0.7 * len(df))
    df_train, df_test = df.iloc[:n], df.iloc[n:]

    scaler_x = MinMaxScaler()
    scaler_y = MinMaxScaler()

    x_train = scaler_x.fit_transform(df_train[["SM_5cm"]].values)
    y_train = scaler_y.fit_transform(df_train[["SM_10cm", "SM_20cm", "SM_40cm"]].values)
    x_test = scaler_x.transform(df_test[["SM_5cm"]].values)
    y_test = df_test[["SM_10cm", "SM_20cm", "SM_40cm"]].values

    x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32)

    train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=128, shuffle=False)

    model = Denoiser(cond_dim=1, target_dim=3)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    os.makedirs(os.path.join(base_result_path, station), exist_ok=True)

    train_diffusion_combined(model, train_loader, out_dir=os.path.join(base_result_path, station), device=device)

    x_test_tensor = torch.tensor(x_test, dtype=torch.float32)

    for N in range(5, 105, 5):
        print(f"   🔁 Sampling N={N}")
        out_dir = os.path.join(base_result_path, f"N_{N}", station)
        os.makedirs(out_dir, exist_ok=True)

        ensemble_outputs = []
        for _ in range(N):
            y_pred_norm = sample(model, x_test_tensor.to(device))
            y_pred = scaler_y.inverse_transform(y_pred_norm.cpu().numpy())
            ensemble_outputs.append(y_pred)

        ensemble_outputs = np.stack(ensemble_outputs, axis=0)
        y_pred_mean = ensemble_outputs.mean(axis=0)
        y_pred_std = ensemble_outputs.std(axis=0)
        y_pred_min = ensemble_outputs.min(axis=0)
        y_pred_max = ensemble_outputs.max(axis=0)

        depths = ["10cm", "20cm", "40cm"]
        metrics = {}
        for i, d in enumerate(depths):
            rmse = np.sqrt(mean_squared_error(y_test[:, i], y_pred_mean[:, i]))
            std = np.std(y_test[:, i])
            nrmse = rmse / std
            r2 = r2_score(y_test[:, i], y_pred_mean[:, i])
            metrics.update({f"RMSE_{d}": rmse, f"nRMSE_{d}": nrmse, f"R2_{d}": r2})

        pd.DataFrame([metrics]).to_csv(os.path.join(out_dir, "evaluation_metrics.csv"), index=False)

        # Save predictions
        pred_df = pd.DataFrame({
            "Datetime": df_test["Datetime"].values,
            "Observed_10cm": y_test[:, 0],
            "Pred_Mean_10cm": y_pred_mean[:, 0],
            "Pred_Min_10cm": y_pred_min[:, 0],
            "Pred_Max_10cm": y_pred_max[:, 0],
            "Pred_Std_10cm": y_pred_std[:, 0],
            "Observed_20cm": y_test[:, 1],
            "Pred_Mean_20cm": y_pred_mean[:, 1],
            "Pred_Min_20cm": y_pred_min[:, 1],
            "Pred_Max_20cm": y_pred_max[:, 1],
            "Pred_Std_20cm": y_pred_std[:, 1],
            "Observed_40cm": y_test[:, 2],
            "Pred_Mean_40cm": y_pred_mean[:, 2],
            "Pred_Min_40cm": y_pred_min[:, 2],
            "Pred_Max_40cm": y_pred_max[:, 2],
            "Pred_Std_40cm": y_pred_std[:, 2],
        })
        pred_df.to_csv(os.path.join(out_dir, "soil_moisture_predictions_with_uncertainty.csv"), index=False)

        # Plot
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        for i in range(3):
            axes[i].errorbar(
                y_test[:, i], y_pred_mean[:, i], yerr=y_pred_std[:, i],
                fmt='o', alpha=0.6, ecolor='gray', elinewidth=1.2, capsize=3
            )
            axes[i].plot([y_test[:, i].min(), y_test[:, i].max()],
                         [y_test[:, i].min(), y_test[:, i].max()],
                         'r--')
            axes[i].set_title(f"Observed vs Predicted at {depths[i]} (N={N})")
            axes[i].set_xlabel("Observed")
            axes[i].set_ylabel("Predicted")
            axes[i].grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, "soil_moisture_ensemble_uncertainty.png"), dpi=300)
        plt.close()
    print(f"✅ Completed processing for station: {station}")

# --- MAIN EXECUTION ---
if __name__ == "__main__":
    station_files = get_station_files()
    for station, file_dict in station_files.items():
        if set(["5cm", "10cm", "20cm", "40cm"]).issubset(file_dict.keys()):
            run_pipeline_for_station(station, file_dict)
