#!/usr/bin/env python3

#
# Run the test, compare results against the benchmark
#

# requires: not metric_3d
# Requires: netcdf
# Cores: 4

# Variables to compare
from __future__ import print_function

try:
    from builtins import str
except:
    pass

vars = ["pade1", "pade2"]

tol = 1e-7  # Absolute tolerance, benchmark values are floats

from boututils.run_wrapper import build_and_log, shell, launch_safe
from boutdata.collect import collect
import numpy as np
from sys import stdout, exit


build_and_log("Gyro-average inversion test")

# Read benchmark values
print("Reading benchmark data")
bmk = {}
for v in vars:
    bmk[v] = collect(v, path="data", prefix="benchmark", info=False, xguards=False)

print("Running Gyro-average inversion test")
success = True

for nproc in [1, 2, 4]:
    nxpe = 1
    if nproc > 2:
        nxpe = 2

    cmd = "./test_gyro NXPE=" + str(nxpe)

    shell("rm data/BOUT.dmp.*.nc")

    print("   %d processors (nxpe = %d)...." % (nproc, nxpe))
    s, out = launch_safe(cmd, nproc=nproc, pipe=True)
    with open("run.log." + str(nproc), "w") as f:
        f.write(out)

    # Collect output data
    for v in vars:
        stdout.write("      Checking variable " + v + " ... ")
        result = collect(v, path="data", info=False, xguards=False)
        # Compare benchmark and output
        if np.shape(bmk[v]) != np.shape(result):
            print("Fail, wrong shape")
            success = False
        diff = np.max(np.abs(bmk[v] - result))
        if diff > tol:
            print("Fail, maximum difference = " + str(diff))
            success = False
        else:
            print("Pass")

if success:
    print(" => All Gyro-average tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
