#!/usr/bin/env python3

# Python script to run and analyse MMS test

from __future__ import division
from __future__ import print_function

try:
    from builtins import zip
    from builtins import str
except:
    pass

from boututils.run_wrapper import shell, build_and_log, launch_safe
from boutdata.collect import collect

import pickle

from sys import stdout

from numpy import sqrt, max, abs, mean, array, log, concatenate, pi


build_and_log("Making MMS wave test")

# List of NX values to use
nylist = [8, 16, 32, 64, 128, 256]

nout = 1
timestep = 1

nproc = 1

varlist = ["f", "g"]
markers = ["bo", "r^"]
labels = ["f", "g"]

error_2 = {}
error_inf = {}
for var in varlist:
    error_2[var] = []  # The L2 error (RMS)
    error_inf[var] = []  # The maximum error

for ny in nylist:
    dy = 2.0 * pi / ny
    args = (
        "mesh:ny="
        + str(ny)
        + " mesh:dy="
        + str(dy)
        + " nout="
        + str(nout)
        + " timestep="
        + str(timestep)
    )

    print("Running with " + args)

    # Delete old data
    shell("rm data/BOUT.dmp.*.nc")

    # Command to run
    cmd = "./wave " + args
    # Launch using MPI
    s, out = launch_safe(cmd, nproc=nproc, pipe=True)

    # Save output to log file
    with open("run.log." + str(ny), "w") as f:
        f.write(out)

    for var in varlist:
        # Collect data
        E = collect("E_" + var, tind=[nout, nout], info=False, path="data")
        E = E[0, 0, :, 0]

        # Average error over domain

        l2 = sqrt(mean(E**2))
        linf = max(abs(E))

        error_2[var].append(l2)
        error_inf[var].append(linf)

        print("Error norm %s: l-2 %f l-inf %f" % (var, l2, linf))

# Save data
with open("wave.pkl", "wb") as output:
    pickle.dump(nylist, output)
    pickle.dump(error_2, output)
    pickle.dump(error_inf, output)

# Calculate grid spacing
dy = 1.0 / array(nylist)

# Calculate convergence order
success = True
for var in varlist:
    order = log(error_2[var][-1] / error_2[var][-2]) / log(dy[-1] / dy[-2])
    stdout.write("%s Convergence order = %f" % (var, order))

    if 1.8 < order < 2.2:  # Should be second order accurate
        print("............ PASS")
    else:
        success = False
        print("............ FAIL")

# plot errors
try:
    import matplotlib.pyplot as plt

    for var, mark, label in zip(varlist, markers, labels):
        plt.plot(dy, error_2[var], "-" + mark, label="%s order=%.2f" % (label, order))
        plt.plot(dy, error_inf[var], "--" + mark)

    plt.legend(loc="upper left")
    plt.grid()

    plt.yscale("log")
    plt.xscale("log")

    plt.xlabel(r"Mesh spacing $\delta y$")
    plt.ylabel("Error norm")

    plt.savefig("norm.pdf")

    # plt.show()
    plt.close()
except:
    # Plotting could fail for any number of reasons, and the actual
    # error raised may depend on, among other things, the current
    # matplotlib backend, so catch everything
    pass

if success:
    print(" => Test passed")
    exit(0)
else:
    print(" => Test failed")
    exit(1)
