#!/usr/bin/env python3

# Test initial conditions

from boututils.run_wrapper import build_and_log, shell, launch_safe
from boutdata.collect import collect

import configparser
import itertools
from scipy.special import erf
import numpy as np
import os

# requires not make
# requires scipy
# cores: 4

########################################
# Implementations of BOUT++ functions


def bout_round(x):
    """
    BOUT++ rounding
    """
    return x + 0.5 if x > 0.0 else x - 0.5


def genRand(seed):
    """
    BOUT++ psuedo random number generator

    This PRNG has no memory, i.e. you need to call it with a different
    seed each time
    """
    # Make sure seed is
    if seed < 0.0:
        seed *= -1

    # Round the seed to get the number of iterations
    niter = int(11 + (23 + bout_round(seed)) % 79)

    # Start x between 0 and 1
    A = 0.01
    B = 1.23456789
    x = (A + np.mod(seed, B)) / (B + 2.0 * A)

    # Iterate logistic map
    for i in range(niter):
        x = 3.99 * x * (1.0 - x)

    return x


def ballooning(x, ball_n=3):
    """
    Ballooning function. Currently too tricky to implement
    """
    raise NotImplementedError("ballooning")


def gauss(x, width=1.0):
    """
    Normalised gaussian
    """
    return np.exp(-(x**2) / (2 * width**2)) / np.sqrt(2 * np.pi)


def mixmode(x, seed=0.5):
    """
    14 modes with random phases
    """
    result = 0.0
    for i in range(14):
        phase = np.pi * (2.0 * genRand(seed + i) - 1.0)
        result += (1.0 / (1.0 + np.abs(i - 4.0)) ** 2) * np.cos(i * x + phase)
    return result


def heaviside(x):
    """
    Heaviside step function
    """
    return 1 * (x > 0)


def tanhhat(x, width, centre, steepness):
    """
    BOUT++ TanhHat function
    """
    return 0.5 * (
        np.tanh(steepness * (x - (centre - width / 2.0)))
        + np.tanh(-steepness * (x - (centre + width / 2.0)))
    )


def atan(x, y=None):
    """
    Resolves to either np.arctan or np.arctan depending on the number of arguments
    """
    if y is not None:
        return np.arctan2(x, y)
    else:
        return np.arctan(x)


def max(*args):
    """
    Maximum of *args at each point
    """
    current = args[0]
    for arg in args:
        current = np.maximum(arg, current)
    return current


def min(*args):
    """
    Minimum of *args at each point
    """
    current = args[0]
    for arg in args:
        current = np.minimum(arg, current)
    return current


def fmod(x, denominator=1.0):
    """
    Modulo operator using fmod convention, (rem in Matlab)
    """
    return np.fmod(x, denominator)


# Rename functions to match BOUT++ naming
# Mostly just alternative names to numpy functions
abs = np.abs
asin = np.arcsin
acos = np.arccos
ballooning = ballooning
cos = np.cos
cosh = np.cosh
exp = np.exp
tanh = np.tanh
H = heaviside
log = np.log
power = np.power
sin = np.sin
sinh = np.sinh
sqrt = np.sqrt
tan = np.tan
TanhHat = tanhhat
pi = np.pi

########################################
# Running the test

# Some parameters
success = True
tolerance = 1e-13
cmd = "./test_initial"
datadir = "data"
inputfile = os.path.join(datadir, "BOUT.inp")

# Read the input file
config = configparser.ConfigParser()
with open(inputfile, "r") as f:
    config.read_file(itertools.chain(["[global]"], f), source=inputfile)

# Find the variables that have a "function" option
varlist = [key for key, values in config.items() if "function" in values]

# Remove the coordinate arrays
for coord in ["var_x", "var_y", "var_z"]:
    varlist.remove(coord)

build_and_log("initial conditions test")

nprocs = [1, 2, 3, 4]
for nproc in nprocs:
    status, out = launch_safe(cmd, nproc=nproc, pipe=True, verbose=True)
    with open("run.log.{}".format(nproc), "w") as f:
        f.write(out)

    if status != 0:
        print("=> Could not run test")
        print(status)
        exit(status)

    # Collect the coordinate arrays separately
    x = collect("var_x", xguards=True, yguards=True, path=datadir, info=False)
    y = collect("var_y", xguards=True, yguards=True, path=datadir, info=False)
    z = collect("var_z", xguards=True, yguards=True, path=datadir, info=False)

    # Evaluate the functions
    for var in varlist:
        function = config[var]["function"]
        function = function.replace("^", "**")
        if ":" in function:
            print(
                "{} contains reference to variable - not possible to resolve at this time".format(
                    var
                )
            )
            continue
        try:
            analytic = eval(function)
        except NotImplementedError as err:
            print("{} not implemented, skipping".format(err.args[0]))
        else:
            data = collect(var, xguards=True, yguards=True, path=datadir, info=False)
            E2 = np.sqrt(np.mean((analytic - data) ** 2))
            if E2 < tolerance:
                success_string = "PASS"
            else:
                if (var == "mixmode" or var == "mixmode_seed") and E2 < 1e-3:
                    import platform

                    arch = platform.machine()
                    if arch == "i686":
                        # This can happen due tue excess precision e.g. on X87 architecture
                        success_string = "WARNING"
                    else:
                        print(
                            "This should only happen in i686 with an x87 math architecture."
                        )
                        print("We detected however an %s architecture." % arch)
                        success_string = "FAIL"
                        success = False
                else:
                    success_string = "FAIL"
                    success = False
            print(
                "\tChecking {var:<12}: l-2: {err:.4e} ... {success}".format(
                    var=var, err=E2, success=success_string
                )
            )

if success:
    print(" => All tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
