#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Tests a range of different schemes
#

from __future__ import division
from __future__ import print_function
from builtins import str

from boututils.run_wrapper import shell, build_and_log, launch_safe
from boutdata.collect import collect

# requires: all_tests

verbose = False

from numpy import sqrt, max, abs, mean, array, log, pi

from os.path import join

import sys

try:
    import matplotlib.pyplot as plt
except:
    plt = None


build_and_log("MMS time integration test")

# List of options to be passed for each test
if "only_petsc" in sys.argv:
    # this requires petsc:
    options = [("solver:type=imexbdf2 -snes_mf", "IMEX-BDF2", "-+", 2)]
else:
    options = [
        ("solver:type=euler", "Euler", "-^", 1),
        ("solver:type=rk3ssp", "RK3-SSP", "-s", 3),
        ("solver:type=rk4", "RK4", "-o", 4),
        ("solver:type=rkgeneric", "RK-generic", "-o", 4),
        ("solver:type=splitrk", "SplitRK", "-x", 2),
    ]

# Missing: cvode, ida, slepc, power, arkode, snes
# Is there a better way to check a certain solver should be enabled?

# List of NX values to use
timesteps = [4, 8, 16, 32, 64, 128, 256]

nproc = 1

success = True

err_2_all = []
err_inf_all = []
for opts, label, sym, expected_order in options:
    error_2 = []  # The L2 error (RMS)
    error_inf = []  # The maximum error

    for nt in timesteps:
        args = " solver:timestep=" + str(1.0 / nt) + " " + opts

        if verbose:
            print("  Running with " + args)

        # Delete old data
        shell("rm data/BOUT.dmp.*.nc")

        # Command to run
        cmd = "./time " + args
        # Launch using MPI
        s, out = launch_safe(cmd, nproc=nproc, pipe=True)

        # Save output to log file
        f = open("run.log." + label + "." + str(nt), "w")
        f.write(out)
        f.close()

        # Collect data
        E_f = collect("E_f", tind=[1, 1], info=False, path="data")
        # Average error over domain
        l2 = sqrt(mean(E_f**2))
        linf = max(abs(E_f))

        error_2.append(l2)
        error_inf.append(linf)

        if verbose:
            print("  -> Error norm: l-2 %f l-inf %f" % (l2, linf))

    # Append to list of all results
    err_2_all.append((error_2, label, sym))
    err_inf_all.append((error_inf, label, sym))

    # Calculate grid spacing
    dt = 1.0 / array(timesteps)

    # Calculate convergence order

    order = log(error_2[-1] / error_2[-2]) / log(dt[-1] / dt[-2])
    print("Convergence order = %f (expected: %f) %s" % (order, expected_order, label))

    if expected_order - 0.2 < order < expected_order + 0.2:
        pass
    else:
        success = False
        print(" -> FAILED")

    # plot errors
    if False:
        plt.figure()

        plt.plot(dt, error_2, "-o", label=r"$l^2$")
        plt.plot(dt, error_inf, "-x", label=r"$l^\infty$")

        plt.plot(
            dt, error_2[-1] * (dt / dt[-1]) ** order, "--", label="Order %.1f" % (order)
        )

        plt.legend(loc="lower right")
        plt.grid()

        plt.yscale("log")
        plt.xscale("log")

        plt.xlabel(r"Time step $\delta t$")
        plt.ylabel("Error norm")

        # plt.savefig("norm.pdf")

        plt.show()
        plt.close()

try:
    # Plot all results for comparison
    plt.figure()
    for e, label, sym in err_2_all:
        plt.plot(dt, e, sym, label=label)

    plt.legend(loc="lower right")
    plt.grid()
    plt.yscale("log")
    plt.xscale("log")

    plt.xlabel(r"Time step $\delta t$")
    plt.ylabel(r"$l^2$ error norm")
    plt.savefig("time_norm.pdf")
    plt.close()
except:
    pass

if success:
    print(" => All tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
