import os
import numpy as np
import xarray as xr

# =========================================================
# code_4
# Pointwise multivariate EOF analysis
# Regimes: lower20, upper20
# Variables: chl-a, SST, SSHF, MLD, WS, NO3, PO4, Fe, Si
# =========================================================

INPUT_ROOT = "./quantile_regime_data"
OUTPUT_ROOT = "./eof_results"
os.makedirs(OUTPUT_ROOT, exist_ok=True)

regimes = ["lower20", "upper20"]

var_files = [
    "chl_a",
    "SST",
    "SSHF",
    "MLD",
    "WS",
    "NO3",
    "PO4",
    "Fe",
    "Si",
]

def load_regime_data(regime):
    data_list = []
    ref = None

    for var in var_files:
        path = os.path.join(INPUT_ROOT, regime, f"{var}_{regime}.nc")
        ds = xr.open_dataset(path)
        da = ds[list(ds.data_vars)[0]]

        if ref is None:
            ref = da
        data_list.append(da.values)

    # stack to (time, lat, lon, var)
    X = np.stack(data_list, axis=-1)
    return ref, X

for regime in regimes:
    ref_da, X = load_regime_data(regime)

    T, nlat, nlon, nvar = X.shape
    n_modes = nvar

    eigvals = np.full((nlat, nlon, nvar), np.nan, dtype=np.float32)
    evr = np.full((nlat, nlon, nvar), np.nan, dtype=np.float32)
    eof = np.full((nlat, nlon, nvar, n_modes), np.nan, dtype=np.float32)
    pcs = np.full((nlat, nlon, T, n_modes), np.nan, dtype=np.float32)

    for i in range(nlat):
        for j in range(nlon):
            Xi = X[:, i, j, :]  # (time, var)

            # Require complete data across all variables
            mask = np.isfinite(Xi).all(axis=1)
            if np.sum(mask) < 3:
                continue

            Xi = Xi[mask, :]

            # Standardize each variable
            mu = np.nanmean(Xi, axis=0)
            sd = np.nanstd(Xi, axis=0, ddof=1)
            if np.any(sd == 0) or np.any(~np.isfinite(sd)):
                continue

            Zi = (Xi - mu) / sd

            # Covariance matrix
            C = np.cov(Zi, rowvar=False)

            # Eigen decomposition
            w, V = np.linalg.eigh(C)
            idx = np.argsort(w)[::-1]
            w = w[idx]
            V = V[:, idx]

            eigvals[i, j, :] = w
            evr[i, j, :] = w / np.sum(w)
            eof[i, j, :, :] = V
            pcs[i, j, mask, :] = Zi @ V

    ds_out = xr.Dataset(
        {
            "eigvals": (("lat", "lon", "mode"), eigvals),
            "evr":     (("lat", "lon", "mode"), evr),
            "eof":     (("lat", "lon", "variable", "mode"), eof),
            "pcs":     (("lat", "lon", "time", "mode"), pcs),
        },
        coords={
            "lat": ref_da["lat"],
            "lon": ref_da["lon"],
            "time": ref_da["time"],
            "variable": var_files,
            "mode": np.arange(1, n_modes + 1),
        }
    )

    out_path = os.path.join(OUTPUT_ROOT, f"multivariate_eof_{regime}.nc")
    ds_out.to_netcdf(out_path)
    print(f"Saved: {out_path}")