#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Outputs PDF figures in each subdirectory
# Checks that the convergence is 2nd order
# Exits with status 1 if any test failed

# requires: all_tests
# cores: 2

from __future__ import division
from __future__ import print_function
from builtins import str

from boututils.run_wrapper import shell, build_and_log, launch_safe
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log

from os.path import join


build_and_log("MMS diffusion test")

# List of input directories
inputs = [
    ("X", ["mesh:nx"]),
    ("Y", ["mesh:ny"]),
    ("Z", ["MZ"])  # ,
    # ("XYZ", ["mesh:nx", "mesh:ny", "MZ"])
]

# List of NX values to use
nxlist = [4, 8, 16, 32, 64, 128, 256]
# nxlist = [128,256,512,1024,2048,4096]

nout = 1
# timestep = 0.1
timestep = 0.1

nproc = 2

success = True

for dir, sizes in inputs:
    print("Running test in '%s'" % (dir))

    error_2 = []  # The L2 error (RMS)
    error_inf = []  # The maximum error

    for nx in nxlist:
        args = "-d " + dir + " nout=" + str(nout) + " timestep=" + str(timestep)
        for s in sizes:
            args += " " + s + "=" + str(nx)

        print("  Running with " + args)

        # Delete old data
        shell("rm " + dir + "/BOUT.dmp.*.nc")

        # Command to run
        cmd = "./cyto " + args
        # Launch using MPI
        s, out = launch_safe(cmd, nproc=nproc, pipe=True)

        # Save output to log file
        f = open("run.log." + str(nx), "w")
        f.write(out)
        f.close()

        # Collect data
        E_N = collect("E_N", tind=[nout, nout], path=dir, info=False)

        # Average error over domain, not including guard cells
        l2 = sqrt(mean(E_N**2))
        linf = max(abs(E_N))

        error_2.append(l2)
        error_inf.append(linf)

        print("  -> Error norm: l-2 %f l-inf %f" % (l2, linf))

    # Calculate grid spacing
    # This is only correct in the x-direction if MXG = 1. In the other directions
    # dy = 1/ny, dz = 1/(MZ-1)
    dx = 1.0 / (array(nxlist) - 2.0)

    # Calculate convergence order

    order = log(error_2[-1] / error_2[-2]) / log(dx[-1] / dx[-2])
    print("Convergence order = %f" % (order))

    if 1.8 < order < 2.2:
        pass
    else:
        success = False
        print("=> FAILED\n")

    # plot errors

    try:
        import matplotlib.pyplot as plt

        plt.figure()

        plt.plot(dx, error_2, "-o", label=r"$l^2$")
        plt.plot(dx, error_inf, "-x", label=r"$l^\infty$")

        plt.plot(
            dx, error_2[-1] * (dx / dx[-1]) ** order, "--", label="Order %.1f" % (order)
        )

        plt.legend(loc="upper left")
        plt.grid()

        plt.yscale("log")
        plt.xscale("log")

        plt.xlabel(r"Mesh spacing $\delta x$")
        plt.ylabel("Error norm")

        plt.savefig(join(dir, "norm.pdf"))

        # plt.show()
        plt.close()
    except:
        pass

if success:
    print(" => All tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
