import os
import numpy as np
import xarray as xr

# =========================================================
# code_3
# Save anomaly datasets by quantile regime
# 0 = lower 20%, 1 = middle 60%, 2 = upper 20%
# =========================================================

CLASS_FILE = "./quantile_data/chl_a_quantile_classes.nc"
INPUT_DIR = "./anomaly_data"
OUTPUT_ROOT = "./quantile_regime_data"
os.makedirs(OUTPUT_ROOT, exist_ok=True)

anom_files = {
    "chl_a": "chl_a_anomaly.nc",
    "SST":   "SST_anomaly.nc",
    "SSHF":  "SSHF_anomaly.nc",
    "MLD":   "MLD_anomaly.nc",
    "WS":    "WS_anomaly.nc",
    "NO3":   "NO3_anomaly.nc",
    "PO4":   "PO4_anomaly.nc",
    "Fe":    "Fe_anomaly.nc",
    "Si":    "Si_anomaly.nc",
}

ds_class = xr.open_dataset(CLASS_FILE)
class_map = ds_class["quantile_class"]

regime_names = {
    0: "lower20",
    1: "middle60",
    2: "upper20"
}

for var, fname in anom_files.items():
    path = os.path.join(INPUT_DIR, fname)
    if not os.path.exists(path):
        print(f"Input file not found: {path}")
        continue

    ds = xr.open_dataset(path)
    da = ds[list(ds.data_vars)[0]]  # (time, lat, lon)

    for cid, cname in regime_names.items():
        regime_dir = os.path.join(OUTPUT_ROOT, cname)
        os.makedirs(regime_dir, exist_ok=True)

        masked = da.where(class_map == cid)
        masked.name = f"{var}_{cname}"

        out_path = os.path.join(regime_dir, f"{var}_{cname}.nc")
        masked.to_netcdf(out_path)
        print(f"Saved: {out_path}")