#!/usr/bin/env python3

print("Add any arguments to save a PDF")

import matplotlib as mpl
import matplotlib.pyplot as plt
#from matplotlib import cm
import numpy as np
from scipy.io import netcdf
from scipy.interpolate import interp1d
import sys
import os

makePDF = False
if len(sys.argv) > 1:
    makePDF = True

#filename = 'regcoil_out.20200706-01-001-c09r00_withPorts_lambda1e-15_Picard_thetaZeta128_mpolNtor32_ns2_mgrid.nc'
filename = 'regcoil_out.20200716-01-037_uniformThickness10cm.nc'
f = netcdf.netcdf_file(filename,'r',mmap=False)
nfp = f.variables['nfp'][()]
ntheta_plasma = f.variables['ntheta_plasma'][()]
ntheta_coil = f.variables['ntheta_coil'][()]
nzeta_plasma = f.variables['nzeta_plasma'][()]
nzeta_coil = f.variables['nzeta_coil'][()]
nzetal_plasma = f.variables['nzetal_plasma'][()]
nzetal_coil = f.variables['nzetal_coil'][()]
theta_plasma = f.variables['theta_plasma'][()]
theta_coil = f.variables['theta_coil'][()]
zeta_plasma = f.variables['zeta_plasma'][()]
zeta_coil = f.variables['zeta_coil'][()]
zetal_plasma = f.variables['zetal_plasma'][()]
zetal_coil = f.variables['zetal_coil'][()]
r_plasma  = f.variables['r_plasma'][()]
r_coil  = f.variables['r_coil'][()]
chi2_B = f.variables['chi2_B'][()]
chi2_M = f.variables['chi2_M'][()]
max_M = f.variables['max_M'][()]
min_M = f.variables['min_M'][()]
abs_M = f.variables['abs_M'][()]
rmnc_outer = f.variables['rmnc_outer'][()]
rmns_outer = f.variables['rmns_outer'][()]
zmnc_outer = f.variables['zmnc_outer'][()]
zmns_outer = f.variables['zmns_outer'][()]
mnmax_coil = f.variables['mnmax_coil'][()]
xm_coil = f.variables['xm_coil'][()]
xn_coil = f.variables['xn_coil'][()]
ports_weight = f.variables['ports_weight'][()]
d = f.variables['d'][()]
ns_magnetization = f.variables['ns_magnetization'][()]
Bnormal_from_TF_and_plasma_current = f.variables['Bnormal_from_TF_and_plasma_current'][()]
Bnormal_total = f.variables['Bnormal_total'][()]
# = f.variables[''][()]

nlambda = f.variables['nlambda'][()]
nsaved = len(chi2_B)
lambdas = f.variables['lambda'][()]

f.close()

if np.max(np.abs(lambdas)) < 1.0e-200:
    print("lambda array appears to be all 0. Changing it to all 1 to avoid a python error.")
    lambdas += 1

########################################################
# Sort in order of lambda, since for a lambda search (general_option=4 or 5),
# the output arrays are in the order of the search, which is not so convenient for plotting.
########################################################

if nsaved == nlambda:
    permutation = np.argsort(lambdas)
    lambdas = lambdas[permutation]
    chi2_M = chi2_M[permutation]
    chi2_B = chi2_B[permutation]
    Bnormal_total = Bnormal_total[permutation,:,:]
    max_M = max_M[permutation]
    min_M = min_M[permutation]
    if lambdas[-1]>1.0e199:
        lambdas[-1] = np.inf

########################################################
# For 3D plotting, 'close' the arrays in u and v
########################################################

#r_plasma  = np.append(r_plasma,  r_plasma[[0],:,:], axis=0)
#r_plasma  = np.append(r_plasma,  r_plasma[:,[0],:], axis=1)
#zetal_plasma = np.append(zetal_plasma,nfp)

#r_coil  = np.append(r_coil,  r_coil[[0],:,:], axis=0)
#r_coil  = np.append(r_coil,  r_coil[:,[0],:], axis=1)
#zetal_coil = np.append(zetal_coil,nfp)

########################################################
# Extract cross-sections of the 3 surfaces at several toroidal angles
########################################################

def getCrossSection(rArray, zetal_old, zeta_new):
    zetal_old = np.concatenate((zetal_old-2*np.pi,zetal_old, zetal_old+2*np.pi))
    rArray = np.concatenate((rArray,rArray,rArray),axis=0)


    print("zetal_old shape:",zetal_old.shape)
    print("rArray shape:",rArray.shape)

    x = rArray[:,:,0]
    y = rArray[:,:,1]
    z = rArray[:,:,2]
    R = np.sqrt(x**2 + y**2)


    ntheta = z.shape[1]
    nzeta_new = len(zeta_new)
    R_slice = np.zeros([nzeta_new,ntheta+1])
    Z_slice = np.zeros([nzeta_new,ntheta+1])
    for itheta in range(ntheta):
        interpolator = interp1d(zetal_old, R[:,itheta])
        R_slice[:,itheta] = interpolator(zeta_new)
        interpolator = interp1d(zetal_old, z[:,itheta])
        Z_slice[:,itheta] = interpolator(zeta_new)
    R_slice[:,-1] = R_slice[:,0]
    Z_slice[:,-1] = Z_slice[:,0]

    return R_slice, Z_slice

zeta_slices = np.array([0, 0.25, 0.5, 0.75])*2*np.pi/nfp
R_slice_plasma, Z_slice_plasma = getCrossSection(r_plasma, zetal_plasma, zeta_slices)
R_slice_coil, Z_slice_coil = getCrossSection(r_coil, zetal_coil, zeta_slices)

R_slice_outer = np.zeros((ntheta_coil+1,4))
Z_slice_outer = np.zeros((ntheta_coil+1,4))
theta_coil_big = np.linspace(0,2*np.pi,ntheta_coil+1)
for imn in range(mnmax_coil):
    for izeta in range(4):
        angle = xm_coil[imn] * theta_coil_big - xn_coil[imn] * zeta_slices[izeta]
        sinangle = np.sin(angle)
        cosangle = np.cos(angle)
        R_slice_outer[:,izeta] += rmnc_outer[imn] * cosangle + rmns_outer[imn] * sinangle
        Z_slice_outer[:,izeta] += zmnc_outer[imn] * cosangle + zmns_outer[imn] * sinangle

########################################################
# Now make plot of surfaces at given toroidal angle
########################################################

figureNum = 1
fig = plt.figure(figureNum,figsize=(10,4))
fig.patch.set_facecolor('white')

numRows = 1
numCols = 3

Rmin = R_slice_outer.min()
Rmax = R_slice_outer.max()
Zmin = Z_slice_outer.min()
Zmax = Z_slice_outer.max()

dR = Rmax - Rmin
dZ = Zmax - Zmin
margin = 0.05

index = np.argmin(np.abs(zeta_coil - 0.25*2*np.pi/3))

markerSize = 3
angles = ['0', '(1/4)(2\pi/3)', '(1/2)(2\pi/3)']
for whichPlot in range(3):
    plt.subplot(numRows,numCols,whichPlot+1)
    zeta = zeta_slices[whichPlot]
    plt.fill(R_slice_outer[:,whichPlot], Z_slice_outer[:,whichPlot], label='Magnetization region',facecolor=[0.7,1,0.7])
    plt.plot(R_slice_outer[:,whichPlot], Z_slice_outer[:,whichPlot], 'g-', label='Outer magnetization surface')
    plt.fill(R_slice_coil[whichPlot,:], Z_slice_coil[whichPlot,:], facecolor=[1,1,1])
    plt.plot(R_slice_coil[whichPlot,:], Z_slice_coil[whichPlot,:], 'b-', label='Inner magnetization surface')
    plt.fill(R_slice_plasma[whichPlot,:], Z_slice_plasma[whichPlot,:], facecolor=[1,0.8,0.8], label='Plasma region')
    plt.plot(R_slice_plasma[whichPlot,:], Z_slice_plasma[whichPlot,:], 'r-', label='Plasma boundary')

    #if whichPlot == 1:
    #    mask = ports_weight[index,:] > 10
    #    # Close the curve:
    #    mask = np.append(mask,mask[0])
    #    x = R_slice_coil[whichPlot,mask]
    #    y = Z_slice_coil[whichPlot,mask]
    #    perm = np.argsort(y)
    #    plt.plot(x[perm], y[perm], 'c:', label='Port')
    #elif whichPlot == 2:
    #    plt.plot([0,0],[0,1],'c:', label='Port')

    plt.gca().set_aspect('equal',adjustable='box')
    if whichPlot==2:
        plt.legend(fontsize=7, framealpha=1.0)
    plt.title(r'$\phi = ' + angles[whichPlot] + '$')
    plt.xlabel('R [meters]')
    plt.ylabel('Z [meters]')
    plt.xlim([Rmin - dR * margin, Rmax + dR * margin])
    plt.ylim([Zmin - dZ * margin, Zmax + dZ * margin])

plt.tight_layout()

abc=['a','b','c']
for j in range(3):
    plt.figtext(0.015 + 0.33*j, 0.925, '('+abc[j]+')', fontsize=15)

if makePDF:
    print("Saving PDF")
    plt.savefig(__file__ + ".pdf")
else:
    plt.show()

