# demo.py
import argparse
import json
import numpy as np
import onnxruntime as ort


def normalize_minmax(x: np.ndarray, min_v: np.ndarray, max_v: np.ndarray, clamp: bool = False):
    x64 = x.astype(np.float64)
    min64 = min_v.astype(np.float64)
    max64 = max_v.astype(np.float64)

    scale = max64 - min64
    zero = np.isclose(scale, 0.0)
    scale[zero] = 1.0
    y = (x64 - min64) / scale


    if y.ndim == 1:
        y[zero] = 0.0
    else:
        y[:, zero] = 0.0

    if clamp:
        y = np.clip(y, 0.0, 1.0)

    return y.astype(np.float32)


def rmse(a: np.ndarray, b: np.ndarray):
    return np.sqrt(np.mean((a - b) ** 2, axis=1))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=str, default="model/autoencoder.onnx")
    ap.add_argument("--config", type=str, default="model/model_config.json")
    ap.add_argument("--csv", type=str, default="telemetry.csv", help="CSV with header; numeric columns only")
    ap.add_argument("--row", type=int, default=0)
    args = ap.parse_args()

    with open(args.config, "r") as f:
        cfg = json.load(f)

    feature_size = int(cfg["architecture"]["feature_size"])
    thres = float(cfg["postprocessing"]["threshold"])
    min_v = np.asarray(cfg["preprocessing"]["normalization"]["min"], dtype=np.float32)
    max_v = np.asarray(cfg["preprocessing"]["normalization"]["max"], dtype=np.float32)
    clamp = bool(cfg["preprocessing"]["normalization"].get("clamp", False))

    if args.csv is None:
        # fallback: random normalized sample
        x_norm = np.random.rand(1, feature_size).astype(np.float32)
        print("[INFO] No --csv provided. Using a random normalized sample.")
    else:
        # load CSV numeric data (skip header)
        data = np.genfromtxt(args.csv, delimiter=",", skip_header=1).astype(np.float32)
        if data.ndim == 1:
            data = data.reshape(1, -1)
        x_raw = data[args.row : args.row + 1]
        if x_raw.shape[1] != feature_size:
            raise ValueError(f"CSV feature dim mismatch: got {x_raw.shape[1]}, expected {feature_size}")
        x_norm = normalize_minmax(x_raw, min_v, max_v, clamp=clamp).astype(np.float32)

    sess = ort.InferenceSession(args.model, providers=["CPUExecutionProvider"])
    out = sess.run(["reconstruction"], {"x": x_norm})[0].astype(np.float32)

    score = rmse(out, x_norm)[0]
    label = int(score > thres)

    print("=== Inference Result ===")
    print(f"feature_size: {feature_size}")
    print(f"rmse: {score:.6f}")
    print(f"threshold: {thres:.6f}")
    print(f"anomaly_label: {label}  (1=anomaly, 0=normal)")


if __name__ == "__main__":
    main()
