#!/usr/bin/env python3

from boututils.datafile import DataFile
import itertools
import time
import numpy as np
from boututils.run_wrapper import launch_safe, shell_safe, build_and_log
import argparse
import re


# requires: all_tests
# requires: netcdf
# cores: 4

IGNORED_VARS_PATTERN = re.compile(
    "(wtime|ncalls|arkode|cvode|run_id|run_restart_from|M.?SUB|N.?PE|iteration|wall_time|has_legacy_netcdf|hist_hi).*"
)


class timer(object):
    """Context manager for printing how long a command took"""

    def __init__(self, msg):
        self.msg = msg

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, exc_type, exc_value, traceback):
        end = time.time()
        print("{:12.8f}s {}".format(end - self.start, self.msg))


def timed_shell_safe(cmd, *args, **kwargs):
    """Wraps shell_safe in a timer"""
    with timer(cmd):
        shell_safe(cmd, *args, **kwargs)


def timed_launch_safe(cmd, *args, **kwargs):
    """Wraps launch_safe in a timer"""
    with timer(cmd):
        launch_safe(cmd, *args, **kwargs)


def verify(f1, f2):
    """Verifies that two BOUT++ files are identical"""
    with timer("verify %s %s" % (f1, f2)):
        d1 = DataFile(f1)
        d2 = DataFile(f2)
        for v in d1.keys():
            if IGNORED_VARS_PATTERN.match(v):
                continue

            if d1[v].shape != d2[v].shape:
                raise RuntimeError(
                    "shape mismatch in '{}': {} vs {}".format(v, d1[v], d2[v])
                )

            v1 = d1[v]
            v2 = d2[v]
            # Ignore corners by setting them to zero
            if len(v1.shape) >= 2:
                # Probably Field2D or Field3D
                v1[0, 0] = v1[0, -1] = v1[-1, 0] = v1[-1, -1] = 0.0
                v2[0, 0] = v2[0, -1] = v2[-1, 0] = v2[-1, -1] = 0.0

            if not np.allclose(v1, v2, equal_nan=True):
                err = ""
                dimensions = [range(x) for x in v1.shape]
                for i in itertools.product(*dimensions):
                    if v1[i] != v2[i]:
                        err += "{}: {} != {}\n".format(i, v1[i], v2[i])
                raise RuntimeError("data mismatch in ", v, err, v1, v2)


parser = argparse.ArgumentParser(description="Test the bout-squashoutput wrapper")
parser.add_argument(
    "executable", help="Path to bout-squashoutput", default="../../../bin"
)
args = parser.parse_args()

build_and_log("Squash test")

bout_squashoutput = args.executable + "/bout-squashoutput"

print("Run once to get normal data")
timed_shell_safe("./squash -q -q -q nout=2")
timed_shell_safe("mv data/BOUT.dmp.0.nc f1.nc")

print("Parallel test")
timed_shell_safe("rm -f f2.nc")
timed_launch_safe("./squash -q -q -q nout=2", nproc=4, mthread=1)
timed_shell_safe("{} -qdcl 9 data --outputname ../f2.nc".format(bout_squashoutput))

verify("f1.nc", "f2.nc")

print("Parallel and in two pieces")
timed_shell_safe("rm -f f2.nc")
timed_launch_safe("./squash -q -q -q", nproc=4, mthread=1)
timed_shell_safe("{} -qdcl 9 data --outputname ../f2.nc".format(bout_squashoutput))
timed_launch_safe("./squash -q -q -q restart", nproc=4, mthread=1)
timed_shell_safe("{} -qdcal 9 data --outputname ../f2.nc".format(bout_squashoutput))

verify("f1.nc", "f2.nc")

print("Parallel and in two pieces without dump_on_restart")
timed_shell_safe("rm -f f2.nc")
timed_launch_safe("./squash -q -q -q", nproc=4, mthread=1)
timed_shell_safe("{} -qdcl 9 data --outputname ../f2.nc".format(bout_squashoutput))
timed_launch_safe("./squash -q -q -q restart dump_on_restart=false", nproc=4, mthread=1)
timed_shell_safe("{} -qdcal 9 data --outputname ../f2.nc".format(bout_squashoutput))

verify("f1.nc", "f2.nc")

print("Sequential test")
timed_shell_safe("rm -f f2.nc")
timed_shell_safe("./squash -q -q -q nout=2")
timed_shell_safe("{} -qdcl 9 data --outputname ../f2.nc".format(bout_squashoutput))

verify("f1.nc", "f2.nc")
