#!/usr/bin/env python3

# Convert an ELITE .eqin file to a NetCDF equilibrium file

from __future__ import print_function
from builtins import next
import sys
import numpy as np

# Check command-line arguments

nargs = len(sys.argv)

infile = None
outfile = None

if (nargs > 3) or (nargs < 2):
    print("Useage:")
    print("  " + sys.argv[0] + " <input file> [ <output file> ]")
    print("input file is an ELITE .eqin format file")
    print("output is (optionally) an output NetCDF file")
    raise SystemExit
elif nargs == 2:
    # Only specified input file
    infile = sys.argv[1]

    # Just tag ".nc" on the end of the file
    outfile = infile + ".nc"
else:
    # Specified both input and output files
    infile = sys.argv[1]
    outfile = sys.argv[2]


# Load data file output routines
try:
    from boututils.datafile import DataFile
except ImportError:
    print("ERROR: boututils.DataFile not available")
    print("=> Set $PYTHONPATH variable to include BOUT++ pylib")
    raise SystemExit


print("Reading: " + infile)

################################################
# Read the ELITE EQIN file

f = open(infile)

# First line contains description

desc = f.readline()
if not desc:
    print("ERROR: Cannot read from input file")
    raise SystemExit

print("Description: " + desc)


# Define a generator to get the next token from the file
def file_tokens(fp):
    toklist = []
    while True:
        line = fp.readline()
        if not line:
            break
        toklist = line.split()
        for tok in toklist:
            yield tok


token = file_tokens(f)

try:
    npsi = int(next(token))
    npol = int(next(token))
except StopIteration:
    print("ERROR: Unexpected end of file while reading grid size")
    raise SystemExit
except ValueError:
    print("ERROR: Second line should contain Npsi and Npol")
    raise SystemExit

print("Size of grid: %d by %d\n" % (npsi, npol))

# var will be a dictionary of variables
var = dict()

while True:
    try:
        varname = next(token)
    except:
        break

    if (varname == "R") or (varname == "Z") or (varname[0] == "B"):
        # A 2D variable
        try:
            data = np.zeros([npsi, npol])
            for j in np.arange(npol):
                for i in np.arange(npsi):
                    data[i, j] = float(next(token))
        except StopIteration:
            print("ERROR: Unexpected end of file while reading " + varname)
            raise SystemExit
        except ValueError:
            print("ERROR: Expecting float while reading " + varname)
            raise SystemExit
        except:
            print("ERROR: Out of cheese. Variable " + varname)
            raise SystemExit

        # Add this variable to the dictionary
        var[varname] = data

        print("Read 2D variable " + varname)

    else:
        # Assume it's a 1D (psi) variable

        try:
            data = np.zeros([npsi])
            for i in np.arange(npsi):
                data[i] = float(next(token))
        except StopIteration:
            print("ERROR: Unexpected end of file while reading " + varname)
            raise SystemExit
        except ValueError:
            print("ERROR: Expecting float while reading " + varname)
            raise SystemExit
        except:
            print("ERROR: Out of cheese. Variable " + varname)
            raise SystemExit

        # Add this variable to the dictionary
        var[varname] = data

        print("Read 1D variable " + varname)

# Finished reading data

f.close()

# Calculate toroidal field

try:
    Bt = np.zeros([npsi, npol])
    f = var["f(psi)"]
    R = var["R"]
    for i in np.arange(npsi):
        Bt[i, :] = f[i] * R[i, :]
    var["Bt"] = Bt
except KeyError:
    print("ERROR: Need f(psi) and R to calculate Bt")
    raise SystemExit

# Temperatures: If only have one, set equal

if "Te" in var and "Ti" not in var:
    var["Ti"] = var["Te"]

if "Ti" in var and "Te" not in var:
    var["Te"] = var["Ti"]


# Calculate pressure

if "p" not in var:
    if "ne" in var and "Te" in var:
        # Calculate using Ni, Te and Ti
        print("Calculating pressure from ne, Te, Ti")
        var["p"] = (
            var["ne"] * (var["Te"] + var["Ti"]) * 1.602e-19 * (4.0 * 3.14159 * 1e-7)
        )
        # Could check against pprime
    else:
        # No plasma quantities to use, so integrate pprime
        print("SORRY: integrating pprime not implemented yet")

# Open the output file

print("\nWriting: " + outfile)

try:
    of = DataFile()
    of.open(outfile, create=True)
except:
    print("ERROR: Could not open output file")
    raise SystemExit

# Write data

# Map between input and output variable names
vnames = {
    "psi": "psi",
    "f": "f(psi)",
    "ffprime": "ffprime",
    "mu0p": "p",
    "mu0pprime": "pprime",
    "qsafe": "q",
    "Ni": "ne",
    "Te": "Te",
    "Ti": "Ti",
    "Rxy": "R",
    "Zxy": "Z",
    "Bpxy": "Bp",
    "Btxy": "Bt",
}

try:
    of.write("nx", npsi)
    of.write("ny", npol)

    for on, vn in vnames.items():
        try:
            of.write(on, var[vn])
        except KeyError:
            print("WARNING: Not writing variable " + on)

except:
    print("ERROR: Could not write data")
    raise

of.close()

print("done")
