#!/usr/bin/env python3

from boututils.run_wrapper import build_and_log, shell, launch_safe
from boutdata.collect import collect
from netCDF4 import Dataset
import numpy
import os.path
from sys import stdout, exit

# Cores: 6
# Requires: netcdf

build_and_log("griddata test")

nx = 4
ny = 24
blocksize = ny / 6

# first generate some grid files to test
# double null case:
for n_yguards in [0, 1, 2]:
    datadir = "data-doublenull-" + str(n_yguards)
    gridname = "grid-doublenull-" + str(n_yguards) + ".nc"

    with Dataset(os.path.join(datadir, gridname), "w") as gridfile:
        gridfile.createDimension("x", nx)
        gridfile.createDimension("y", ny + 4 * n_yguards)

        gridfile.createVariable("nx", numpy.int32)
        gridfile["nx"][...] = nx

        gridfile.createVariable("ny", numpy.int32)
        gridfile["ny"][...] = ny

        gridfile.createVariable("y_boundary_guards", numpy.int32)
        gridfile["y_boundary_guards"][...] = n_yguards

        gridfile.createVariable("MXG", numpy.int32)
        gridfile["MXG"][...] = 1

        gridfile.createVariable("MYG", numpy.int32)
        gridfile["MYG"][...] = 2 if n_yguards == 0 else n_yguards

        gridfile.createVariable("ixseps1", numpy.int32)
        gridfile["ixseps1"][...] = nx // 2 - 1

        gridfile.createVariable("ixseps2", numpy.int32)
        gridfile["ixseps2"][...] = nx // 2 - 1

        gridfile.createVariable("jyseps1_1", numpy.int32)
        gridfile["jyseps1_1"][...] = blocksize - 1

        gridfile.createVariable("jyseps2_1", numpy.int32)
        gridfile["jyseps2_1"][...] = 2 * blocksize - 1

        gridfile.createVariable("ny_inner", numpy.int32)
        gridfile["ny_inner"][...] = 3 * blocksize

        gridfile.createVariable("jyseps1_2", numpy.int32)
        gridfile["jyseps1_2"][...] = 4 * blocksize - 1

        gridfile.createVariable("jyseps2_2", numpy.int32)
        gridfile["jyseps2_2"][...] = 5 * blocksize - 1

        testdata = numpy.zeros([nx, ny + 4 * n_yguards])
        testdata[:, :] = numpy.arange(ny + 4 * n_yguards)[numpy.newaxis, :]
        gridfile.createVariable("test", float, ("x", "y"))
        gridfile["test"][...] = testdata

# grid files for single-null:
for n_yguards in [0, 1, 2]:
    datadir = "data-singlenull-" + str(n_yguards)
    gridname = "grid-singlenull-" + str(n_yguards) + ".nc"

    with Dataset(os.path.join(datadir, gridname), "w") as gridfile:
        gridfile.createDimension("x", nx)
        gridfile.createDimension("y", ny + 2 * n_yguards)

        gridfile.createVariable("nx", numpy.int32)
        gridfile["nx"][...] = nx

        gridfile.createVariable("ny", numpy.int32)
        gridfile["ny"][...] = ny

        gridfile.createVariable("y_boundary_guards", numpy.int32)
        gridfile["y_boundary_guards"][...] = n_yguards

        gridfile.createVariable("MXG", numpy.int32)
        gridfile["MXG"][...] = 1

        gridfile.createVariable("MYG", numpy.int32)
        gridfile["MYG"][...] = 2 if n_yguards == 0 else n_yguards

        gridfile.createVariable("ixseps1", numpy.int32)
        gridfile["ixseps1"][...] = nx // 2 - 1

        gridfile.createVariable("ixseps2", numpy.int32)
        gridfile["ixseps2"][...] = nx // 2 - 1

        gridfile.createVariable("jyseps1_1", numpy.int32)
        gridfile["jyseps1_1"][...] = blocksize - 1

        gridfile.createVariable("jyseps2_1", numpy.int32)
        gridfile["jyseps2_1"][...] = ny // 2

        gridfile.createVariable("ny_inner", numpy.int32)
        gridfile["ny_inner"][...] = ny // 2

        gridfile.createVariable("jyseps1_2", numpy.int32)
        gridfile["jyseps1_2"][...] = ny // 2

        gridfile.createVariable("jyseps2_2", numpy.int32)
        gridfile["jyseps2_2"][...] = 5 * blocksize - 1

        testdata = numpy.zeros([nx, ny + 2 * n_yguards])
        testdata[:, :] = numpy.arange(ny + 2 * n_yguards)[numpy.newaxis, :]
        gridfile.createVariable("test", float, ("x", "y"))
        gridfile["test"][...] = testdata


for nproc in [6]:
    stdout.write("Checking %d processors ... " % (nproc))

    shell("rm ./data*/BOUT.dmp.*.nc run.log.*")

    success = True

    # double null tests
    for n_yguards in [0, 1, 2]:
        datadir = "data-doublenull-" + str(n_yguards)

        s, out = launch_safe("./test_griddata -d " + datadir, nproc=nproc, pipe=True)

        with open("run.log.doublenull." + str(nproc), "a") as f:
            f.write(out)

        testfield = collect("test", path=datadir, info=False, yguards=True)

        if n_yguards == 0:
            # output has 2 y-guard cells, but grid file did not
            myg = 2
            checkfield = list(numpy.zeros(myg))
            checkfield += list(numpy.arange(ny // 2))
            checkfield += list(numpy.arange(ny // 2) + checkfield[-1] + 1)
            checkfield += list(numpy.zeros(myg) + checkfield[-1])
        else:
            checkfield = []
            checkfield += list(numpy.arange(n_yguards))
            checkfield += list(numpy.arange(ny // 2) + checkfield[-1] + 1)
            checkfield += list(
                numpy.arange(ny // 2) + checkfield[-1] + 1 + 2 * n_yguards
            )
            checkfield += list(numpy.arange(n_yguards) + checkfield[-1] + 1)
        checkfield = numpy.array(checkfield)

        # Test value of testfield
        if numpy.max(numpy.abs(testfield - checkfield)) > 1e-13:
            print(
                "Failed: testfield does not match in doublenull case for n_yguards="
                + str(n_yguards)
            )
            success = False

    # single null tests
    for n_yguards in [0, 1, 2]:
        datadir = "data-singlenull-" + str(n_yguards)

        s, out = launch_safe("./test_griddata -d " + datadir, nproc=nproc, pipe=True)

        with open("run.log.singlenull." + str(nproc), "a") as f:
            f.write(out)

        testfield = collect("test", path=datadir, info=False, yguards=True)

        if n_yguards == 0:
            # output has 2 y-guard cells, but grid file did not
            myg = 2
            checkfield = list(numpy.zeros(myg))
            checkfield += list(numpy.arange(ny))
            checkfield += list(numpy.zeros(myg) + checkfield[-1])
        else:
            checkfield = []
            checkfield += list(numpy.arange(n_yguards))
            checkfield += list(numpy.arange(ny) + checkfield[-1] + 1)
            checkfield += list(numpy.arange(n_yguards) + checkfield[-1] + 1)
        checkfield = numpy.array(checkfield)

        # Test value of testfield
        if numpy.max(numpy.abs(testfield - checkfield)) > 1e-13:
            print(
                "Failed: testfield does not match in doublenull case for n_yguards="
                + str(n_yguards)
            )
            success = False

    if not success:
        exit(1)
    else:
        print("Passed")

exit(0)
