#!/usr/bin/env python3

from boututils.run_wrapper import build_and_log, launch_safe
from boutdata.collect import collect
from sys import stdout, exit

from numpy import max, abs

build_and_log("parallel slices and weights test")

failed = False
for shifttype in ["shiftedinterp"]:
    s, out = launch_safe(
        "./test_yupdown_weights mesh:paralleltransform:type=" + shifttype,
        nproc=1,
        pipe=True,
        verbose=True,
    )

    with open("run.log", "w") as f:
        f.write(out)

    vars = [("ddy", "ddy2")]
    for v1, v2 in vars:
        stdout.write("Testing %s and %s ... " % (v1, v2))
        ddy = collect(v1, path="data", xguards=False, yguards=False, info=False)
        ddy2 = collect(v2, path="data", xguards=False, yguards=False, info=False)

        diff = max(abs(ddy - ddy2))

        if diff < 1e-8:
            print(shifttype + " passed (Max difference %e)" % (diff))
        else:
            print(shifttype + " failed (Max difference %e)" % (diff))
            failed = True

if failed:
    exit(1)

exit(0)
