#!/usr/bin/env python3

from boutdata import collect
from boututils.run_wrapper import build_and_log, launch_safe
import numpy
from sys import exit

datapath = "data"
nproc = 1
tol = 1.0e-13

build_and_log("twistshift test")

s, out = launch_safe("./test-twistshift", nproc=nproc, pipe=True)
with open("run.log." + str(nproc), "w") as f:
    f.write(out)

test = collect("test", path=datapath, yguards=True, info=False)
test_aligned = collect("test_aligned", path=datapath, yguards=True, info=False)
check = collect("check", path=datapath, yguards=True, info=False)

# from boututils.showdata import showdata
# showdata([test, test_aligned, check], titles=['test', 'test_aligned', 'check'])

success = True


# Check test_aligned is *not* periodic in y
def test1(ylower, yupper):
    global success
    if numpy.any(
        numpy.abs(test_aligned[:, yupper, :] - test_aligned[:, ylower, :]) < 1.0e-6
    ):
        success = False
        print(
            "Fail - test_aligned should not be periodic jy=%i and jy=%i should be "
            "different" % (yupper, ylower)
        )


test1(0, -4)
test1(1, -3)
test1(2, -2)
test1(3, -1)

# Check test_aligned is the same as check
# Cannot check in guard cells, as the expression used for 'zShift' in the input file is
# not the same as the corrected zShift used for the transforms in the guard cells
if numpy.any(numpy.abs(test_aligned[2:-2, 2:-2, :] - check[2:-2, 2:-2, :]) > tol):
    success = False
    print("Fail - test_aligned is different from the expected value")
    print("test_aligned", test_aligned)
    print("check", check)

if success:
    print("Pass")
    exit(0)
else:
    exit(1)
