#!/usr/bin/env python3

#
# Run the test, compare results against the benchmark
#

# requires: not metric_3d
# requires: netcdf
# cores: 2

from __future__ import print_function
from __future__ import division

nproc = 2  # Number of processors to run on
reltol = 1.0e-3  # Allowed relative tolerance in growth-rate

from boututils.run_wrapper import build_and_log, shell, launch_safe
from boutdata.collect import collect
import numpy as np
from sys import stdout, exit

nthreads = 1

build_and_log("interchange instability test")

# Delete old output files
shell("rm data_1/BOUT.dmp.*")
shell("rm data_10/BOUT.dmp.*")


def growth_rate(path, nproc, log=False):
    pipe = False
    if log != False:
        pipe = True
    s, out = launch_safe(
        "./2fluid -d " + path, nproc=nproc, mthread=nthreads, pipe=pipe
    )
    if pipe:
        f = open(log, "w")
        f.write(out)
        f.close()
    # Get time base
    tarr = collect("t_array", path=path)
    nt = len(tarr)
    wci = collect("wci", path=path)
    tarr = tarr / wci

    # Read density
    Ni = collect("Ni", path=path, xind=5, yind=30, info=False)

    # Calculate RMS value in Z
    Nirms = np.zeros(nt)
    for t in np.arange(nt):
        Nirms[t] = np.sqrt(np.mean(Ni[t, 0, 0, :] ** 2))

    # Get the growth rate
    ln = np.log(Nirms)
    return (ln[nt - 5] - 8.0 * ln[nt - 4] + 8.0 * ln[nt - 2] - ln[nt - 1]) / (
        6.0 * (tarr[nt - 2] - tarr[nt - 4])
    )


code = 0  # Return code

# Run case 2

print("Test case 1: R = 1m")
growth = growth_rate("data_1", nproc, log="run_1.log")
print("   Log file run_1.log")
orig = (
    2.148177e05  # 24th October 2011, revision c4f7ec92786b333a5502c5256b5e602ba867090f
)
analytic = 2.2e5
print("   Growth-rate = %e, original = %e, analytic = %e" % (growth, orig, analytic))
absdev = abs(growth - orig)
reldev = absdev / orig
print("   Deviation from original: %e (%e %%)" % (absdev, reldev * 100.0))

if reldev > reltol:
    print("  => Failed")
    exit(1)
else:
    print("  => Passed")

# Run case 2

print("Test case 2: R = 10m")
growth = growth_rate("data_10", nproc, log="run_10.log")
print("   Log file run_10.log")
# orig = 65570. # 24th October 2011, revision c4f7ec92786b333a5502c5256b5e602ba867090f
orig = 6.457835e04  # 25th April 2014, revision fd032da
analytic = 6.3e4
print("   Growth-rate = %e, original = %e, analytic = %e" % (growth, orig, analytic))
absdev = abs(growth - orig)
reldev = absdev / orig
print("   Deviation from original: %e (%e %%)" % (absdev, reldev * 100.0))

if reldev > reltol:
    print("  => Failed")
    code = 1
else:
    print("  => Passed")

exit(code)
