#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Outputs PDF figures in each subdirectory
# Checks that the convergence is 2nd order
# Exits with status 1 if any test failed

# requires: all_tests
# cores: 2

from __future__ import division
from __future__ import print_function
from builtins import str

from boututils.run_wrapper import shell, build_and_log, launch_safe
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log

from os.path import join

import matplotlib.pyplot as plt


build_and_log("MMS diffusion test")

# List of input directories
inputs = [
    ("X", ["mesh:nx"])  # ,
    # ("Y",["mesh:ny"]),
    # ("XYZ", ["mesh:nx", "mesh:ny", "MZ"])
]

# List of NX values to use
nxlist = [8, 16, 32, 64, 128, 256, 512, 1024]

nout = 1

nproc = 2

success = True

for dir, sizes in inputs:
    print("Running test in '%s'" % (dir))

    error_2 = []  # The L2 error (RMS)
    error_inf = []  # The maximum error

    for nx in nxlist:
        # args = "-d "+dir+" nout="+str(nout)+" timestep="+str(timestep)+" solver:type=rk4 solver:timestep="+str(0.01*timestep/nx)
        args = "-d " + dir + " nout=" + str(nout)
        for s in sizes:
            args += " " + s + "=" + str(nx)

        print("  Running with " + args)

        # Delete old data
        shell("rm " + dir + "/BOUT.dmp.*.nc")

        # Command to run
        cmd = "./diffusion " + args
        # Launch using MPI
        s, out = launch_safe(cmd, nproc=nproc, pipe=True)

        # Save output to log file
        f = open("run.log." + str(nx), "w")
        f.write(out)
        f.close()

        # Collect data
        E_N = collect("E_N", tind=[nout, nout], path=dir, info=False)

        # Remove guard cells
        E_N = E_N[:, 1:-1, :, :]

        # Average error over domain
        l2 = sqrt(mean(E_N**2))
        linf = max(abs(E_N))

        error_2.append(l2)
        error_inf.append(linf)

        print("  -> Error norm: l-2 %f l-inf %f" % (l2, linf))

    # Calculate grid spacing
    dx = 1.0 / (array(nxlist) - 2.0)

    # Calculate convergence order

    order = log(error_2[-1] / error_2[-2]) / log(dx[-1] / dx[-2])
    print("Convergence order = %f" % (order))

    if 1.8 < order < 2.2:
        # success
        pass
    else:
        success = False
        print("=> FAILED\n")

    # plot errors

    try:
        plt.figure()

        plt.plot(dx, error_2, "-o", label=r"$l^2$")
        plt.plot(dx, error_inf, "-x", label=r"$l^\infty$")

        plt.plot(
            dx, error_2[-1] * (dx / dx[-1]) ** order, "--", label="Order %.1f" % (order)
        )

        plt.legend(loc="upper left")
        plt.grid()

        plt.yscale("log")
        plt.xscale("log")

        plt.xlabel(r"Mesh spacing $\delta x$")
        plt.ylabel("Error norm")

        plt.savefig(join(dir, "norm.pdf"))

        # plt.show()
        plt.close()
    except:
        pass

if success:
    print(" => All tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
