#!/usr/bin/env python3

# Python script to run and analyse MMS test

from __future__ import division
from __future__ import print_function
from builtins import zip
from builtins import str

from boututils.run_wrapper import shell, launch_safe, build_and_log
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log, concatenate

import pickle


build_and_log("MMS test")

# List of NX values to use
nxlist = [8, 16, 32, 64, 128]  # , 256]

path = "mms-slab3d"

nxdx = 128 * 2e-5  # Nx * dx held constant
nydy = 64 * 1e-3

varlist = ["Ne", "Te", "Vort", "VePsi", "Vi"]
markers = ["bo", "r^", "gs", "k+", "mx"]

success = True

nproc = 4

error_2 = {}
error_inf = {}
for var in varlist:
    error_2[var] = []  # The L2 error (RMS)
    error_inf[var] = []  # The maximum error

for nx in nxlist:
    args = (
        "-d "
        + path
        + " mesh:nx="
        + str(nx + 4)
        + " mesh:dx="
        + str(nxdx / nx)
        + " MZ="
        + str(nx)
        + " mesh:ny="
        + str(nx)
        + " mesh:dy="
        + str(nydy / nx)
    )

    print("Running with " + args)

    # Delete old data
    shell("rm %s/BOUT.dmp.*.nc" % (path,))

    # Command to run
    cmd = "./gbs " + args
    # Launch using MPI
    s, out = launch_safe(cmd, nproc=nproc, pipe=True)

    # Save output to log file
    f = open(path + "/run.log." + str(nx), "w")
    f.write(out)
    f.close()

    for var in varlist:
        # Collect data
        E = collect("E_" + var, tind=[1, 1], path=path, info=False)
        E = E[0, 2:-2, :, :]

        l2 = sqrt(mean(E**2))
        linf = max(abs(E))

        error_2[var].append(l2)
        error_inf[var].append(linf)

        print("%s : l-2 %f l-inf %f" % (var, l2, linf))

# Calculate grid spacing
dx = 1.0 / (array(nxlist) - 2.0)

# Save data
with open("mms-slab3d.pkl", "wb") as output:
    pickle.dump(nxlist, output)
    pickle.dump(error_2, output)
    pickle.dump(error_inf, output)


# plot errors

for var, mark in zip(varlist, markers):
    order = log(error_2[var][-1] / error_2[var][-2]) / log(dx[-1] / dx[-2])
    print("%s Convergence order = %f" % (var, order))
    if 1.9 < order < 2.1:
        pass
    else:
        success = False
    # plt.plot(dx, error_2[var][-1]*(dx/dx[-1])**order, '--', label="Order %.1f"%(order))

try:
    import matplotlib.pyplot as plt

    for var, mark in zip(varlist, markers):
        plt.plot(dx, error_2[var], "-" + mark, label=var)
        plt.plot(dx, error_inf[var], "--" + mark)

        # plt.plot(dx, error_2[var][-1]*(dx/dx[-1])**order, '--', label="Order %.1f"%(order))

    plt.legend(loc="upper left")
    plt.grid()

    plt.yscale("log")
    plt.xscale("log")

    plt.xlabel(r"Mesh spacing $\delta x$")
    plt.ylabel("Error norm")

    plt.savefig("norm-slab3d.pdf")

    # plt.show()
except:
    pass

if success:
    exit(0)
else:
    exit(1)
