#!/usr/bin/env python3

import boutpp

# requires boutpp
# requires not make

import numpy as np
import itertools
from sys import exit

errorlist = []
passlist = []
boutpp.init("-d data -q -q -q".split(" "))


def runtests(functions, derivatives, directions, stag, msg):
    global errorlist
    global passlist
    for direction in directions:
        direction, fac, guards, diff_func = direction
        locations = ["centre"]
        if stag:
            locations.append(direction.lower() + "low")
        for funcs, derivative, floc, vloc in itertools.product(
            functions, derivatives, locations, locations
        ):
            vfunc, ffunc, outfunc = funcs
            order, diff = derivative

            errors = []
            for nz in nzs:
                boutpp.setOption(
                    "meshD:nD".replace("D", direction),
                    "%d" % (nz + (2 * guards if direction == "x" else 0)),
                    force=True,
                )
                boutpp.setOption(
                    "meshD:dD".replace(
                        "D",
                        direction,
                    ),
                    "2*pi/(%d)" % (nz),
                    force=True,
                )
                dirnfac = direction + "*" + fac
                mesh = boutpp.Mesh(section="mesh" + direction)
                f = boutpp.create3D(ffunc.replace("%s", dirnfac), mesh, outloc=floc)
                v = boutpp.create3D(vfunc.replace("%s", dirnfac), mesh, outloc=vloc)
                sim = diff_func(v, f, method=diff, outloc=floc)
                if sim.getLocation() != floc:
                    if floc == "centre" and sim.getLocation() == "centre":
                        pass
                    else:
                        errorlist.append(
                            "Location does not match - expected %s but got %s"
                            % (floc, sim.getLocation())
                        )
                ana = boutpp.create3D(outfunc.replace("%s", dirnfac), mesh, outloc=floc)
                err = sim - ana
                err = err.getAll().flatten()
                if guards:
                    err = err[guards:-guards]
                err = np.max(np.abs(err))
                errors.append(err)
            errc = np.log(errors[-2] / errors[-1])
            difc = np.log(nzs[-1] / nzs[-2])
            conv = errc / difc
            if order - 0.1 < conv < order + 0.1:
                info = "%s - %s - %s - %s - %s -> %s " % (
                    vfunc,
                    ffunc,
                    diff,
                    direction,
                    floc,
                    vloc,
                )
                passMsg = "%s: %s is working. Expected %f got %f" % (
                    msg,
                    info,
                    order,
                    conv,
                )
                passlist.append(passMsg)
            else:
                info = "%s - %s - %s - %s - %s -> %s " % (
                    vfunc,
                    ffunc,
                    diff,
                    direction,
                    floc,
                    vloc,
                )
                error = "%s: %s is not working. Expected %f got %f" % (
                    msg,
                    info,
                    order,
                    conv,
                )
                errorlist.append(error)
                if doPlot:
                    from matplotlib import pyplot as plt

                    plt.plot((ana).getAll().flatten())
                    plt.plot((sim).getAll().flatten())
                    plt.show()


mmax = 7
start = 6
doPlot = False
nzs = np.logspace(start, mmax, num=mmax - start + 1, base=2)

functions = [
    ["sin(%s)", "sin(%s)", "sin(%s)*cos(%s)"],
    ["sin(%s)", "cos(%s)", "-sin(%s)*sin(%s)"],
]

derivatives = [
    [2, "C2"],
    [4, "C4"],
    [1, "U1"],
    [2, "U2"],
    # [3,"W3"] ,
]

directions = [
    ["x", "2*pi", 2, boutpp.VDDX],
    ["y", "1", 2, boutpp.VDDY],
    ["z", "1", 0, boutpp.VDDZ],
]

runtests(functions, derivatives, directions, stag=False, msg="DD")

derivatives = [
    [2, "C2"],
    [4, "C4"],
    [1, "U1"],
    [2, "U2"],
    # [3,"W3"] ,
]

runtests(functions, derivatives, directions, stag=True, msg="DD")

boutpp.finalise()

if errorlist:
    for error in errorlist:
        print(error)
    if passlist:
        for passMsg in passlist:
            print(passMsg)
    exit(1)
else:
    print("Pass")
    exit(0)
