#!/usr/bin/env python3

#
# Run the test, compare results against the benchmark
#

from boututils.run_wrapper import build_and_log, shell, launch_safe
from boutdata import collect
from numpy import sqrt, max, abs, mean, array, log, polyfit
from sys import stdout, exit

# Display the plots as well as saving to file
show_plot = False

# List of NX values to use
nxlist = [16, 32, 64, 128]

# Only testing 2D (x, z) slices, so only need one processor
nproc = 1

# Variables to compare
varlist = ["a", "b", "c"]
markers = ["bo", "r^", "kx"]
labels = [r"$" + var + r"$" for var in varlist]

methods = {
    "hermitespline": 4,
}


build_and_log("ZInterpolation test")

print("Running ZInterpolation test")
success = True

for method in methods:
    print("------------------------------")
    print("Using {} interpolation".format(method))

    error_2 = {}
    error_inf = {}
    for var in varlist:
        error_2[var] = []  # L2 error (RMS)
        error_inf[var] = []  # Maximum error

    for nx in nxlist:
        dx = 1.0 / (nx)

        args = (
            " mesh:nx={nx4} mesh:dx={dx} MZ={nx} zinterpolation:type={method}".format(
                nx4=nx + 4, dx=dx, nx=nx, method=method
            )
        )

        cmd = "./test_interpolate" + args

        shell("rm data/BOUT.dmp.*.nc")

        s, out = launch_safe(cmd, nproc=nproc, pipe=True)
        with open("run.log.{}.{}".format(method, nx), "w") as f:
            f.write(out)

        # Collect output data
        for var in varlist:
            interp = collect(var + "_interp", path="data", xguards=False, info=False)
            solution = collect(
                var + "_solution", path="data", xguards=False, info=False
            )

            E = interp - solution

            l2 = float(sqrt(mean(E**2)))
            linf = float(max(abs(E)))

            error_2[var].append(l2)
            error_inf[var].append(linf)

            print("{0:s} : l-2 {1:.8f} l-inf {2:.8f}".format(var, l2, linf))

    dx = 1.0 / array(nxlist)

    for var in varlist:
        fit = polyfit(log(dx), log(error_2[var]), 1)
        order = fit[0]
        stdout.write("{0:s} Convergence order = {1:.2f}".format(var, order))

        # Make sure scaling is at least 90% of expected order
        if order < 0.9 * methods[method]:
            print("............ FAIL")
            success = False
        else:
            print("............ PASS")

    if False:
        try:
            import matplotlib.pyplot as plt

            # Plot errors
            plt.figure()

            for var, mark, label in zip(varlist, markers, labels):
                plt.plot(dx, error_2[var], "-" + mark, label=label)
                plt.plot(dx, error_inf[var], "--" + mark)

            plt.legend(loc="upper left")
            plt.grid()

            plt.yscale("log")
            plt.xscale("log")

            plt.xlabel(r"Mesh spacing $\delta x$")
            plt.ylabel("Error norm")
            plt.title("Error scaling for {}".format(method))

            name = "error_scaling_{}.pdf".format(method)
            plt.savefig(name)
            print("Plot saved to {}".format(name))

            if show_plot:
                plt.show()
            plt.close()
        except ImportError:
            print("No matplotlib available")
    else:
        print("Plotting disabled")

print("------------------------------")
if success:
    print(" => All Interpolation tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
