#!/usr/bin/env python3

# Convert a GATO file to a NetCDF equilibrium file
# This code should be more resiliant than the C
# code gato2pdb since it uses regular expressions to
# match tokens, rather than sscanf
#
# Ben Dudson, University of York, Nov 2009

from __future__ import print_function
from __future__ import division
from builtins import str
from builtins import next
from past.utils import old_div
import sys
import numpy as np
import re
from boututils.datafile import DataFile

# Check command-line arguments

nargs = len(sys.argv)

infile = None
outfile = None

if (nargs > 3) or (nargs < 2):
    print("Useage:")
    print("  " + sys.argv[0] + " <input file> [ <output file> ]")
    print("input file is an GATO format file")
    print("output is (optionally) an output NetCDF file")
    raise SystemExit
elif nargs == 2:
    # Only specified input file
    infile = sys.argv[1]

    # Just tag ".nc" on the end of the file
    outfile = infile + ".nc"
else:
    # Specified both input and output files
    infile = sys.argv[1]
    outfile = sys.argv[2]

# Load data file output routines
try:
    from boututils.datafile import DataFile
except ImportError:
    print("ERROR: boututils.DataFile not available")
    print("=> Set $PYTHONPATH variable to include BOUT++ pylib")
    raise SystemExit

print("Reading: " + infile)

################################################
# Read the GATO equilibrium file

f = open(infile)

# First line contains the date
date = f.readline()
if not date:
    print("ERROR: Cannot read from input file")
    raise SystemExit

print("Date: " + date)

# Second is description
desc = f.readline()
print("Description: " + desc)


# Define a generator to get the next token from the file
def file_tokens(fp):
    """Generator to get numbers from a text file"""
    toklist = []
    while True:
        line = fp.readline()
        if not line:
            break
        # Match numbers in the line using regular expression
        pattern = r"[+-]?\d*[\.]?\d+(?:[Ee][+-]?\d+)?"
        toklist = re.findall(pattern, line)
        for tok in toklist:
            yield tok


token = file_tokens(f)

# Third contains number of mesh points
try:
    npsi = int(next(token))
    ntheta = int(next(token))
    isym = int(next(token))
except StopIteration:
    print("ERROR: Unexpected end of file while reading grid size")
    raise SystemExit
except ValueError:
    print("ERROR: Third line should contain npsi, ntheta and isym")
    raise SystemExit

print("Grid size %d %d %d" % (npsi, ntheta, isym))

# Check values
if (isym < 0) or (isym > 1):
    print("ERROR: isym must be either 0 or 1")
    raise SystemExit
if (npsi < 1) or (ntheta < 1):
    print("ERROR: ")

# Read normalisation factors

try:
    rcnt = float(next(token))
    xma = float(next(token))
    zma = float(next(token))
    btor = float(next(token))
except:
    print("ERROR: Couldn't read normalisation factors")
    raise SystemExit

print("rcnt = %e, xma = %e, zma = %e, btor = %e" % (rcnt, xma, zma, btor))

try:
    curtot = float(next(token))
    eaxe = float(next(token))
    dnorm = float(next(token))
except:
    print("ERROR: Couldn't read normalisation factors")
    raise SystemExit

print("curtot = %e, eaxe = %e, dnorm = %e\n" % (curtot, eaxe, dnorm))


def read_array(n, name="Unknown"):
    data = np.zeros([n])
    try:
        for i in np.arange(n):
            data[i] = float(next(token))
    except:
        print("ERROR: Failed reading array '" + name + "' of size ", n)
        raise SystemExit
    return data


def read_2d(nx, ny, name="Unknown"):
    data = np.zeros([nx, ny])
    print("Reading 2D variable " + name)
    for i in np.arange(nx):
        data[i, :] = read_array(ny, name + "[" + str(i) + "]")
    return data


# Read 1D arrays
psiflux = read_array(npsi, "psiflux")
fnorm = read_array(npsi, "fnorm")
ffpnorm = read_array(npsi, "ffpnorm")
ponly = read_array(npsi, "ponly")
pponly = read_array(npsi, "pponly")
qsf = read_array(npsi, "qsf")
d = read_array(npsi, "d")

dpdz = read_array(ntheta, "dpdz")
dpdr = read_array(ntheta, "dpdr")

# 2D arrays

xnorm = read_2d(ntheta, npsi, "xnorm")
znorm = read_2d(ntheta, npsi, "znorm")

# Try to read Br and Bz (may be present)
try:
    Br = read_2d(ntheta, npsi, "Br")
    Bz = read_2d(ntheta, npsi, "Bz")
except:
    print("=> File does not contain Br and Bz")
    Br = Bz = None

ny = ntheta

if isym == 1:
    # Fill in values for up-down symmetric case
    print("Grid is up-down symmetric. Reflecting grid about midplane")
    ny = tsize = 2 * (ntheta - 1) + 1

    def reflect(data, mapfunc=lambda x: x):
        """Reflect a variable about midplane
        Optionally supply a mapping function"""
        data2 = np.zeros([tsize, npsi])
        # Copy the original data
        for i in np.arange(ntheta):
            data2[i, :] = data[i, :]
        # Now fill in the remainder
        for i in np.arange(ntheta, tsize):
            t0 = tsize - 1 - i
            data2[i, :] = mapfunc(data[t0, :])
        return data2

    xnorm = reflect(xnorm)
    znorm = reflect(znorm, lambda x: 2.0 * zma - x)  # Reflect about zma
    if Br != None:
        Br = reflect(Br, lambda x: -x)  # Br reverses
    if Bz != None:
        Bz = reflect(Bz)  # Bz remains the same
    theta = tsize

# Make sure we have Br, Bz and Bpol

if (Br == None) or (Bz == None):
    # Calculate Bpol from psi then Br and Bz from Bpol
    # Use dpsi = 2*PI*R*Bp dx (for now)
    Bpol = None
else:
    Bpol = np.sqrt(Br**2 + Bz**2)

# Calculate toroidal field
Btor = old_div(fnorm, xnorm)

# Output file
print("Writing data to '" + outfile + "'")
of = DataFile()
of.open(outfile, create=True)

of.write("nx", npsi)
of.write("ny", ny)

# Write 1D flux functions
mu0 = 4.0e-7 * np.pi
of.write("psi", psiflux)
of.write("f", fnorm)
of.write("ffprime", ffpnorm)
of.write("mu0p", ponly * mu0)
of.write("mu0pprime", pponly * mu0)
of.write("qsafe", qsf)
of.write("Ni", d)

# Write 2D grid locations
of.write("Rxy", np.transpose(xnorm))
of.write("Zxy", np.transpose(znorm))

# Write 2D field components
# of.write("Br", np.transpose(Br))
# of.write("Bz", np.transpose(Bz))
# of.write("Bpol", np.transpose(Bpol))
of.write("Btor", np.transpose(Btor))

of.close()

print("Done")
