#!/usr/bin/env python3
"""
prepare_data.py — build the analysis-ready tables from the raw Zenodo deposit
=============================================================================

Converts the raw files in Data.zip into the two table sets consumed by
Analysize.py and the figN_*.py figure scripts:

    FA_Data/F0.xlsx ... F100.xlsx       ->  LA_data/F0_nmol.csv ... F100_nmol.csv
    FA_Data/Muscle.xlsx  (optional)     ->  LA_data/Muscle_nmol.csv
    NIR_Data/F0/F0-1.csv ...            ->  NIR_data/merged_NIR_spectra.csv

What it does
------------
* Fatty-acid workbooks: each lists 16 fatty acids (rows) x samples (columns).
  The first column (fatty-acid names) is written out under the header "LA";
  the sample columns (e.g. F0-1E) are passed through unchanged.
* NIR per-fish files: each holds the 25 spectra measured from one eye as
  25 paired (wavelength, reflectance) columns x 512 rows. The 25 spectra are
  aggregated into ONE reflectance spectrum per fish (default = mean of all 25
  acquisition points; see AGG below) and written as one row per fish with the
  512 bands as columns w<wavelength>. File F0-N.csv -> sample_id F0-N
  (which matches fatty-acid sample F0-NE).

>>> ASSUMPTIONS YOU MAY NEED TO CHANGE (verify against Materials and Methods) <<<
    AGG          how the 25 eye spectra are combined:
                   'all'   = mean of all 25 acquisition points   (default)
                   'iris'  = mean of the 20 iris points  (cols 1-40)
                   'pupil' = mean of the 5 pupil points  (cols 41-50)
    EXCLUDE_IDS  base sample IDs to drop so the tables match the 259-fish
                 modeling set (Methods exclude 1 x F0 + 2 x F100). EMPTY by
                 default -> all fish are kept. Fill in the three IDs (no -E/-M
                 suffix, e.g. "F0-12") to reproduce the modeling set exactly.

Usage
-----
    python prepare_data.py                       # defaults: ../FA_Data, ../NIR_Data
    python prepare_data.py --agg iris
    python prepare_data.py --exclude F0-12 F100-7 F100-33
    python prepare_data.py --fa-in raw/FA_Data --nir-in raw/NIR_Data \
                           --fa-out LA_data --nir-out NIR_data
"""

import argparse, glob, os, re
import numpy as np
import pandas as pd

GROUPS = ["F0", "F25", "F50", "F75", "F100"]
N_IRIS_POINTS  = 20          # cols 1-40  -> acquisition pairs 1-20
N_PUPIL_POINTS = 5           # cols 41-50 -> acquisition pairs 21-25

# ---- assumptions (edit to match Materials and Methods) ----
AGG = "all"                  # 'all' | 'iris' | 'pupil'
EXCLUDE_IDS = []             # e.g. ["F0-12", "F100-7", "F100-33"]

# Map raw-deposit fatty-acid labels -> the names used internally by Analysize.py
# and the figure scripts. The deposit stores e.g. "16:0"/"18:1"; the analysis code
# expects the zero-padded "16:00"/"18:01". Extend this if your workbook uses other
# exact strings.
FA_RENAME = {"16:0": "16:00", "16:1": "16:01", "18:0": "18:00", "18:1": "18:01"}

# The 16 fatty acids the analysis expects (used only to warn about mismatches).
EXPECTED_FA = ["16:00", "16:01", "18:00", "18:01", "18:2n-6", "18:3n-3", "18:3n-6",
               "18:4n-3", "20:2n-6", "20:3n-3", "20:3n-6", "20:4n-3", "20:4n-6",
               "DHA", "DPA", "EPA"]


def log(m=""):
    print(m)


def strip_suffix(sid):
    """'F0-1E' -> 'F0-1' ; 'F0-1M' -> 'F0-1'."""
    return re.sub(r"[EM]$", "", str(sid).strip())


def fish_num(path):
    m = re.search(r"-(\d+)\.csv$", os.path.basename(path))
    return int(m.group(1)) if m else 0


# ----------------------- fatty-acid workbooks -----------------------
def convert_fa(fa_in, fa_out, workbooks, exclude):
    os.makedirs(fa_out, exist_ok=True)
    for label, fname, sheet in workbooks:
        path = os.path.join(fa_in, fname)
        if not os.path.exists(path):
            log(f"  [skip] {fname} not found in {fa_in}")
            continue
        try:
            df = pd.read_excel(path, sheet_name=sheet)
        except Exception:
            df = pd.read_excel(path)                      # fall back to first sheet
        # first column = fatty-acid names -> rename to "LA"
        df = df.rename(columns={df.columns[0]: "LA"})
        df["LA"] = df["LA"].astype(str).str.strip().replace(FA_RENAME)
        missing = [fa for fa in EXPECTED_FA if fa not in set(df["LA"])]
        if missing:
            log(f"  [warn] {label}: FA names not matched after rename: {missing} "
                f"(edit FA_RENAME at the top of prepare_data.py)")
        # drop excluded samples if present in this workbook
        drop = [c for c in df.columns if c != "LA" and strip_suffix(c) in exclude]
        if drop:
            log(f"  [{label}] dropping excluded samples: {drop}")
            df = df.drop(columns=drop)
        out = os.path.join(fa_out, f"{label}_nmol.csv")
        df.to_csv(out, index=False)
        log(f"  [{label}] {path} -> {out}  "
            f"({df.shape[0]} FAs x {df.shape[1] - 1} samples)")


# ----------------------- NIR per-fish spectra -----------------------
def aggregate_one_fish(path, agg):
    """Return (wavelength grid, aggregated reflectance) for one fish file."""
    raw = pd.read_csv(path, header=None)
    num = raw.apply(pd.to_numeric, errors="coerce")
    num = num.loc[~num.isna().all(axis=1)]                # drop a header row if present
    vals = num.values.astype(float)
    npair = vals.shape[1] // 2
    waves_cols = vals[:, 0::2]                            # wavelength columns
    refl_cols  = vals[:, 1::2]                            # reflectance columns
    if agg == "iris":
        sel = slice(0, N_IRIS_POINTS)
    elif agg == "pupil":
        sel = slice(N_IRIS_POINTS, N_IRIS_POINTS + N_PUPIL_POINTS)
    else:
        sel = slice(0, npair)
    waves = np.nanmean(waves_cols, axis=1)               # shared band grid
    refl  = np.nanmean(refl_cols[:, sel], axis=1)        # one spectrum per fish
    return waves, refl


def build_nir(nir_in, nir_out, agg, exclude):
    os.makedirs(nir_out, exist_ok=True)
    rows, wcols, grid, counts = [], None, None, {g: 0 for g in GROUPS}
    for grp in GROUPS:
        gdir = os.path.join(nir_in, grp)
        files = sorted(glob.glob(os.path.join(gdir, f"{grp}-*.csv")), key=fish_num)
        if not files:
            log(f"  [warn] no CSVs found in {gdir}")
        for f in files:
            sid = os.path.splitext(os.path.basename(f))[0]    # 'F0-1'
            if sid in exclude:
                log(f"  [excluded] {sid}")
                continue
            waves, refl = aggregate_one_fish(f, agg)
            if grid is None:
                grid  = np.round(waves, 2)
                wcols = [f"w{w:.2f}" for w in grid]
            if len(refl) != len(grid):
                log(f"  [warn] {sid}: {len(refl)} bands != {len(grid)}; skipped")
                continue
            row = {"group": grp, "sample_id": sid}
            row.update(dict(zip(wcols, refl)))
            rows.append(row)
            counts[grp] += 1
    df = pd.DataFrame(rows, columns=["group", "sample_id"] + (wcols or []))
    out = os.path.join(nir_out, "merged_NIR_spectra.csv")
    df.to_csv(out, index=False)
    log(f"  NIR -> {out}  ({len(df)} fish x {len(wcols or [])} bands)")
    if grid is not None:
        log(f"  band range: {grid.min():.2f}-{grid.max():.2f} nm")
    log("  per-group spectra: " + ", ".join(f"{g}={counts[g]}" for g in GROUPS))


# ----------------------------- main -----------------------------
def main():
    ap = argparse.ArgumentParser(description="Build analysis tables from the raw deposit.")
    ap.add_argument("--fa-in",   default="../FA_Data")
    ap.add_argument("--nir-in",  default="../NIR_Data")
    ap.add_argument("--fa-out",  default="../LA_data")
    ap.add_argument("--nir-out", default="../NIR_data")
    ap.add_argument("--agg", default=AGG, choices=["all", "iris", "pupil"],
                    help="how the 25 eye spectra are combined")
    ap.add_argument("--exclude", nargs="*", default=[],
                    help="extra base sample IDs to drop (added to EXCLUDE_IDS)")
    args = ap.parse_args()

    exclude = set(EXCLUDE_IDS) | set(args.exclude)
    agg_desc = {"all": "mean of all 25 points",
                "iris": "mean of the 20 iris points",
                "pupil": "mean of the 5 pupil points"}[args.agg]

    log("== prepare_data ==")
    log(f"  eye-spectra aggregation: {args.agg}  ({agg_desc})")
    log(f"  excluded IDs: {sorted(exclude) if exclude else 'none (keeping all fish)'}")

    log("Fatty-acid workbooks:")
    eye_books = [(g, f"{g}.xlsx", 0) for g in GROUPS]
    convert_fa(args.fa_in, args.fa_out, eye_books, exclude)
    if os.path.exists(os.path.join(args.fa_in, "Muscle.xlsx")):
        convert_fa(args.fa_in, args.fa_out, [("Muscle", "Muscle.xlsx", "M_nmol")], exclude)
        log("  (note: the muscle table is provided for completeness; "
            "Analysize.py and the figure scripts use eye data only)")

    log("NIR spectra:")
    build_nir(args.nir_in, args.nir_out, args.agg, exclude)

    log("")
    log("Done. Tables written. Next:")
    log("  python Analysize.py        # numerical results log")
    log("  python figN_*.py           # individual figures")


if __name__ == "__main__":
    main()
