#!/usr/bin/env python3

# requires: not metric_3d
# requires: petsc

#
# Run the test, compare results against the benchmark
#

from boututils.run_wrapper import shell, launch_safe, getmpirun, build_and_log
from boutdata.collect import collect
import numpy as np
from sys import exit

tol = 1e-10  # Absolute tolerance

MPIRUN = getmpirun()

build_and_log("LaplaceXZ test")

print("Running LaplaceXZ test")
success = True

for nproc in [1, 2, 4]:
    nxpe = nproc

    cmd = "./test-laplacexz nxpe=" + str(nxpe)

    shell("rm data/BOUT.dmp.*.nc")

    print("   %d processors (nxpe = %d)...." % (nproc, nxpe))
    s, out = launch_safe(cmd, runcmd=MPIRUN, nproc=nproc, mthread=1, pipe=True)
    with open("run.log." + str(nproc), "w") as f:
        f.write(out)

    # Collect output data
    f = collect("f", path="data", info=False)
    f2 = collect("f2", path="data", info=False)
    print("      Checking tolerance... ")
    # Compare benchmark and output
    if np.shape(f) != np.shape(f2):
        print("Fail, wrong shape")
        success = False
    diff = np.max(np.abs(f2 - f))
    if diff > tol:
        print("Fail, maximum difference = " + str(diff))
        success = False
    else:
        print("Pass")

if success:
    print(" => LaplaceXZ inversion test passed")
    exit(0)

print(" => LaplaceXZ test failed")
exit(1)
