#!/usr/bin/env python
#
# Test of staggered grids for wave equation
#
# Runs the test case twice, once with staggered grids enabled, and
# once with staggered grids disabled.
#
# In each case, a top-hat density perturbation is used
# to give a mixture of starting modes. By taking the FFT
# in both time and space, the frequency of each mode is
# calculated.
#

from __future__ import print_function
from boutdata.collect import collect
from numpy.fft import rfftn
from numpy import argmax

from boututils.run_wrapper import build_and_log, launch

def analyse(path="test"):
  """
  Reads in density "n", subtracts 1 to get the
  perturbation.
  Takes FFT in time and Y, and for each
  Y wavenumber finds the time harmonic with the
  highest amplitude
  """

  # Read the data
  n = collect("n", path=path)
  n = n[:,2,:,0]  # Get 2D [t,y] array
  # Subtract the background to get perturbation
  delta_n = n - 1.0

  # Take FFT in both t and y
  f = rfftn(delta_n)

  nt = f.shape[0] # Number of time points

  # Get peak harmonic
  freq = argmax(abs(f[0:(nt//2),:]), axis=0)

  return freq

####################################################


build_and_log("Staggered wave")

# Run with and without staggered grids
print("Running with staggered grids")
s, out = launch("./test_staggered -d test mesh:StaggerGrids=true", nproc=1, pipe=True)
with open("run.log.stag", "w") as f:
  f.write(out)

stag = analyse()

print("Running without staggered grids")
s, out = launch("./test_staggered -d test mesh:StaggerGrids=false", nproc=1, pipe=True)
with open("run.log.nostag", "w") as f:
  f.write(out)

nostag = analyse()

import matplotlib.pyplot as plt

plt.plot(stag, '-o', label="Staggered grids")
plt.plot(nostag, '-x', label="Centered grids")

plt.legend(loc="upper left")
plt.xlabel(r"Spatial harmonic $k$")
plt.ylabel(r"Frequency $\omega$")

plt.savefig("test_staggered.pdf")
plt.show()
