#!/usr/bin/env python3
#
# Python script to run and analyse MMS test
#

# Cores: 2
# requires: zoidberg

from __future__ import division
from __future__ import print_function

from boututils.run_wrapper import build_and_log, launch_safe
from boutdata.collect import collect
import boutconfig as conf

from numpy import array, log, polyfit, linspace, arange

import pickle

from sys import stdout

import zoidberg as zb

nx = 3  # Not changed for these tests

# Resolution in y and z
nlist = [8, 16, 32, 64, 128]

# Number of parallel slices (in each direction)
nslices = [1, 2]

directory = "data"

nproc = 2
mthread = 2


success = True

error_2 = {}
error_inf = {}
method_orders = {}

# Run with periodic Y?
yperiodic = True

failures = []

build_and_log("FCI MMS test")

for nslice in nslices:
    error_2[nslice] = []
    error_inf[nslice] = []

    # Which central difference scheme to use and its expected order
    order = nslice * 2
    method_orders[nslice] = {"name": "C{}".format(order), "order": order}

    for n in nlist:
        # Define the magnetic field using new poloidal gridding method
        # Note that the Bz and Bzprime parameters here must be the same as in mms.py
        field = zb.field.Slab(Bz=0.05, Bzprime=0.1)
        # Create rectangular poloidal grids
        poloidal_grid = zb.poloidal_grid.RectangularPoloidalGrid(nx, n, 0.1, 1.0)
        # Set the ylength and y locations
        ylength = 10.0

        if yperiodic:
            ycoords = linspace(0.0, ylength, n, endpoint=False)
        else:
            # Doesn't include the end points
            ycoords = (arange(n) + 0.5) * ylength / float(n)

        # Create the grid
        grid = zb.grid.Grid(poloidal_grid, ycoords, ylength, yperiodic=yperiodic)
        # Make and write maps
        maps = zb.make_maps(grid, field, nslice=nslice, quiet=True)
        zb.write_maps(
            grid, field, maps, new_names=False, metric2d=conf.isMetric2D(), quiet=True
        )

        args = " MZ={} MYG={} mesh:paralleltransform:y_periodic={} mesh:ddy:first={}".format(
            n, nslice, yperiodic, method_orders[nslice]["name"]
        )

        # Command to run
        cmd = "./fci_mms " + args

        print("Running command: " + cmd)

        # Launch using MPI
        s, out = launch_safe(cmd, nproc=nproc, mthread=mthread, pipe=True)

        # Save output to log file
        with open("run.log." + str(n), "w") as f:
            f.write(out)

        if s:
            print("Run failed!\nOutput was:\n")
            print(out)
            exit(s)

        # Collect data
        l_2 = collect(
            "l_2", tind=[1, 1], info=False, path=directory, xguards=False, yguards=False
        )
        l_inf = collect(
            "l_inf",
            tind=[1, 1],
            info=False,
            path=directory,
            xguards=False,
            yguards=False,
        )

        error_2[nslice].append(l_2)
        error_inf[nslice].append(l_inf)

        print("Errors : l-2 {:f} l-inf {:f}".format(l_2, l_inf))

    dx = 1.0 / array(nlist)

    # Calculate convergence order
    fit = polyfit(log(dx), log(error_2[nslice]), 1)
    order = fit[0]
    stdout.write("Convergence order = {:f} (fit)".format(order))

    order = log(error_2[nslice][-2] / error_2[nslice][-1]) / log(dx[-2] / dx[-1])
    stdout.write(", {:f} (small spacing)".format(order))

    # Should be close to the expected order
    if order > method_orders[nslice]["order"] * 0.95:
        print("............ PASS\n")
    else:
        print("............ FAIL\n")
        success = False
        failures.append(method_orders[nslice]["name"])


with open("fci_mms.pkl", "wb") as output:
    pickle.dump(nlist, output)
    for nslice in nslices:
        pickle.dump(error_2[nslice], output)
        pickle.dump(error_inf[nslice], output)

# Do we want to show the plot as well as save it to file.
showPlot = True

if False:
    try:
        # Plot using matplotlib if available
        import matplotlib.pyplot as plt

        fig, ax = plt.subplots(1, 1)

        for nslice in nslices:
            ax.plot(
                dx,
                error_2[nslice],
                "-",
                label="{} $l_2$".format(method_orders[nslice]["name"]),
            )
            ax.plot(
                dx,
                error_inf[nslice],
                "--",
                label="{} $l_\inf$".format(method_orders[nslice]["name"]),
            )
        ax.legend(loc="upper left")
        ax.grid()
        ax.set_yscale("log")
        ax.set_xscale("log")
        ax.set_title("error scaling")
        ax.set_xlabel(r"Mesh spacing $\delta x$")
        ax.set_ylabel("Error norm")

        plt.savefig("fci_mms.pdf")

        print("Plot saved to fci_mms.pdf")

        if showPlot:
            plt.show()
        plt.close()
    except ImportError:
        print("No matplotlib")

if success:
    print("All tests passed")
    exit(0)
else:
    print("Some tests failed:")
    for failure in failures:
        print("\t" + failure)
    exit(1)
