#!/usr/bin/env python

# Python script to run and analyse MMS test

from __future__ import division
from __future__ import print_function

try:
  from builtins import str
except:
  pass

from boututils.run_wrapper import shell, launch, getmpirun, build_and_log
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log, concatenate


build_and_log("Fluid model MMS test")

# List of NY values to use
nylist = [20, 40, 80, 160, 320, 640, 1280]

nout = 1
timestep = 1

nproc = 1

varnames = ["n", "p", "nv"]

results = {}
for var in varnames:
  results[var] = {"l2":[], "inf":[]}

for ny in nylist:
    args = "mesh:ny="+str(ny)+" nout="+str(nout)+" timestep="+str(timestep)
    
    print("Running with " + args)

    # Delete old data
    shell("rm data/BOUT.dmp.*.nc")
    
    # Command to run
    cmd = "./fluid -d mms "+args
    # Launch using MPI
    s, out = launch(cmd, nproc=nproc, pipe=True)

    # Save output to log file
    f = open("run.log."+str(ny), "w")
    f.write(out)
    f.close()

    # Collect data
    for var in varnames:
      E = collect("E_"+var, tind=[nout,nout], path="mms", info=False)
      E = E[0,0,:,0]

      l2 = sqrt(mean(E**2))
      linf = max(abs(E))
      results[var]["l2"].append( l2 )
      results[var]["inf"].append( linf )
      
      print("Error norm %s: l-2 %f l-inf %f" % (var, l2, linf))

# Calculate grid spacing
dy = 1. / array(nylist)

success = True

for var in varnames:
  l2 = results[var]["l2"]
  order = log(l2[-1] / l2[-2]) / log(dy[-1] / dy[-2])
  print("Convergence order %s = %f" % (var, order))

  if order < 1.8:
    success = False

# plot errors
try:
  import matplotlib.pyplot as plt

  for var in varnames:
    l2 = results[var]["l2"]
    inf = results[var]["inf"]
    order = log(l2[-1] / l2[-2]) / log(dy[-1] / dy[-2])
    
    plt.plot(dy, l2, '-o', label=r'$l^2$ ('+var+')')
    plt.plot(dy, inf, '-x', label=r'$l^\infty$ ('+var+')')
    plt.plot(dy, l2[-1]*(dy/dy[-1])**order, '--', label="Order %.1f" % (order))

  plt.legend(loc="upper left")
  plt.grid()

  plt.yscale('log')
  plt.xscale('log')

  plt.xlabel(r'Mesh spacing $\delta y$')
  plt.ylabel("Error norm")

  plt.savefig("fluid_norm.pdf")
  plt.savefig("fluid_norm.png")
  
  #plt.show()
  plt.close()
except:
  # Plotting could fail for any number of reasons, and the actual
  # error raised may depend on, among other things, the current
  # matplotlib backend, so catch everything
  pass

if success:
  print(" => Test passed")
  exit(0)
else:
  print(" => Test failed")
  exit(1)
