#!/usr/bin/env python3

#
# Run the test, compare results against the benchmark
#

# requires: not metric_3d
# Requires: netcdf
# Cores: 4

from __future__ import print_function

try:
    from builtins import str
except:
    pass

# Variables to compare
vars = [
    "flag0",
    "flag3",
    "flagis",
    "flagos",
    "flag0a",
    "flag3a",
    "flagisa",
    "flagosa",
    "flag0ac",
    "flag3ac",
    "flagisac",
    "flagosac",
    "flag0ad",
    "flag3ad",
    "flagisad",
    "flagosad",
]
tol = 1e-6  # Absolute tolerance

from boututils.run_wrapper import build_and_log, shell, launch_safe
from boutdata.collect import collect
import numpy as np
from sys import stdout, exit


build_and_log("Laplacian inversion test")

# Read benchmark values
print("Reading benchmark data")
bmk = {}
for v in vars:
    bmk[v] = collect(v, path="data", prefix="benchmark", info=False)

print("Running Laplacian inversion test")
success = True

for solver in ["cyclic", "pcr", "pcr_thomas"]:
    for nproc in [1, 2, 4]:
        nxpe = 1
        if nproc > 2:
            nxpe = 2

        cmd = "./test_laplace NXPE=" + str(nxpe) + " laplace:type=" + solver

        shell("rm data/BOUT.dmp.*.nc")

        print("   %s solver with %d processors (nxpe = %d)...." % (solver, nproc, nxpe))
        s, out = launch_safe(cmd, nproc=nproc, mthread=1, pipe=True)
        with open("run.log." + str(nproc), "w") as f:
            f.write(out)

        # Collect output data
        for v in vars:
            stdout.write("      Checking variable " + v + " ... ")
            result = collect(v, path="data", info=False)
            # Compare benchmark and output
            if np.shape(bmk[v]) != np.shape(result):
                print("Fail, wrong shape")
                success = False
            diff = np.max(np.abs(bmk[v] - result))
            if diff > tol:
                print("Fail, maximum difference = " + str(diff))
                success = False
            else:
                print("Pass")

if success:
    print(" => All Laplacian inversion tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
