#!/usr/bin/env python3

import boutpp

# requires boutpp
# requires not make

from boutdata.data import BoutOptionsFile
import numpy as np
import itertools
from sys import exit

errorlist = []

boutpp.init("-d data -q -q -q")


def runtests(functions, derivatives, direction, stag, msg, errorlist, ordertol=0.1):
    direction, fac, guards, diff_func = direction
    locations = ["centre"]
    if stag:
        locations.append(direction.upper() + "low")
    for funcs, derivative, inloc, outloc in itertools.product(
        functions, derivatives, locations, locations
    ):
        infunc, outfunc = funcs
        order, diff = derivative

        errors = []
        for nz in nzs:
            boutpp.setOption(
                "mesh:ny",
                "%d" % (nz + (2 * guards if direction == "x" else 0)),
                force=True,
            )
            boutpp.setOption("mesh:dy", "2*pi/(%d)" % (nz), force=True)
            boutpp.setOption(
                "mesh:nz",
                "%d" % (nz + (2 * guards if direction == "x" else 0)),
                force=True,
            )
            boutpp.setOption("mesh:dz", "2*pi/(%d)" % (nz), force=True)
            dirnfac = direction + "*" + fac
            mesh = boutpp.Mesh(section="mesh")

            opt = BoutOptionsFile("data/BOUT.inp")
            opt["mesh"]["ny"] = nz
            opt["mesh"]["nz"] = nz
            opt.getSection("mms")
            opt.recalculate_xyz()

            # BOUT++'s input expressions translate from field aligned to
            # orthogonal.
            # We want to input an orthogonal-grid expression, so use
            # BoutOptionsFile routines to generate input function instead.
            #
            # Adding "0.*x*y*z" does not change the value of the array, but
            # ensures that its dimension is (nx,ny,nz). One or more dimensions in
            # the numpy array might only have length 1 if the coordinate was not
            # explicitly present (normally we could rely on array broadcasting,
            # but this is not used when creating a boutpp.Field3D object.
            opt["mms"]["f"] = infunc.replace("%s", dirnfac) + "+0.*x*y*z"
            f_array = opt.evaluate("mms:f")
            f = boutpp.Field3D.fromMesh(mesh)
            f.set(f_array)
            mesh.communicate(f)  # communicate to set yup/ydown fields
            sim = diff_func(f, method=diff, outloc=outloc)
            if sim.getLocation() != outloc:
                errorlist.append(
                    "Location does not match - expected %s but got %s"
                    % (outloc, sim.getLocation())
                )
            opt["mms"]["ana"] = outfunc.replace("%s", dirnfac) + "+0.*x*y*z"
            ana_array = opt.evaluate("mms:ana")
            ana = boutpp.Field3D.fromMesh(mesh)
            ana.set(ana_array)
            err = sim - ana
            err = err.getAll()
            if guards:
                err = err[1:-1, guards:-guards, :]
            err = np.max(np.abs(err))
            print("this iteration " + str(nz) + "\t" + str(err))
            errors.append(err)
        errc = np.log(errors[-2] / errors[-1])
        difc = np.log(nzs[-1] / nzs[-2])
        conv = errc / difc
        if order - ordertol < conv < order + ordertol:
            pass
        else:
            info = "%s - %s - %s - %s -> %s " % (infunc, diff, direction, inloc, outloc)
            error = "%s: %s is not working. Expected %f got %f" % (
                msg,
                info,
                order,
                conv,
            )
            errorlist.append(error)
        if doPlot:
            from matplotlib import pyplot as plt

            plt.subplot(121)
            plt.plot((ana).getAll().flatten())
            plt.plot((sim).getAll().flatten())
            plt.subplot(122)
            plt.loglog(1.0 / nzs, errors)
            plt.show()


mmax = 9
start = 6
doPlot = False
nzs = np.logspace(start, mmax, num=mmax - start + 1, base=2).astype(int)

# This test was based on derivatives3, but has some extra restrictions:
# - in this test, cannot use any derivative that requires more than a 3-point
#   stencil, because only 1 yup and 1 ydown field are calculated
# - Don't take derivatives in the z-direction, because we need to add
#   z-variation to test the ParallelTransform methods, but don't account for that
#   in calculating the analytic derivatives

# Note that inputs are given in field-aligned coordinates, but outputs are in
# orthogonal x-z coordinates (if ShiftedMetric or ShiftedMetricInterp are used)
functions = [
    ["sin(%s + (z-mesh:zShift))", "cos(%s + (z-mesh:zShift))"],
    ["cos(%s + (z-mesh:zShift))", "-sin(%s + (z-mesh:zShift))"],
]

derivatives = [
    [2, "C2"],
]

direction = ["y", "1", 2, boutpp.DDY]

runtests(functions, derivatives, direction, stag=False, msg="DD", errorlist=errorlist)

derivatives = [
    [2, "C2"],
]


functions = [
    ["sin(%s + (z-mesh:zShift))", "-sin(%s + (z-mesh:zShift))"],
    ["cos(%s + (z-mesh:zShift))", "-cos(%s + (z-mesh:zShift))"],
]

direction = ["y", "1", 2, boutpp.D2DY2]

runtests(functions, derivatives, direction, False, "D2D2", errorlist, ordertol=0.2)

if errorlist:
    for error in errorlist:
        print(error)
    exit(1)
else:
    print("Pass")
    exit(0)
