#!/usr/bin/env python3

#
# Run the test, compare results against the benchmark
#

# Requires: netcdf
# Cores: 4

# Variables to compare
vars = ["yavg2d", "yavg3d", "sm3d"]
tol = 1e-7  # Absolute tolerance, benchmark values are floats

from boututils.run_wrapper import build_and_log, shell, launch_safe
from boutdata.collect import collect
import numpy as np
from sys import stdout, exit


build_and_log("smoothing operator test")

# Read benchmark values
print("Reading benchmark data")
bmk = {}
for v in vars:
    bmk[v] = collect(v, path="data", prefix="benchmark", info=False)

print("Running smoothing operator test")
success = True

for nype in [1, 2]:
    for nxpe in [1, 2]:
        nproc = nxpe * nype
        cmd = "./test_smooth"

        shell("rm data/BOUT.dmp.*.nc")

        print("   %d processor (%d x %d)...." % (nproc, nxpe, nype))
        s, out = launch_safe(cmd + " NXPE=" + str(nxpe), nproc=nproc, pipe=True)
        with open("run.log." + str(nproc), "w") as f:
            f.write(out)

        # Collect output data
        for v in vars:
            stdout.write("      Checking variable " + v + " ... ")
            result = collect(v, path="data", info=False)
            # Compare benchmark and output
            if np.shape(bmk[v]) != np.shape(result):
                print("Fail, wrong shape")
                success = False
                continue

            diff = np.max(np.abs(bmk[v] - result))
            if diff > tol:
                print("Fail, maximum difference = " + str(diff))
                success = False
            else:
                print("Pass")

if success:
    print(" => All smoothing operator tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
