#!/usr/bin/env python3

print("Add any arguments to save a PDF")

import matplotlib as mpl
import matplotlib.pyplot as plt
#from matplotlib import cm
import numpy as np
from scipy.io import netcdf
from scipy.interpolate import interp1d
import sys
import os
from matplotlib.colors import ListedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

makePDF = False
if len(sys.argv) > 1:
    makePDF = True

#filename = 'regcoil_out.20200706-01-001-c09r00_withPorts_lambda1e-15_Picard_thetaZeta128_mpolNtor32_ns2_mgrid.nc'
#filename = 'regcoil_out.20200716-01-059_nthetanzetaPlasma128_nthetanzetaCoil129_mpolntor64_ns3_dinit0.1_Picard_ports.nc'
#filename = 'regcoil_out.20200901-01-011-regcoilpm_famus_benchmark.nc'
filename = 'regcoil_out.20200901-01-015_regcoilpm_famus_benchmark.nc'

arrowwidth=0.003

################################################################
# Read FAMUS data
################################################################

famus_filename = '/Users/mattland/Downloads/init_orient_pm_nonorm_5E4_q4_dp.focus'

ox, oy, oz, M, p, mp, mt = np.loadtxt(famus_filename, delimiter=',', skiprows=3, usecols=(3,4,5,7,8,10,11), unpack=True)

r = np.sqrt(ox * ox + oy * oy)
phi = np.arctan2(oy, ox)

rho = p ** 4
# See https://github.com/zhucaoxiang/CoilPy/blob/master/coilpy/dipole.py
mm = rho
mx = mm * np.sin(mt) * np.cos(mp) 
my = mm * np.sin(mt) * np.sin(mp) 
mz = mm * np.cos(mt)

sinphi = np.sin(phi)
cosphi = np.cos(phi)
mr = cosphi * mx + sinphi * my

print(ox.shape)
print('max phi:', np.max(phi))
print('min phi:', np.min(phi))
nfp = 3

#newshape = (64, 64, 14)
newshape = (14, 64, 64)

r = np.reshape(r, newshape)
oz = np.reshape(oz, newshape)
rho = np.reshape(rho, newshape)
phi = np.reshape(phi, newshape)
mr = np.reshape(mr, newshape)
mz = np.reshape(mz, newshape)

# Close in theta:
r = np.concatenate((r, r[:,0:1,:]), axis=1)
oz = np.concatenate((oz, oz[:,0:1,:]), axis=1)
rho = np.concatenate((rho, rho[:,0:1,:]), axis=1)
phi = np.concatenate((phi, phi[:,0:1,:]), axis=1)
                   
maxphi = 2 * np.pi / (nfp * 2.0 * 64)
mask = phi < maxphi
print('maxphi: ',maxphi)
print('mask[:,:,0]: ', mask[:,:,0])


print('x shape:', r[:,:,0].shape)
print('y shape:', oz[:,:,0].shape)

N = 10
colors = np.ones((N, 4))
colors[:,0] = np.linspace(1,0,N)
colors[:,2] = np.linspace(1,0,N)
cmap = ListedColormap(colors)
#index = int(N * 0.1)
#colors[:index, :] = 1

################################################################
# Done reading FAMUS data
################################################################









f = netcdf.netcdf_file(filename,'r',mmap=False)
nfp = f.variables['nfp'][()]
ntheta_plasma = f.variables['ntheta_plasma'][()]
ntheta_coil = f.variables['ntheta_coil'][()]
nzeta_plasma = f.variables['nzeta_plasma'][()]
nzeta_coil = f.variables['nzeta_coil'][()]
nzetal_plasma = f.variables['nzetal_plasma'][()]
nzetal_coil = f.variables['nzetal_coil'][()]
theta_plasma = f.variables['theta_plasma'][()]
theta_coil = f.variables['theta_coil'][()]
zeta_plasma = f.variables['zeta_plasma'][()]
zeta_coil = f.variables['zeta_coil'][()]
zetal_plasma = f.variables['zetal_plasma'][()]
zetal_coil = f.variables['zetal_coil'][()]
r_plasma  = f.variables['r_plasma'][()]
r_coil  = f.variables['r_coil'][()]
chi2_B = f.variables['chi2_B'][()]
chi2_M = f.variables['chi2_M'][()]
max_M = f.variables['max_M'][()]
min_M = f.variables['min_M'][()]
abs_M = f.variables['abs_M'][()]
rmnc_outer = f.variables['rmnc_outer'][()]
rmns_outer = f.variables['rmns_outer'][()]
zmnc_outer = f.variables['zmnc_outer'][()]
zmns_outer = f.variables['zmns_outer'][()]
mnmax_coil = f.variables['mnmax_coil'][()]
xm_coil = f.variables['xm_coil'][()]
xn_coil = f.variables['xn_coil'][()]
ports_weight = f.variables['ports_weight'][()]
d = f.variables['d'][()]
ns_magnetization = f.variables['ns_magnetization'][()]
Bnormal_from_TF_and_plasma_current = f.variables['Bnormal_from_TF_and_plasma_current'][()]
Bnormal_total = f.variables['Bnormal_total'][()]
Mvec = f.variables['magnetization_vector'][()]
r_coil = f.variables['r_coil'][()]
# = f.variables[''][()]

print('Mvec.shape: ', Mvec.shape)
# Mvec.shape:  (12, 3, 1, 129, 129)
# (nsaved, RZetaZ, ns_magnetization, nzeta_coil, ntheta_coil)

print('r_coil.shape: ', r_coil.shape)
# (387, 129, 3)
# (nzetal_coil, ntheta_coil, xyz)
R_coil = np.sqrt(r_coil[:,:,0] ** 2 + r_coil[:,:,1] ** 2)
Z_coil = r_coil[:,:,2]
#exit(0)

nlambda = f.variables['nlambda'][()]
nsaved = len(chi2_B)
lambdas = f.variables['lambda'][()]

f.close()

if np.max(np.abs(lambdas)) < 1.0e-200:
    print("lambda array appears to be all 0. Changing it to all 1 to avoid a python error.")
    lambdas += 1

########################################################
# Sort in order of lambda, since for a lambda search (general_option=4 or 5),
# the output arrays are in the order of the search, which is not so convenient for plotting.
########################################################

if nsaved == nlambda:
    permutation = np.argsort(lambdas)
    lambdas = lambdas[permutation]
    chi2_M = chi2_M[permutation]
    chi2_B = chi2_B[permutation]
    Bnormal_total = Bnormal_total[permutation,:,:]
    max_M = max_M[permutation]
    min_M = min_M[permutation]
    if lambdas[-1]>1.0e199:
        lambdas[-1] = np.inf

########################################################
# For 3D plotting, 'close' the arrays in u and v
########################################################

#r_plasma  = np.append(r_plasma,  r_plasma[[0],:,:], axis=0)
#r_plasma  = np.append(r_plasma,  r_plasma[:,[0],:], axis=1)
#zetal_plasma = np.append(zetal_plasma,nfp)

#r_coil  = np.append(r_coil,  r_coil[[0],:,:], axis=0)
#r_coil  = np.append(r_coil,  r_coil[:,[0],:], axis=1)
#zetal_coil = np.append(zetal_coil,nfp)

########################################################
# Extract cross-sections of the 3 surfaces at several toroidal angles
########################################################

def getCrossSection(rArray, zetal_old, zeta_new):
    zetal_old = np.concatenate((zetal_old-2*np.pi,zetal_old, zetal_old+2*np.pi))
    rArray = np.concatenate((rArray,rArray,rArray),axis=0)


    print("zetal_old shape:",zetal_old.shape)
    print("rArray shape:",rArray.shape)

    x = rArray[:,:,0]
    y = rArray[:,:,1]
    z = rArray[:,:,2]
    R = np.sqrt(x**2 + y**2)


    ntheta = z.shape[1]
    nzeta_new = len(zeta_new)
    R_slice = np.zeros([nzeta_new,ntheta+1])
    Z_slice = np.zeros([nzeta_new,ntheta+1])
    for itheta in range(ntheta):
        interpolator = interp1d(zetal_old, R[:,itheta])
        R_slice[:,itheta] = interpolator(zeta_new)
        interpolator = interp1d(zetal_old, z[:,itheta])
        Z_slice[:,itheta] = interpolator(zeta_new)
    R_slice[:,-1] = R_slice[:,0]
    Z_slice[:,-1] = Z_slice[:,0]

    return R_slice, Z_slice

zeta_slices = np.array([0, 0.25, 0.5, 0.75])*2*np.pi/nfp
R_slice_plasma, Z_slice_plasma = getCrossSection(r_plasma, zetal_plasma, zeta_slices)
R_slice_coil, Z_slice_coil = getCrossSection(r_coil, zetal_coil, zeta_slices)

R_slice_outer = np.zeros((ntheta_coil+1,4))
Z_slice_outer = np.zeros((ntheta_coil+1,4))
theta_coil_big = np.linspace(0,2*np.pi,ntheta_coil+1)
for imn in range(mnmax_coil):
    for izeta in range(4):
        angle = xm_coil[imn] * theta_coil_big - xn_coil[imn] * zeta_slices[izeta]
        sinangle = np.sin(angle)
        cosangle = np.cos(angle)
        R_slice_outer[:,izeta] += rmnc_outer[imn] * cosangle + rmns_outer[imn] * sinangle
        Z_slice_outer[:,izeta] += zmnc_outer[imn] * cosangle + zmns_outer[imn] * sinangle

########################################################
# Now make plot of surfaces at given toroidal angle
########################################################

figureNum = 1
#fig = plt.figure(figureNum,figsize=(10,4))
#fig = plt.figure(figureNum,figsize=(10,3.6))
fig = plt.figure(figureNum,figsize=(9.7,7))
fig.patch.set_facecolor('white')

numRows = 2
numCols = 3

Rmin = R_slice_outer.min()
Rmax = R_slice_outer.max()
Zmin = Z_slice_outer.min()
Zmax = Z_slice_outer.max()

dR = Rmax - Rmin
dZ = Zmax - Zmin
margin = 0.09

index = np.argmin(np.abs(zeta_coil - 0.25*2*np.pi/3))

markerSize = 3
angles = ['0', '(1/4)(2\pi/3)', '(1/2)(2\pi/3)']

#################################
# FAMUS plots
#################################


indices = [0, 31, 63]
for whichPlot in range(3):
    plt.subplot(numRows,numCols,whichPlot+1)
    plt.fill(R_slice_plasma[whichPlot,:], Z_slice_plasma[whichPlot,:], facecolor=[1,0.8,0.8], label='Plasma region')
    plt.plot(R_slice_plasma[whichPlot,:], Z_slice_plasma[whichPlot,:], 'r-', label='Plasma boundary')
    #plt.fill([[0,0],[0.1,0.1]], [[0,1],[1,0]], label='Magnetization region',facecolor=[0,1,0])
    plt.fill([0,0.1], [0,1], label='Magnetization region',facecolor=[0,1,0])
    zeta = zeta_slices[whichPlot]
    index = indices[whichPlot]
    cnt = plt.contourf(r[:,:,index], oz[:,:,index], rho[:,:,index], 10, cmap=cmap, vmin=0, vmax=1, label='Magnetization region')
    # This is the fix for the white lines between contour levels
    for c in cnt.collections:
        c.set_edgecolor("face")
    stride = 4
    plt.quiver(r[::stride,:-1,index], oz[::stride,:-1,index], mr[::stride,:,index], mz[::stride,:,index], pivot='mid', width=arrowwidth, scale=16)
    #plt.fill(R_slice_outer[:,whichPlot], Z_slice_outer[:,whichPlot], label='Magnetization region',facecolor=[0.7,1,0.7])
    #plt.plot(R_slice_outer[:,whichPlot], Z_slice_outer[:,whichPlot], 'g-', label='Outer magnetization surface')
    #plt.fill(R_slice_coil[whichPlot,:], Z_slice_coil[whichPlot,:], facecolor=[1,1,1])
    #plt.plot(R_slice_coil[whichPlot,:], Z_slice_coil[whichPlot,:], 'b-', label='Inner magnetization surface')


    #if whichPlot == 1:
    #    mask = ports_weight[index,:] > 10
    #    # Close the curve:
    #    mask = np.append(mask,mask[0])
    #    x = R_slice_coil[whichPlot,mask]
    #    y = Z_slice_coil[whichPlot,mask]
    #    perm = np.argsort(y)
    #    plt.plot(x[perm], y[perm], 'c:', label='Port')
    #elif whichPlot == 2:
    #    plt.plot([0,0],[0,1],'c:', label='Port')

    plt.gca().set_aspect('equal',adjustable='box')
    plt.title(r'FAMUS,  $\phi = ' + angles[whichPlot] + '$')
    plt.xlabel('R [meters]')
    plt.ylabel('Z [meters]')
    plt.xlim([Rmin - dR * margin, Rmax + dR * margin])
    plt.ylim([Zmin - dZ * margin, Zmax + dZ * margin])
    if whichPlot==2:
        plt.legend(fontsize=7, framealpha=1.0)
        ax2 = plt.gca()
        cax = inset_axes(ax2,
                           width="5%",  # width = 5% of parent_bbox width
                           height="100%",  # height : 50%
                           loc='lower left',
                           bbox_to_anchor=(1.05, 0., 1, 1),
                           bbox_transform=ax2.transAxes,
                           borderpad=0,
                       )
        plt.colorbar(cnt, ticks=np.linspace(0,1,6), cax=cax)

#################################
# REGCOIL_PM plots
#################################

indices = [int(nzeta_coil * x) for x in [0, 0.25, 0.5]]
for whichPlot in range(3):
    plt.subplot(numRows,numCols,whichPlot+4)
    zeta = zeta_slices[whichPlot]
    plt.fill(R_slice_outer[:,whichPlot], Z_slice_outer[:,whichPlot], facecolor=[0,1,0])
    #plt.plot(R_slice_outer[:,whichPlot], Z_slice_outer[:,whichPlot], 'g-', label='Outer magnetization surface')
    plt.fill(R_slice_coil[whichPlot,:], Z_slice_coil[whichPlot,:], facecolor=[1,1,1])
    #plt.plot(R_slice_coil[whichPlot,:], Z_slice_coil[whichPlot,:], 'b-', label='Inner magnetization surface')
    plt.fill(R_slice_plasma[whichPlot,:], Z_slice_plasma[whichPlot,:], facecolor=[1,0.8,0.8], label='Plasma region')
    plt.plot(R_slice_plasma[whichPlot,:], Z_slice_plasma[whichPlot,:], 'r-', label='Plasma boundary')
    plt.fill([0,0.1], [0,1], label='Magnetization region',facecolor=[0,1,0])
    index = indices[whichPlot]
    print('index=',index)
    stride=2

    R_arrow = (R_slice_outer[:,whichPlot] + R_slice_coil[whichPlot,:]) / 2
    Z_arrow = (Z_slice_outer[:,whichPlot] + Z_slice_coil[whichPlot,:]) / 2
    print('R_arrow.shape:', R_arrow.shape[:-1:stride])
    print('Mvec.shape:', Mvec[-1,0,0,index,::stride].shape)
    #plt.quiver(R_coil[index,::stride], Z_coil[index,::stride], Mvec[-1,0,0,index,::stride], Mvec[-1,2,0,index,::stride], pivot='mid', width=arrowwidth, zorder=999)
    plt.quiver(R_arrow[:-1:stride], Z_arrow[:-1:stride], Mvec[-1,0,0,index,::stride], Mvec[-1,2,0,index,::stride], pivot='mid', width=arrowwidth, zorder=999)


    #if whichPlot == 1:
    #    mask = ports_weight[index,:] > 10
    #    # Close the curve:
    #    mask = np.append(mask,mask[0])
    #    x = R_slice_coil[whichPlot,mask]
    #    y = Z_slice_coil[whichPlot,mask]
    #    perm = np.argsort(y)
    #    plt.plot(x[perm], y[perm], 'c:', label='Port')
    #elif whichPlot == 2:
    #    plt.plot([0,0],[0,1],'c:', label='Port')

    plt.gca().set_aspect('equal',adjustable='box')
    plt.title(r'REGCOIL_PM,  $\phi = ' + angles[whichPlot] + '$')
    plt.xlabel('R [meters]')
    plt.ylabel('Z [meters]')
    plt.xlim([Rmin - dR * margin, Rmax + dR * margin])
    plt.ylim([Zmin - dZ * margin, Zmax + dZ * margin])
    if whichPlot==2:
        plt.legend(fontsize=7, framealpha=1.0)

#plt.tight_layout()
plt.subplots_adjust(left=0.06, bottom=0.07, right=0.94, top=0.96, wspace=0.34, hspace=0.32)
plt.figtext(0.999, 0.99, r'$M / M_t$', fontsize=12, ha='right', va='top')

abc=['a','b','c']
defg=['d','e','f']
for j in range(3):
    x = 0.015 + 0.32*j
    plt.figtext(x, 0.999, '('+abc[j]+')', fontsize=15, va='top')
    plt.figtext(x, 0.49, '('+defg[j]+')', fontsize=15, va='top')

if makePDF:
    print("Saving PDF")
    plt.savefig(__file__ + ".pdf")
else:
    plt.show()


