#!/usr/bin/env python3

import numpy as np
from sys import exit

from boututils.datafile import DataFile
from boututils.run_wrapper import build_and_log, shell, launch_safe

# Cores: 18
# Requires: netcdf

# Note, the expected values for the communicated cells, which are used to test the
# results, are shown in corner_communication_diagram.ods

#################################################
# Test disconnected-double-null diverted topology
#################################################

nype = 6  # need at least 6 to include all regions

# run with processor splitting at separatrices
##################################################

# With nxpe=3, both separatrices are on processor boundaries
nxpe = 3
command = "./test-communications NXPE=" + str(nxpe)

build_and_log("Communications Test")

# remove old outputs
shell("rm data/BOUT.dmp.*")

print("Running Communications Test, nproc=" + str(nxpe * nype))

s, out = launch_safe(command, nproc=nxpe * nype, pipe=True)
with open("run.log." + str(nxpe), "w") as f:
    f.write(out)


def test(actual, expected, procnum, region, boundary):
    if not np.all(np.array(np.rint(actual).astype(int)) == np.array(expected)):
        print(
            "failed in {} boundary of region {} ({}). Expected: {}  Actual: {}".format(
                boundary, procnum, region, expected, actual
            )
        )
        exit(1)


region = "lower, inner PF"
f = DataFile("data/BOUT.dmp.0.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [72, 73], 0, region, "inner x")
test(f[-1, 1:-1], [24, 25], 0, region, "outer x")
test(f[:, 0], [96, 97, 98, 99], 0, region, "lower y")
test(f[:, -1], [82, 10, 22, 26], 0, region, "upper y")

region = "lower, inner between separatrices"
f = DataFile("data/BOUT.dmp.1.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [12, 13], 0, region, "inner x")
test(f[-1, 1:-1], [48, 49], 0, region, "outer x")
test(f[:, 0], [98, 99, 100, 101], 0, region, "lower y")
test(f[:, -1], [22, 26, 38, 50], 0, region, "upper y")

region = "lower, inner SOL"
f = DataFile("data/BOUT.dmp.2.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [36, 37], 0, region, "inner x")
test(f[-1, 1:-1], [84, 85], 0, region, "outer x")
test(f[:, 0], [100, 101, 102, 103], 0, region, "lower y")
test(f[:, -1], [38, 50, 62, 86], 0, region, "upper y")

region = "inner core"
f = DataFile("data/BOUT.dmp.3.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [74, 75], 0, region, "inner x")
test(f[-1, 1:-1], [26, 27], 0, region, "outer x")
test(f[:, 0], [81, 9, 21, 25], 0, region, "lower y")
test(f[:, -1], [80, 8, 20, 32], 0, region, "upper y")

region = "inner between separatrices"
f = DataFile("data/BOUT.dmp.4.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [14, 15], 0, region, "inner x")
test(f[-1, 1:-1], [50, 51], 0, region, "outer x")
test(f[:, 0], [21, 25, 37, 49], 0, region, "lower y")
test(f[:, -1], [20, 32, 44, 52], 0, region, "upper y")

region = "inner SOL"
f = DataFile("data/BOUT.dmp.5.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [38, 39], 0, region, "inner x")
test(f[-1, 1:-1], [86, 87], 0, region, "outer x")
test(f[:, 0], [37, 49, 61, 85], 0, region, "lower y")
test(f[:, -1], [44, 52, 64, 88], 0, region, "upper y")

region = "upper inner PF"
f = DataFile("data/BOUT.dmp.6.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [76, 77], 0, region, "inner x")
test(f[-1, 1:-1], [28, 29], 0, region, "outer x")
test(f[:, 0], [79, 7, 19, 31], 0, region, "lower y")
test(f[:, -1], [104, 105, 106, 107], 0, region, "upper y")

region = "upper inner SOL, nearer PF"
f = DataFile("data/BOUT.dmp.7.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [16, 17], 0, region, "inner x")
test(f[-1, 1:-1], [52, 53], 0, region, "outer x")
test(f[:, 0], [19, 31, 43, 51], 0, region, "lower y")
test(f[:, -1], [106, 107, 108, 109], 0, region, "upper y")

region = "upper inner SOL, further from PF"
f = DataFile("data/BOUT.dmp.8.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [40, 41], 0, region, "inner x")
test(f[-1, 1:-1], [88, 89], 0, region, "outer x")
test(f[:, 0], [43, 51, 63, 87], 0, region, "lower y")
test(f[:, -1], [108, 109, 110, 111], 0, region, "upper y")

region = "upper outer PF"
f = DataFile("data/BOUT.dmp.9.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [78, 79], 0, region, "inner x")
test(f[-1, 1:-1], [30, 31], 0, region, "outer x")
test(f[:, 0], [96, 97, 98, 99], 0, region, "lower y")
test(f[:, -1], [76, 4, 16, 28], 0, region, "upper y")

region = "upper outer SOL, nearer PF"
f = DataFile("data/BOUT.dmp.10.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [18, 19], 0, region, "inner x")
test(f[-1, 1:-1], [54, 55], 0, region, "outer x")
test(f[:, 0], [98, 99, 100, 101], 0, region, "lower y")
test(f[:, -1], [16, 28, 40, 56], 0, region, "upper y")

region = "upper outer SOL, further from PF"
f = DataFile("data/BOUT.dmp.11.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [42, 43], 0, region, "inner x")
test(f[-1, 1:-1], [90, 91], 0, region, "outer x")
test(f[:, 0], [100, 101, 102, 103], 0, region, "lower y")
test(f[:, -1], [40, 56, 68, 92], 0, region, "upper y")

region = "outer core"
f = DataFile("data/BOUT.dmp.12.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [80, 81], 0, region, "inner x")
test(f[-1, 1:-1], [32, 33], 0, region, "outer x")
test(f[:, 0], [75, 3, 15, 27], 0, region, "lower y")
test(f[:, -1], [74, 2, 14, 34], 0, region, "upper y")

region = "outer between separatrices"
f = DataFile("data/BOUT.dmp.13.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [20, 21], 0, region, "inner x")
test(f[-1, 1:-1], [56, 57], 0, region, "outer x")
test(f[:, 0], [15, 27, 39, 55], 0, region, "lower y")
test(f[:, -1], [14, 34, 46, 58], 0, region, "upper y")

region = "outer SOL"
f = DataFile("data/BOUT.dmp.14.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [44, 45], 0, region, "inner x")
test(f[-1, 1:-1], [92, 93], 0, region, "outer x")
test(f[:, 0], [39, 55, 67, 91], 0, region, "lower y")
test(f[:, -1], [46, 58, 70, 94], 0, region, "upper y")

region = "lower outer PF"
f = DataFile("data/BOUT.dmp.15.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [82, 83], 0, region, "inner x")
test(f[-1, 1:-1], [34, 35], 0, region, "outer x")
test(f[:, 0], [73, 1, 13, 33], 0, region, "lower y")
test(f[:, -1], [104, 105, 106, 107], 0, region, "upper y")

region = "lower outer between separatrices"
f = DataFile("data/BOUT.dmp.16.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [22, 23], 0, region, "inner x")
test(f[-1, 1:-1], [58, 59], 0, region, "outer x")
test(f[:, 0], [13, 33, 45, 57], 0, region, "lower y")
test(f[:, -1], [106, 107, 108, 109], 0, region, "upper y")

region = "lower outer SOL"
f = DataFile("data/BOUT.dmp.17.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [46, 47], 0, region, "inner x")
test(f[-1, 1:-1], [94, 95], 0, region, "outer x")
test(f[:, 0], [45, 57, 69, 93], 0, region, "lower y")
test(f[:, -1], [108, 109, 110, 111], 0, region, "upper y")

# run with processor splitting not at separatrices
##################################################

# With nxpe=2, neither separatrix is on a processor boundary
nxpe = 2
command = "./test-communications NXPE=" + str(nxpe)

# remove old outputs
shell("rm data/BOUT.dmp.*")

print("Running Communications Test, nproc=" + str(nxpe * nype))

s, out = launch_safe(command, nproc=nxpe * nype, pipe=True)
with open("run.log." + str(nxpe), "w") as f:
    f.write(out)

region = "lower, inner leg, interior"
f = DataFile("data/BOUT.dmp.0.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [72, 73], 0, region, "inner x")
test(f[-1, 1:-1], [36, 37], 0, region, "outer x")
test(f[:, 0], [96, 97, 98, 99, 100], 0, region, "lower y")
test(f[:, -1], [82, 10, 22, 26, 38], 0, region, "upper y")

region = "lower, inner leg, exterior"
f = DataFile("data/BOUT.dmp.1.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [24, 25], 0, region, "inner x")
test(f[-1, 1:-1], [84, 85], 0, region, "outer x")
test(f[:, 0], [99, 100, 101, 102, 103], 0, region, "lower y")
test(f[:, -1], [26, 38, 50, 62, 86], 0, region, "upper y")

region = "inner, interior"
f = DataFile("data/BOUT.dmp.2.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [74, 75], 0, region, "inner x")
test(f[-1, 1:-1], [38, 39], 0, region, "outer x")
test(f[:, 0], [81, 9, 21, 25, 37], 0, region, "lower y")
test(f[:, -1], [80, 8, 20, 32, 44], 0, region, "upper y")

region = "inner, exterior"
f = DataFile("data/BOUT.dmp.3.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [26, 27], 0, region, "inner x")
test(f[-1, 1:-1], [86, 87], 0, region, "outer x")
test(f[:, 0], [25, 37, 49, 61, 85], 0, region, "lower y")
test(f[:, -1], [32, 44, 52, 64, 88], 0, region, "upper y")

region = "upper, inner leg, interior"
f = DataFile("data/BOUT.dmp.4.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [76, 77], 0, region, "inner x")
test(f[-1, 1:-1], [40, 41], 0, region, "outer x")
test(f[:, 0], [79, 7, 19, 31, 43], 0, region, "lower y")
test(f[:, -1], [104, 105, 106, 107, 108], 0, region, "upper y")

region = "upper, inner leg, exterior"
f = DataFile("data/BOUT.dmp.5.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [28, 29], 0, region, "inner x")
test(f[-1, 1:-1], [88, 89], 0, region, "outer x")
test(f[:, 0], [31, 43, 51, 63, 87], 0, region, "lower y")
test(f[:, -1], [107, 108, 109, 110, 111], 0, region, "upper y")

region = "upper, outer leg, interior"
f = DataFile("data/BOUT.dmp.6.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [78, 79], 0, region, "inner x")
test(f[-1, 1:-1], [42, 43], 0, region, "outer x")
test(f[:, 0], [96, 97, 98, 99, 100], 0, region, "lower y")
test(f[:, -1], [76, 4, 16, 28, 40], 0, region, "upper y")

region = "upper, outer leg, exterior"
f = DataFile("data/BOUT.dmp.7.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [30, 31], 0, region, "inner x")
test(f[-1, 1:-1], [90, 91], 0, region, "outer x")
test(f[:, 0], [99, 100, 101, 102, 103], 0, region, "lower y")
test(f[:, -1], [28, 40, 56, 68, 92], 0, region, "upper y")

region = "outer, interior"
f = DataFile("data/BOUT.dmp.8.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [80, 81], 0, region, "inner x")
test(f[-1, 1:-1], [44, 45], 0, region, "outer x")
test(f[:, 0], [75, 3, 15, 27, 39], 0, region, "lower y")
test(f[:, -1], [74, 2, 14, 34, 46], 0, region, "upper y")

region = "outer, exterior"
f = DataFile("data/BOUT.dmp.9.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [32, 33], 0, region, "inner x")
test(f[-1, 1:-1], [92, 93], 0, region, "outer x")
test(f[:, 0], [27, 39, 55, 67, 91], 0, region, "lower y")
test(f[:, -1], [34, 46, 58, 70, 94], 0, region, "upper y")

region = "lower, outer leg, interior"
f = DataFile("data/BOUT.dmp.10.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [82, 83], 0, region, "inner x")
test(f[-1, 1:-1], [46, 47], 0, region, "outer x")
test(f[:, 0], [73, 1, 13, 33, 45], 0, region, "lower y")
test(f[:, -1], [104, 105, 106, 107, 108], 0, region, "upper y")

region = "lower, outer leg, exterior"
f = DataFile("data/BOUT.dmp.11.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [34, 35], 0, region, "inner x")
test(f[-1, 1:-1], [94, 95], 0, region, "outer x")
test(f[:, 0], [33, 45, 57, 69, 93], 0, region, "lower y")
test(f[:, -1], [107, 108, 109, 110, 111], 0, region, "upper y")

# run with no processor splitting in the x-direction
####################################################

nxpe = 1
command = "./test-communications NXPE=" + str(nxpe)

# remove old outputs
shell("rm data/BOUT.dmp.*")

print("Running Communications Test, nproc=" + str(nxpe * nype))

s, out = launch_safe(command, nproc=nxpe * nype, pipe=True)
with open("run.log." + str(nxpe), "w") as f:
    f.write(out)

region = "lower, inner leg"
f = DataFile("data/BOUT.dmp.0.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [72, 73], 0, region, "inner x")
test(f[-1, 1:-1], [84, 85], 0, region, "outer x")
test(f[:, 0], [96, 97, 98, 99, 100, 101, 102, 103], 0, region, "lower y")
test(f[:, -1], [82, 10, 22, 26, 38, 50, 62, 86], 0, region, "upper y")

region = "inner"
f = DataFile("data/BOUT.dmp.1.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [74, 75], 0, region, "inner x")
test(f[-1, 1:-1], [86, 87], 0, region, "outer x")
test(f[:, 0], [81, 9, 21, 25, 37, 49, 61, 85], 0, region, "lower y")
test(f[:, -1], [80, 8, 20, 32, 44, 52, 64, 88], 0, region, "upper y")

region = "upper, inner leg"
f = DataFile("data/BOUT.dmp.2.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [76, 77], 0, region, "inner x")
test(f[-1, 1:-1], [88, 89], 0, region, "outer x")
test(f[:, 0], [79, 7, 19, 31, 43, 51, 63, 87], 0, region, "lower y")
test(f[:, -1], [104, 105, 106, 107, 108, 109, 110, 111], 0, region, "upper y")

region = "upper, outer leg"
f = DataFile("data/BOUT.dmp.3.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [78, 79], 0, region, "inner x")
test(f[-1, 1:-1], [90, 91], 0, region, "outer x")
test(f[:, 0], [96, 97, 98, 99, 100, 101, 102, 103], 0, region, "lower y")
test(f[:, -1], [76, 4, 16, 28, 40, 56, 68, 92], 0, region, "upper y")

region = "outer"
f = DataFile("data/BOUT.dmp.4.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [80, 81], 0, region, "inner x")
test(f[-1, 1:-1], [92, 93], 0, region, "outer x")
test(f[:, 0], [75, 3, 15, 27, 39, 55, 67, 91], 0, region, "lower y")
test(f[:, -1], [74, 2, 14, 34, 46, 58, 70, 94], 0, region, "upper y")

region = "lower, outer leg"
f = DataFile("data/BOUT.dmp.5.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [82, 83], 0, region, "inner x")
test(f[-1, 1:-1], [94, 95], 0, region, "outer x")
test(f[:, 0], [73, 1, 13, 33, 45, 57, 69, 93], 0, region, "lower y")
test(f[:, -1], [104, 105, 106, 107, 108, 109, 110, 111], 0, region, "upper y")


#################################################
# Test limiter topology
#################################################

nype = 1

# run with processor splitting at the separatrix
####################################################

nxpe = 3
command = "./test-communications -d data_limiter NXPE=" + str(nxpe)

# remove old outputs
shell("rm data_limiter/BOUT.dmp.*")

print("Running Communications Test, nproc=" + str(nxpe * nype))

s, out = launch_safe(command, nproc=nxpe * nype, pipe=True)
with open("run_limiter.log." + str(nxpe), "w") as f:
    f.write(out)

region = "core"
f = DataFile("data_limiter/BOUT.dmp.0.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [12, 13], 0, region, "inner x")
test(f[-1, 1:-1], [4, 5], 0, region, "outer x")
test(f[:, 0], [13, 1, 3, 19], 0, region, "lower y")
test(f[:, -1], [12, 0, 2, 27], 0, region, "upper y")

region = "interior SOL"
f = DataFile("data_limiter/BOUT.dmp.1.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [2, 3], 0, region, "inner x")
test(f[-1, 1:-1], [8, 9], 0, region, "outer x")
test(f[:, 0], [3, 19, 20, 21], 0, region, "lower y")
test(f[:, -1], [2, 27, 28, 29], 0, region, "upper y")

region = "exterior SOL"
f = DataFile("data_limiter/BOUT.dmp.2.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [6, 7], 0, region, "inner x")
test(f[-1, 1:-1], [14, 15], 0, region, "outer x")
test(f[:, 0], [20, 21, 22, 23], 0, region, "lower y")
test(f[:, -1], [28, 29, 30, 31], 0, region, "upper y")

# run with processor splitting not at the separatrix
####################################################

nxpe = 2
command = "./test-communications -d data_limiter NXPE=" + str(nxpe)

# remove old outputs
shell("rm data_limiter/BOUT.dmp.*")

print("Running Communications Test, nproc=" + str(nxpe * nype))

s, out = launch_safe(command, nproc=nxpe * nype, pipe=True)
with open("run_limiter.log." + str(nxpe), "w") as f:
    f.write(out)

region = "interior"
f = DataFile("data_limiter/BOUT.dmp.0.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [12, 13], 0, region, "inner x")
test(f[-1, 1:-1], [6, 7], 0, region, "outer x")
test(f[:, 0], [13, 1, 3, 19, 20], 0, region, "lower y")
test(f[:, -1], [12, 0, 2, 27, 28], 0, region, "upper y")

region = "exterior"
f = DataFile("data_limiter/BOUT.dmp.1.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [4, 5], 0, region, "inner x")
test(f[-1, 1:-1], [14, 15], 0, region, "outer x")
test(f[:, 0], [19, 20, 21, 22, 23], 0, region, "lower y")
test(f[:, -1], [27, 28, 29, 30, 31], 0, region, "upper y")

# run with no processor splitting in the x-direction
####################################################

nxpe = 1
command = "./test-communications -d data_limiter NXPE=" + str(nxpe)

# remove old outputs
shell("rm data_limiter/BOUT.dmp.*")

print("Running Communications Test, nproc=" + str(nxpe * nype))

s, out = launch_safe(command, nproc=nxpe * nype, pipe=True)
with open("run_limiter.log." + str(nxpe), "w") as f:
    f.write(out)

region = "all"
f = DataFile("data_limiter/BOUT.dmp.0.nc")["f"][0, :, :, 0]
test(f[0, 1:-1], [12, 13], 0, region, "inner x")
test(f[-1, 1:-1], [14, 15], 0, region, "outer x")
test(f[:, 0], [13, 1, 3, 19, 20, 21, 22, 23], 0, region, "lower y")
test(f[:, -1], [12, 0, 2, 27, 28, 29, 30, 31], 0, region, "upper y")

# If we did not exit already, then all tests passed
print("Pass")
exit(0)
