#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Exits with status 1 if any test failed


from __future__ import division
from __future__ import print_function
from builtins import str

import numpy as np

# Requires: boutpp
# requires: not make

import boutpp as bc

bc.init("-q -q -q")


funcs = [["sin(z)", "sin(4*x)", "4*cos(z)*cos(4*x)"]]

# List of NX values to use
nlist = [8, 16, 32]

prec = 0.2


def genMesh(nx, ny, nz, **kwargs):
    bc.setOption("mesh:MXG", str(2 if nx > 1 else 0), force=True)
    bc.setOption("mesh:MYG", str(2 if ny > 1 else 0), force=True)
    bc.setOption("mesh:nx", str(nx + 4 if nx > 1 else nx), force=True)
    bc.setOption("mesh:ny", str(ny), force=True)
    bc.setOption("mz", str(nz), force=True)
    bc.setOption("mesh:nz", str(nz), force=True)
    bc.setOption("mesh:dx", "2*pi/(%d)" % (nx), force=True)
    bc.setOption("mesh:dy", "1/(%d)" % (ny), force=True)
    bc.setOption("mesh:dz", "1/(%d)" % (nz), force=True)
    for k, v in kwargs.items():
        bc.setOption(k, v, force=True)
    return bc.Mesh(section="mesh")


errlist = ""
brackets = [
    ["ARAKAWA", 2, {}],
    ["STD", 1, {"mesh:ddx:upwind": "U1", "mesh:ddz:upwind": "U1"}],
    ["STD", 2, {"mesh:ddx:upwind": "C2", "mesh:ddz:upwind": "C2"}],
    # Weno convergence order is currently under discusision:
    # https://github.com/boutproject/BOUT-dev/issues/1049
    # ["STD", 3 , {"mesh:ddx:upwind":"W3",
    #              "mesh:ddz:upwind":"W3"}
    #  ]
]
for in1, in2, outfunc in funcs:
    for name, order, args in brackets:
        errors = []
        for n in nlist[-2:]:
            mesh = genMesh(n, 1, n, **args)
            inf1 = bc.create3D(in1, mesh)
            inf2 = bc.create3D(in2, mesh)
            inf = bc.bracket(inf1, inf2, method=name)
            outf = bc.create3D(outfunc, mesh)
            err = inf - outf
            slize = (slice(2, -2), 0, slice(None))
            err = err[slize]
            err = np.sqrt(np.mean(err**2))
            errors.append(err)
        errc = np.log(errors[-2] / errors[-1])
        difc = np.log(nlist[-1] / nlist[-2])
        conv = errc / difc
        if order - prec < conv < order + prec:
            pass
        else:
            err = "{%s,%s} -> %s failed (conv: %f - expected: %f)\n" % (
                in1,
                in2,
                outfunc,
                conv,
                order,
            )
            errlist += err

if errlist:
    print(errlist)
    exit(1)

# FIXME: this is a workaround for a strange bug, MPI_Comm_free after
# MPI_Finalize. Seems to be from the global mesh getting destroyed
# after/at the same time as finalise is called -- hence explicitly
# deleting it here
mesh = bc.Mesh.getGlobal()
del mesh
bc.finalise()

print("All tests passed")
exit(0)
