#!/usr/bin/env python3

#
# Run the test, check the error
#

from __future__ import print_function

try:
    from builtins import str
except:
    pass

tol = 2e-6  # Absolute tolerance
numTests = 4  # We test 4 different boundary conditions (with slightly different inputs for each)

from boututils.run_wrapper import shell, build_and_log, launch_safe
from boutdata.collect import collect
from sys import exit


build_and_log("Multigrid Laplacian inversion test")

print("Running multigrid Laplacian inversion test")
success = True

for nproc in [1, 2, 4]:
    for inputfile in ["BOUT_jy4.inp", "BOUT_jy63.inp", "BOUT_jy127.inp"]:
        # set nxpe on the command line as we only use solution from one point in y, so splitting in y-direction is redundant (and also doesn't help test the multigrid solver)
        cmd = "./test_multigrid_laplace -f " + inputfile + " NXPE=" + str(nproc)

        shell("rm data/BOUT.dmp.*.nc")

        print("   %d processors, input file is %s" % (nproc, inputfile))
        s, out = launch_safe(cmd, nproc=nproc, pipe=True)
        with open("run.log." + str(nproc), "w") as f:
            f.write(out)

        # Collect errors
        errors = [
            collect("max_error" + str(i), path="data") for i in range(1, numTests + 1)
        ]

        for i, e in enumerate(errors):
            print("Checking test " + str(i))
            if e < 0.0:
                print("Fail, solver did not converge")
                success = False
            if e > tol:
                print("Fail, maximum absolute error = " + str(e))
                success = False
            else:
                print("Pass")

if success:
    print(" => All multigrid Laplacian inversion tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
