#!/usr/bin/env python3

from boutdata import collect
from boututils.run_wrapper import build_and_log, launch_safe
import numpy
from sys import exit

datapath = "data"
nproc = 1
tol = 1.0e-13

build_and_log("twistshift test")

s, out = launch_safe("./test-twistshift", nproc=nproc, pipe=True)
with open("run.log." + str(nproc), "w") as f:
    f.write(out)

test = collect("test", path=datapath, yguards=True, info=False)
test_aligned = collect("test_aligned", path=datapath, yguards=True, info=False)
result = collect("result", path=datapath, yguards=True, info=False)

# from boututils.showdata import showdata
# showdata([test, test_aligned, result], titles=['test', 'test_aligned', 'result'])

success = True


# Check test_aligned is *not* periodic in y
def test1(ylower, yupper):
    global success
    if numpy.any(
        numpy.abs(test_aligned[:, yupper, :] - test_aligned[:, ylower, :]) < 1.0e-6
    ):
        success = False
        print(
            "Fail - test_aligned should not be periodic jy=%i and jy=%i should be "
            "different",
            yupper,
            ylower,
        )


test1(0, -4)
test1(1, -3)
test1(2, -2)
test1(3, -1)

# Check test and result are the same
if numpy.any(numpy.abs(result - test) > tol):
    print("Fail - result has not been communicated correctly - is different from input")
    success = False


# Check result is periodic in y
def test2(ylower, yupper):
    global success
    if numpy.any(numpy.abs(result[:, yupper, :] - result[:, ylower, :]) > tol):
        success = False
        print(
            "Fail - result should be periodic jy=%i and jy=%i should not be "
            "different",
            yupper,
            ylower,
        )
        print(ylower, result[:, ylower, :])
        print(yupper, result[:, yupper, :])
        print(result[:, ylower, :] - result[:, yupper, :])


test2(0, -4)
test2(1, -3)
test2(2, -2)
test2(3, -1)

if success:
    print("Pass")
    exit(0)
else:
    exit(1)
