import os
import numpy as np
import xarray as xr
from scipy.stats import linregress

# =========================================================
# code_1
# Calculate linear trend (slope) for each anomaly dataset
# Variables: chl-a, SST, SSHF, MLD, WS, NO3, PO4, Fe, Si
# =========================================================

INPUT_DIR = "./anomaly_data"
OUTPUT_DIR = "./trend_data"
os.makedirs(OUTPUT_DIR, exist_ok=True)

file_vars = {
    "chl_a": os.path.join(INPUT_DIR, "chl_a_anomaly.nc"),
    "SST":   os.path.join(INPUT_DIR, "SST_anomaly.nc"),
    "SSHF":  os.path.join(INPUT_DIR, "SSHF_anomaly.nc"),
    "MLD":   os.path.join(INPUT_DIR, "MLD_anomaly.nc"),
    "WS":    os.path.join(INPUT_DIR, "WS_anomaly.nc"),
    "NO3":   os.path.join(INPUT_DIR, "NO3_anomaly.nc"),
    "PO4":   os.path.join(INPUT_DIR, "PO4_anomaly.nc"),
    "Fe":    os.path.join(INPUT_DIR, "Fe_anomaly.nc"),
    "Si":    os.path.join(INPUT_DIR, "Si_anomaly.nc"),
}

def get_main_variable(ds):
    return ds[list(ds.data_vars)[0]]

for var, path in file_vars.items():
    if not os.path.exists(path):
        print(f"Input file not found: {path}")
        continue

    ds = xr.open_dataset(path)
    da = get_main_variable(ds)  # (time, lat, lon)

    arr = da.values
    time_len, nlat, nlon = arr.shape
    x = np.arange(time_len)

    slope = np.full((nlat, nlon), np.nan, dtype=np.float32)

    for i in range(nlat):
        for j in range(nlon):
            y = arr[:, i, j]
            mask = np.isfinite(y)
            if np.sum(mask) > 1:
                slope_ij, _, _, _, _ = linregress(x[mask], y[mask])
                slope[i, j] = slope_ij

    slope_da = xr.DataArray(
        slope,
        coords={"lat": da["lat"], "lon": da["lon"]},
        dims=("lat", "lon"),
        name="slope"
    )

    out_ds = xr.Dataset({"slope": slope_da})
    out_path = os.path.join(OUTPUT_DIR, f"{var}_slope.nc")
    out_ds.to_netcdf(out_path)
    print(f"Saved: {out_path}")