#!/usr/bin/env python
print "usage: include any arguments to save a PDF."

# coilsPerHalfPeriod is used only to choose the number of contours to plot for the total current potential:
coilsPerHalfPeriod = 3

import matplotlib as mpl
import matplotlib.pyplot as plt
#from matplotlib import cm
import numpy as np
from scipy.io import netcdf
from scipy.interpolate import interp1d
import sys
import os

makePDF = False
if len(sys.argv) > 1:
    makePDF=True

filename = 'regcoil_out.benchmark.nc'
f = netcdf.netcdf_file(filename,'r',mmap=False)
nfp = f.variables['nfp'][()]
ntheta_plasma = f.variables['ntheta_plasma'][()]
ntheta_coil = f.variables['ntheta_coil'][()]
nzeta_plasma = f.variables['nzeta_plasma'][()]
nzeta_coil = f.variables['nzeta_coil'][()]
nzetal_plasma = f.variables['nzetal_plasma'][()]
nzetal_coil = f.variables['nzetal_coil'][()]
theta_plasma = f.variables['theta_plasma'][()]
theta_coil = f.variables['theta_coil'][()]
zeta_plasma = f.variables['zeta_plasma'][()]
zeta_coil = f.variables['zeta_coil'][()]
zetal_plasma = f.variables['zetal_plasma'][()]
zetal_coil = f.variables['zetal_coil'][()]
r_plasma  = f.variables['r_plasma'][()]
r_coil  = f.variables['r_coil'][()]
chi2_B = f.variables['chi2_B'][()]
chi2_M = f.variables['chi2_M'][()]
max_M = f.variables['max_M'][()]
min_M = f.variables['min_M'][()]
abs_M = f.variables['abs_M'][()]
rmnc_outer = f.variables['rmnc_outer'][()]
rmns_outer = f.variables['rmns_outer'][()]
zmnc_outer = f.variables['zmnc_outer'][()]
zmns_outer = f.variables['zmns_outer'][()]
mnmax_coil = f.variables['mnmax_coil'][()]
xm_coil = f.variables['xm_coil'][()]
xn_coil = f.variables['xn_coil'][()]
ports_weight = f.variables['ports_weight'][()]
d = f.variables['d'][()]
ns_magnetization = f.variables['ns_magnetization'][()]
Bnormal_from_TF_and_plasma_current = f.variables['Bnormal_from_TF_and_plasma_current'][()]
Bnormal_total = f.variables['Bnormal_total'][()]
# = f.variables[''][()]

nlambda = f.variables['nlambda'][()]
nsaved = len(chi2_B)
lambdas = f.variables['lambda'][()]

print "ntheta_plasma: ",ntheta_plasma
print "nzetal_plasma: ",nzetal_plasma
print "r_plasma.shape: ",r_plasma.shape
print "Bnormal_total.shape: ",Bnormal_total.shape
print "abs_M.shape: ",abs_M.shape

f.close()

if np.max(np.abs(lambdas)) < 1.0e-200:
    print "lambda array appears to be all 0. Changing it to all 1 to avoid a python error."
    lambdas += 1

########################################################
# Sort in order of lambda, since for a lambda search (general_option=4 or 5),
# the output arrays are in the order of the search, which is not so convenient for plotting.
########################################################

if nsaved == nlambda:
    permutation = np.argsort(lambdas)
    lambdas = lambdas[permutation]
    chi2_M = chi2_M[permutation]
    chi2_B = chi2_B[permutation]
    Bnormal_total = Bnormal_total[permutation,:,:]
    max_M = max_M[permutation]
    min_M = min_M[permutation]
    if lambdas[-1]>1.0e199:
        lambdas[-1] = np.inf

########################################################
# For 3D plotting, 'close' the arrays in u and v
########################################################

#r_plasma  = np.append(r_plasma,  r_plasma[[0],:,:], axis=0)
#r_plasma  = np.append(r_plasma,  r_plasma[:,[0],:], axis=1)
#zetal_plasma = np.append(zetal_plasma,nfp)

#r_coil  = np.append(r_coil,  r_coil[[0],:,:], axis=0)
#r_coil  = np.append(r_coil,  r_coil[:,[0],:], axis=1)
#zetal_coil = np.append(zetal_coil,nfp)

########################################################
# Extract cross-sections of the 3 surfaces at several toroidal angles
########################################################

def getCrossSection(rArray, zetal_old, zeta_new):
    zetal_old = np.concatenate((zetal_old-2*np.pi,zetal_old, zetal_old+2*np.pi))
    rArray = np.concatenate((rArray,rArray,rArray),axis=0)


    print "zetal_old shape:",zetal_old.shape
    print "rArray shape:",rArray.shape

    x = rArray[:,:,0]
    y = rArray[:,:,1]
    z = rArray[:,:,2]
    R = np.sqrt(x**2 + y**2)


    ntheta = z.shape[1]
    nzeta_new = len(zeta_new)
    R_slice = np.zeros([nzeta_new,ntheta+1])
    Z_slice = np.zeros([nzeta_new,ntheta+1])
    for itheta in range(ntheta):
        interpolator = interp1d(zetal_old, R[:,itheta])
        R_slice[:,itheta] = interpolator(zeta_new)
        interpolator = interp1d(zetal_old, z[:,itheta])
        Z_slice[:,itheta] = interpolator(zeta_new)
    R_slice[:,-1] = R_slice[:,0]
    Z_slice[:,-1] = Z_slice[:,0]

    return R_slice, Z_slice

zeta_slices = np.array([0, 0.25, 0.5, 0.75])*2*np.pi/nfp
R_slice_plasma, Z_slice_plasma = getCrossSection(r_plasma, zetal_plasma, zeta_slices)
R_slice_coil, Z_slice_coil = getCrossSection(r_coil, zetal_coil, zeta_slices)

R_slice_outer = np.zeros((ntheta_coil+1,4))
Z_slice_outer = np.zeros((ntheta_coil+1,4))
theta_coil_big = np.linspace(0,2*np.pi,ntheta_coil+1)
for imn in range(mnmax_coil):
    for izeta in range(4):
        angle = xm_coil[imn] * theta_coil_big - xn_coil[imn] * zeta_slices[izeta]
        sinangle = np.sin(angle)
        cosangle = np.cos(angle)
        R_slice_outer[:,izeta] += rmnc_outer[imn] * cosangle + rmns_outer[imn] * sinangle
        Z_slice_outer[:,izeta] += zmnc_outer[imn] * cosangle + zmns_outer[imn] * sinangle

########################################################
# Pick the isaved values to show in the 2D plots
########################################################

max_nsaved_for_contour_plots = 16
numPlots = min(nsaved,max_nsaved_for_contour_plots)
isaved_to_plot = np.sort(list(set(map(int,np.linspace(1,nsaved,numPlots)))))
numPlots = len(isaved_to_plot)
print "isaved_to_plot:",isaved_to_plot

########################################################
# Now make plot of chi^2 over lambda scan
########################################################

figureNum = 1
fig = plt.figure(figureNum,figsize=(10,4))
fig.patch.set_facecolor('white')

numRows = 1
numCols = 2

ms=2

plt.subplot(numRows,numCols,1)
plt.loglog(chi2_M[1:],chi2_B[1:],'.-r',ms=ms)
plt.xlabel(r'$f_M$ [A$^2$ m$^2$]')
plt.ylabel(r'$f_B$ [T$^2$ m$^2$]')

plt.subplot(numRows,numCols,2)
plt.loglog(lambdas[1:],chi2_B[1:],'.-r',ms=ms)
plt.xlabel(r'$\lambda$ [T$^2$ / A$^2$]')
plt.ylabel(r'$f_B$ [T$^2$ m$^2$]')


#plt.subplot(numRows,numCols,3)
#plt.contourf(zeta_coil, theta_coil, np.transpose(ports_weight), 20)
#plt.colorbar()
#plt.xlabel('zeta (coil)',fontsize='x-small')
#plt.ylabel('theta (coil)',fontsize='x-small')
#plt.title('ports_weight',fontsize='x-small')

plt.tight_layout()
plt.subplots_adjust(top=0.95)
#plt.figtext(0.5,0.99,"Blue dots indicate the points that are plotted in later figures",horizontalalignment='center',verticalalignment='top',fontsize='small')

#plt.figtext(0.5,0.00,os.path.abspath(filename),ha='center',va='bottom',fontsize=7)

plt.figtext(0.02,0.9,'(a)',fontsize=15)
plt.figtext(0.515,0.9,'(b)',fontsize=15)

if makePDF:
    print "Saving PDF"
    plt.savefig(__file__+".pdf")
else:
    plt.show()
