#!/usr/bin/env python3

from boututils.run_wrapper import build_and_log, launch_safe
from boutdata.collect import collect
from sys import exit

from numpy import max, abs


build_and_log("parallel slices test")

failed = False

for shifttype in ["shifted", "shiftedinterp"]:
    s, out = launch_safe(
        "./test_yupdown mesh:paralleltransform:type=" + shifttype,
        nproc=1,
        pipe=True,
        verbose=True,
    )

    with open("run.log", "w") as f:
        f.write(out)

    vars = [("ddy", "ddy_check"), ("ddy2", "ddy_check")]

    for v, v_check in vars:
        print("Testing %s and %s ... " % (v, v_check))
        ddy = collect(v, path="data", xguards=False, yguards=False, info=False)
        ddy_check = collect(
            v_check, path="data", xguards=False, yguards=False, info=False
        )

        diff = max(abs(ddy - ddy_check))

        if diff < 2e-5:
            print(shifttype + " passed (Max difference %e)" % (diff))
        else:
            print(shifttype + " failed (Max difference %e)" % (diff))
            failed = True

if failed:
    exit(1)
exit(0)
