#!/usr/bin/env python3
#
# Test file I/O by loading from restart files and writing to dump files
#
# requires: netcdf
# cores: 4

from boutdata import restart
from boutdata.collect import collect
from boututils.boutarray import BoutArray
from boututils.datafile import DataFile
from boututils.run_wrapper import build_and_log, shell, launch_safe
import numpy
import os
from sys import exit
import uuid

nx = 8
ny = 16
nz = 4
mxg = 2
myg = 2

build_and_log("restart I/O test")

x = numpy.linspace(0.0, 1.0, nx + 2 * mxg)[:, numpy.newaxis, numpy.newaxis]
y = numpy.linspace(0.0, 1.0, ny + 2 * myg)[numpy.newaxis, :, numpy.newaxis]
z = numpy.linspace(0.0, 1.0, nz)[numpy.newaxis, numpy.newaxis, :]

testvars = {}
testvars["f3d"] = BoutArray(
    numpy.exp(numpy.sin(x + y + z)), attributes={"bout_type": "Field3D"}
)
testvars["f2d"] = BoutArray(
    numpy.exp(numpy.sin(x + y + 1.0))[:, :, 0], attributes={"bout_type": "Field2D"}
)
testvars["fperp_lower"] = BoutArray(
    numpy.exp(numpy.sin(x + z + 2.0))[:, 0, :],
    attributes={"bout_type": "FieldPerp", "yindex_global": 0},
)
testvars["fperp_upper"] = BoutArray(
    numpy.exp(numpy.sin(x + z + 3.0))[:, 0, :],
    attributes={"bout_type": "FieldPerp", "yindex_global": ny - 1},
)

# make restart file
restartdir = os.path.join("data", "restart")
try:
    os.mkdir(restartdir)
except FileExistsError:
    pass

with DataFile(
    os.path.join(restartdir, "BOUT.restart.0.nc"), create=True
) as base_restart:
    base_restart.write("MXSUB", nx)
    base_restart.write("MYSUB", ny)
    base_restart.write("MZSUB", nz)
    base_restart.write("MXG", mxg)
    base_restart.write("MYG", myg)
    base_restart.write("MZG", 0)
    base_restart.write("nx", nx + 2 * mxg)
    base_restart.write("ny", ny)
    base_restart.write("nz", nz)
    base_restart.write("MZ", nz)
    base_restart.write("NXPE", 1)
    base_restart.write("NYPE", 1)
    base_restart.write("NZPE", 1)
    base_restart.write("tt", 0.0)
    base_restart.write("hist_hi", 0)
    # set BOUT_VERSION to stop collect from changing nz or printing a warning
    base_restart.write("BOUT_VERSION", 4.0)
    base_restart.write("f3d", testvars["f3d"])
    base_restart.write("f2d", testvars["f2d"])
    base_restart.write("fperp_lower", testvars["fperp_lower"])
    base_restart.write("fperp_upper", testvars["fperp_upper"])

    # make run_id an array of characters because NetCDF doesn't like generic strings
    run_id_string = "36 character run_id test string ****"
    run_id = numpy.array(list(run_id_string), dtype="S1")
    run_id = BoutArray(run_id, attributes={"bout_type": "string"})
    base_restart.write("run_id", run_id)
    run_restart_from = numpy.array(
        list("36 character run_restart_from string"), dtype="S1"
    )
    run_restart_from = BoutArray(run_restart_from, attributes={"bout_type": "string"})
    base_restart.write("run_restart_from", run_restart_from)

success = True

# Note: expect this to fail for 16 processors, because when there are 2
# y-points per processor, the fperp_lower FieldPerp is in the grid cells of one
# set of processors and also in the guard cells of anoether set. This means
# valid FieldPerps get written to output files with different
# y-processor-indices, and collect() cannot handle this.
for nproc in [1, 2, 4]:
    # delete any existing output
    shell("rm -f data/BOUT.dmp.*.nc data/BOUT.restart.*.nc")

    # create restart files for the run
    restart.redistribute(nproc, path=restartdir, output="data")

    print("   %d processor...." % (nproc))

    # run the test executable
    s, out = launch_safe("./test-restart-io", nproc=nproc, pipe=True)
    with open("run.log." + str(nproc), "w") as f:
        f.write(out)

    # check the results
    for name in testvars.keys():
        # check non-evolving version
        result = collect(
            name + "_once", path="data", xguards=True, yguards=True, info=False
        )
        testvar = testvars[name]

        if not numpy.allclose(testvar, result):
            success = False
            print(f"{name}_once is different: {numpy.max(numpy.abs(testvar - result))}")
            # Don't plot anything by default
            if False:
                from boututils.showdata import showdata

                showdata([result, testvar])
        if name == "fperp_lower" or name == "fperp_upper":
            yindex_result = result.attributes["yindex_global"]
            yindex_test = testvar.attributes["yindex_global"]
            if not yindex_result == yindex_test:
                success = False
                print(
                    "Fail: yindex_global of "
                    + name
                    + " is "
                    + str(yindex_result)
                    + " should be "
                    + str(yindex_test)
                )

        # check evolving versions
        result = collect(name, path="data", xguards=True, yguards=True, info=False)

        for result_timeslice in result:
            if not numpy.all(testvar == result_timeslice):
                success = False
                print(
                    f"{name} evolving version is different: {numpy.max(numpy.abs(testvar - result_timeslice))}"
                )
            if name == "fperp_lower" or name == "fperp_upper":
                yindex_result = result.attributes["yindex_global"]
                yindex_test = testvar.attributes["yindex_global"]
                if not yindex_result == yindex_test:
                    success = False
                    print(
                        "Fail: yindex_global of "
                        + name
                        + " evolving version is "
                        + str(yindex_result)
                        + " should be "
                        + str(yindex_test)
                    )

    # check the run_id
    run_id = str(collect("run_id", path="data", info=False))
    # check run_id can be converted to a valid UUID
    try:
        uuid.UUID(run_id)
    except ValueError:
        success = False
        print(f"run_id='{run_id}' is not a valid UUID")
    run_restart_from = str(collect("run_restart_from", path="data", info=False))
    if not run_restart_from == run_id_string:
        success = False
        print(
            f"incorrect run_restart_from='{run_restart_from}'. Expected '{run_id_string}'"
        )

if success:
    print("=> All restart I/O tests passed")
    # clean up binary files
    shell(
        "rm -f data/BOUT.dmp.*.nc data/BOUT.restart.*.nc data/restart/BOUT.restart.0.nc"
    )
    exit(0)

print("=> Some failed tests")
exit(1)
