#!/usr/bin/env python3

from __future__ import division
from __future__ import print_function
from builtins import str

from boututils.run_wrapper import shell, build_and_log, launch_safe
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log

try:
    import matplotlib.pyplot as plt
except:
    plt = None

# requires: all_tests


build_and_log("Making MMS d2dx2 test")

nproc = 1

error_2 = []  # The L2 error (RMS)
error_inf = []  # The maximum error

# List of MZ values to use
nxlist = [8, 16, 32, 64, 128, 256, 512, 1024]

for nx in nxlist:
    args = "mesh:nx=" + str(nx) + " mesh:dx=" + str(1.0 / (nx - 2))

    # Delete old data
    shell("rm data/BOUT.dmp.*.nc")

    # Command to run
    cmd = "./test_d2dx2 " + args

    # Launch using MPI
    s, out = launch_safe(cmd, nproc=nproc, pipe=True)

    # Save output to log file
    f = open("run.log." + str(nx), "w")
    f.write(out)
    f.close()

    # Collect data
    result = collect("result", path="data", info=False)
    solution = collect("solution", path="data", info=False)

    if plt:
        try:
            plt.figure()
            plt.plot(result[1:-1, 0, 0], label="Result")
            plt.plot(solution[1:-1, 0, 0], label="Solution")
            plt.legend()
            plt.xlabel("X")
            plt.savefig("result." + str(nx) + ".pdf")
            plt.close()
        except:
            plt = None
    err = result[1:-1, 0, 0] - solution[1:-1, 0, 0]

    # Average error
    l2 = sqrt(mean(err**2))
    linf = max(abs(err))

    error_2.append(l2)
    error_inf.append(linf)

    print("Error norm: l-2 %f l-inf %f" % (l2, linf))


# Calculate grid spacing
dx = 1.0 / (array(nxlist) - 2.0)

order = log(error_2[-1] / error_2[-2]) / log(dx[-1] / dx[-2])
print("Convergence order = %f" % (order))


# plot errors
if plt:
    try:
        plt.plot(dx, error_2, "-o", label=r"$l^2$")
        plt.plot(dx, error_inf, "-x", label=r"$l^\infty$")

        plt.plot(
            dx, error_2[-1] * (dx / dx[-1]) ** order, "--", label="Order %.1f" % (order)
        )

        plt.legend(loc="upper left")
        plt.grid()

        plt.yscale("log")
        plt.xscale("log")

        plt.xlabel(r"Mesh spacing $\delta x$")
        plt.ylabel("Error norm")

        plt.savefig("norm.pdf")

        # plt.show()
        plt.close()
    except:
        pass

if 2.2 > order > 1.8:
    # check for success
    print(" => Test passed")
    exit(0)
else:
    print(" => Test failed")
    exit(1)
