#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Tests a range of different schemes
#

# requires: all_tests
# requires: make
# cores: 8

from boututils.run_wrapper import shell, launch_safe, build_and_log
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log, pi

from os.path import join

import pickle

import time


if __name__ == "__main__":
    build_and_log("MMS advection")


def make_parent():
    build_and_log("MMS advection")


# List of NX values to use
# [16, 32, 64, 128, 256, 512, 1024]
nxlist = [16, 32, 64, 128, 256, 512, 1024]

nproc = 8


dt0 = 0.15


def run_mms(options, exit=True):
    success = True

    for opts, exp_order in options:
        error_2 = []  # The L2 error (RMS)
        error_inf = []  # The maximum error

        for nx in nxlist:
            # Set the X and Z mesh size

            dx = 2.0 * pi / (nx)

            args = f"{opts} mesh:nx={nx+4} mesh:dx={dx} MZ={nx}"

            print("  Running with " + args)

            # Delete old data
            shell("rm -f data/BOUT.dmp.*.nc")

            # Command to run
            cmd = "../advection " + args
            # Launch using MPI
            start_time_ = time.time()

            s, out = launch_safe(cmd, nproc=nproc, pipe=True)

            # Save output to log file
            with open(f"run.log.{nx}", "w") as f:
                f.write(out)

            # Collect data
            E_f = collect("E_f", xguards=False, tind=[1, 1], info=False, path="data")

            # Average error over domain
            l2 = sqrt(mean(E_f**2))
            linf = max(abs(E_f))

            error_2.append(l2)
            error_inf.append(linf)

            if len(error_2) > 1:
                order = log(error_2[-1] / error_2[-2]) / log(0.5)
                print(" %d | %7.3f s | %.6f" % (nx, time.time() - start_time_, order))
            else:
                print(" %d | %7.3f s | ?" % (nx, time.time() - start_time_))

            print("  -> Error norm: l-2 %f l-inf %f" % (l2, linf))

        dx = 1.0 / array(nxlist)

        # Calculate convergence order

        order = log(error_2[-1] / error_2[-2]) / log(dx[-1] / dx[-2])
        print("Convergence order = %f" % (order))

        if exp_order - 0.25 < order < exp_order + 0.75:
            print("............ PASS")
        else:
            success = False
            print("............ FAIL")

    if exit:
        _exit(success)
    return success


def _exit(success):
    from sys import exit

    if success:
        print(" => All tests passed")
        exit(0)
    else:
        print(" => Some failed tests")
        exit(1)
