import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.basemap import Basemap, shiftgrid
plt.switch_backend('agg')
#************************************************************************

def plotContour(var, lat, lon, levels=None, cmap=None, latLabel=True, lonLabel=True, ax=None, fontsize=8):
    m = Basemap(llcrnrlon=lon[0],llcrnrlat=lat[0],urcrnrlat=lat[-1],urcrnrlon=lon[-1],projection='mill', ax=ax)
    m.drawcoastlines(linewidth=0.5)
    ny = var.shape[0]; nx = var.shape[1]
    lons, lats = m.makegrid(nx, ny) # get lat/lons of ny by nx evenly space grid.
    x, y = m(lons, lats) # compute map proj coordinates.
    cmap = plt.cm.RdBu_r if cmap is None else cmap
    plot1 = m.contourf(x,y,var, levels, cmap=cmap, extend='both') if levels is not None else m.contourf(x,y,var)
    plot2 = m.contour(x,y,var, levels[::2], colors='k',linewidths=0.3)  if levels is not None else m.contour(x,y,var, linewidths=0.3)
    
    lat_label_arange = np.arange(-50,70,10) if lat[1]-lat[0] < 60 else np.arange(-50,70,20)
    if latLabel:
        m.drawparallels(lat_label_arange,labels=[1,0,0,0], fontsize=fontsize, linewidth=0.1)
    else:
        m.drawparallels(np.arange(-60,60,30),labels=[0,0,0,0], fontsize=fontsize, linewidth=0.1)
    if lonLabel:
        m.drawmeridians([120,180,240],labels=[0,0,0,1], fontsize=fontsize, linewidth=0.1)
    else:
        m.drawmeridians([120,180,240],labels=[0,0,0,0], fontsize=fontsize, linewidth=0.1)
    return plot1, m


def plotContourOnly(var, lat, lon, m, levels=None, cmap=None, colors='k'):
   
    ny = var.shape[0]; nx = var.shape[1]
    lons, lats = m.makegrid(nx, ny) # get lat/lons of ny by nx evenly space grid.
    x, y = m(lons, lats) # compute map proj coordinates.
    #cmap = plt.cm.RdBu_r if cmap is None else cmap
    plot2 = m.contour(x,y,var, levels, linewidths=0.7, cmap=cmap)  if cmap is not None else m.contour(x,y,var,linewidths=0.3, levels=levels, colors=colors)
    return 

def plotVector(u_ori, v_ori, lat_ori, lon_ori, ax=None, m=None, maximum=None, latLabel=True, lonLabel=True, fontsize=8):
    u = np.copy(u_ori)
    v = np.copy(v_ori)
    lat = np.copy(lat_ori)
    lon = np.copy(lon_ori)
    if m is None:
        if ax is None:
            m = Basemap(llcrnrlon=lon[0],llcrnrlat=lat[0],urcrnrlat=lat[-1],urcrnrlon=lon[-1],projection='mill')
        else:
            m = Basemap(llcrnrlon=lon[0],llcrnrlat=lat[0],urcrnrlat=lat[-1],urcrnrlon=lon[-1],projection='mill', ax=ax)
        m.drawcoastlines(linewidth=0.5)
        m.fillcontinents(color='0.8')
    ny = u.shape[0]/4; nx = u.shape[1]/4
    if u.shape[0]%4!=0:
        u = u[:ny*4]
        v = v[:ny*4]
        lat = lat[:ny*4]
    if u.shape[1]%4!=0:
        u = u[:,:nx*4]
        v = v[:,:nx*4]
        lon = lon[:nx*4]

    lons, lats = m.makegrid(nx, ny) # get lat/lons of ny by nx evenly space grid.
    x, y = m(lons, lats) # compute map proj coordinates.
    latNum = np.arange(lat[::4].shape[0])*4
    lonNum = np.arange(lon[::4].shape[0])*4
    # plot wind vectors on projection grid.
    # first, shift grid so it goes from -180 to 180 (instead of 0 to 360
    # in longitude).  Otherwise, interpolation is messed up.
    u_regrid = u[latNum, :]
    u_regrid = u_regrid[:, lonNum]
    v_regrid = v[latNum, :]
    v_regrid = v_regrid[:, lonNum]
    ugrid,newlons = shiftgrid(lon[0], u_regrid.real, lon[::4])
    vgrid,newlons = shiftgrid(lon[0], v_regrid.real, lon[::4])
    # transform vectors to projection grid.
    uproj,vproj,xx,yy = m.transform_vector(ugrid,vgrid,newlons,lat[::4],nx,ny, returnxy=True)
    # now plot.
    if maximum is None:
        maximum = max([np.nanmax(u), np.nanmax(v)])
        print maximum
    plot = m.quiver(x,y,uproj,vproj,scale= maximum*nx/1.5, pivot='mid', units='width')
    lat_label_arange = np.arange(-50,70,10) if lat[1]-lat[0] < 60 else np.arange(-50,70,20)
    if latLabel:
        m.drawparallels(lat_label_arange,labels=[1,0,0,0], fontsize=fontsize, linewidth=0.1)
    else:
        m.drawparallels(np.arange(-60,60,30),labels=[0,0,0,0], fontsize=fontsize, linewidth=0.1)
    if lonLabel:
        m.drawmeridians([120,180,240],labels=[0,0,0,1], fontsize=fontsize, linewidth=0.1)
    else:
        m.drawmeridians([120,180,240],labels=[0,0,0,0], fontsize=fontsize, linewidth=0.1)
    return plot

def readVariableFromNetCDF(filename, var_name):
    import netCDF4 as nc
    f   = nc.Dataset(filename)
    var = f.variables[var_name][:]
    var = np.array(var)
    return var

def getEasternSSTIndex(sst, lat, lon):
    maskLat = np.where((lat <= 5) & (lat >= -5))[0]
    maskLon = np.where((lon <= 270) & (lon >= 240))[0]
    if len(sst.shape) == 3:
        eSSTInd    = sst[:, maskLat, :]
        eSSTInd    = eSSTInd[:, :, maskLon]
        avg         = np.nanmean(eSSTInd, axis=2)
        avg         = np.nanmean(avg, axis=1)
    else:
        eSSTInd    = sst[maskLat, :]
        eSSTInd    = eSSTInd[:, maskLon]
        avg         = np.nanmean(eSSTInd)
    return avg

def getVarFromRegion(var, lat, lon, lat_range, lon_range):
    lat_plot_ind = np.where((lat>=lat_range[0])&(lat<=lat_range[1]))[0]
    lon_plot_ind = np.where((lon>=lon_range[0])&(lon<=lon_range[1]))[0]
    if var.ndim == 3:
        var_region = var[:,lat_plot_ind]
        var_region = var_region[:,:,lon_plot_ind]
    elif var.ndim == 4:
        var_region = var[:,:,lat_plot_ind]
        var_region = var_region[:,:,:,lon_plot_ind]
    else:
        var_region = var[lat_plot_ind]
        var_region = var_region[:,lon_plot_ind]
    return var_region

def get3MonthSST(center=None, forcing=None, method='large'):
    file_dir = '/DFS-B/DATA/jinyi/shihweif/ModelRuns/SST/'
    filename_base = 'SST/sst_gaussianForcing_C'+str(center)+'_'+forcing+'_'+method+'.nc' if center is not None else 'SST/sst_HadOIBl_bc_1.9x2.5_2000climo_c180511.nc'
    #filename_base = 'sst_gaussianForcing_C'+str(center)+'_'+forcing+'_'+method+'.nc' if center is not None else 'sst_HadOIBl_bc_1.9x2.5_2000climo_c180511.nc'
    var_name = 'SST_cpl_prediddle'
    lat = readVariableFromNetCDF(filename_base, 'lat')
    lon = readVariableFromNetCDF(filename_base, 'lon')
    sst = readVariableFromNetCDF(filename_base, var_name)
    sst_3months = sst[np.array([10,11,0])]
    return sst_3months, lat, lon

def get3MonthsVar(var_name, startyear, model):
    file_dir = '/DFS-B/DATA/jinyi/shihweif/ModelRuns/'
    filename_base = '.cam.h0.'+str(startyear).zfill(4)+'-11.nc'
    lat = readVariableFromNetCDF(var_name+'/'+var_name+'.'+model+filename_base, 'lat')
    lon = readVariableFromNetCDF(var_name+'/'+var_name+'.'+model+filename_base, 'lon')
    vertical = True if var_name in ['U', 'V', 'Z3'] else False
    if vertical:
        lev = readVariableFromNetCDF(var_name+'/'+var_name+'.'+model+filename_base, 'lev')

    var = np.zeros((3, lev.shape[0], lat.shape[0], lon.shape[0])) if vertical else np.zeros((3, lat.shape[0], lon.shape[0])) 
    for i in range(3):
        filename_base = '.cam.h0.'+str(startyear+(11+i)/13).zfill(4)+'-'+str((11+i)/13+(11+i)%13).zfill(2)+'.nc'
        var[i] = readVariableFromNetCDF(var_name+'/'+var_name+'.'+model+filename_base, var_name)[0]
    return var

def get3MonthsVariablesForOLR(startyear, model):
    olr_name = 'FLUT'
    olr = -1*get3MonthsVar(olr_name, startyear, model) # mm/month
    return olr

def get3MonthsVariablesForModel(startyear, model):
    u_name      = 'U'
    v_name      = 'V'
    olr_name    = 'FLUT'
    u   = get3MonthsVar(u_name, startyear, model)
    v   = get3MonthsVar(v_name, startyear, model)
    olr = get3MonthsVar(olr_name, startyear, model)
    return u, v, olr

def get3MonthsVariablesIndex(vari, lat, lon):
    east = getEasternSSTIndex(vari, lat, lon)
    return east

def getRegionalExpirimentDifference(var_exp, var_ctl, lat, lon, lat_range, lon_range):
    var_exp_plot = getVarFromRegion(var_exp, lat, lon, lat_range, lon_range)
    var_ctl_plot = getVarFromRegion(var_ctl, lat, lon, lat_range, lon_range)
    var_plot = var_exp_plot - var_ctl_plot
    return var_plot

def createSubplots(order):
    axes_posi = [[0.10, 0.53, 0.40, 0.20], [0.05, 0.05, 0.20, 0.20], [0.28, 0.05, 0.20, 0.20]
                ,[0.52, 0.53, 0.20, 0.20], [0.52, 0.05, 0.20, 0.20], [0.75, 0.53, 0.20, 0.20], [0.75, 0.05, 0.20, 0.20]]
    ax1 = plt.axes(axes_posi[order])
    return ax1

totalyear = 4
#method = 'large'
method = 'thin'
#posneg_array = ['neg4', 'neg2', 'pos2', 'pos4'] if EPCP=='EP' else ['neg1.4', 'neg0.7', 'pos0.7', 'pos1.4']

u_name = 'U'
v_name = 'V'
psl_name = 'PSL'
ght_name = 'Z3'
file_dir = '/DFS-B/DATA/jinyi/shihweif/ModelRuns/'
filename_forlatlon = u_name+'/'+u_name+'.f_present_day.cam.h0.0003-11.nc'
lat = readVariableFromNetCDF(filename_forlatlon, 'lat')
lon = readVariableFromNetCDF(filename_forlatlon, 'lon')
lev = readVariableFromNetCDF(filename_forlatlon, 'lev')

lat_range = [-20,20]
lat_plot_ind = np.where((lat>=lat_range[0])&(lat<=lat_range[1]))[0]
lat_plot = lat[lat_plot_ind]
lon_range = [120,280]
lon_plot_ind = np.where((lon>=lon_range[0])&(lon<=lon_range[1]))[0]
lon_plot = lon[lon_plot_ind]

#epcp_array = [175,255]
center = 255
#posneg_array = ['pos1.4', 'neg1.4', 'pos4', 'neg4']
posneg_all = ['neg4','neg2','neg1','pos1','pos2','pos4'] 
beginyear = 3
olr_east_exp      = np.zeros((len(posneg_all), totalyear, 3))
model_month = '11'


for s, startyear in enumerate(range(beginyear,beginyear+totalyear)):
    model_year  = str(startyear).zfill(2)
    month       = str(startyear).zfill(4)+'-11'
    model2      = 'f_present_day'
    olr_name    = 'FLUT'
    ght_name    = 'Z3'
    filename_base = '.cam.h0.'+month+'.nc'
    file_dir    = '/DFS-B/DATA/jinyi/shihweif/ModelRuns/'

    lat = readVariableFromNetCDF(olr_name+'/'+olr_name+'.'+model2+filename_base, 'lat')
    lon = readVariableFromNetCDF(olr_name+'/'+olr_name+'.'+model2+filename_base, 'lon')


    olr_ctl = get3MonthsVariablesForOLR(startyear, model2)
    olr_east = get3MonthsVariablesIndex(olr_ctl,lat,lon)


    for e, exp_model in enumerate(posneg_all):
        model1 = 'exp_C'+str(center)+'_'+exp_model+'_'+model_year+'_'+model_month+'_'+method

        olr_exp = get3MonthsVariablesForOLR(startyear, model1)
        olr_east_temp = get3MonthsVariablesIndex(olr_exp,lat,lon)
        olr_east_exp[e,s] = olr_east_temp - olr_east

lat_range = [-30,30]
lon_range = [120,280]
posneg_array = ['neg4','pos4']
lat_30_range = [-30,30]
lat_plot_ind = np.where((lat>=lat_30_range[0])&(lat<=lat_30_range[1]))[0]
lat_plot = lat[lat_plot_ind]

sst_forcing = np.zeros((2, 3, lat_plot.shape[0], lon_plot.shape[0]))
sst_ctl, lat_sst, lon_sst = get3MonthSST(method=method)
for i in range(2):
    sst_exp, lat_sst, lon_sst = get3MonthSST(center=center, forcing=posneg_array[i], method=method) 
    sst_forcing[i] = getRegionalExpirimentDifference(sst_exp, sst_ctl, lat_sst, lon_sst, lat_range, lon_range)
lat_sst_plot_ind = np.where((lat>=lat_range[0])&(lat<=lat_range[1]))[0]
lat_sst_plot = lat_sst[lat_sst_plot_ind]
lon_sst_plot_ind = np.where((lon_sst>=lon_range[0])&(lon_sst<=lon_range[1]))[0]
lon_sst_plot = lon_sst[lon_sst_plot_ind]

u_850_exp = np.zeros((2,totalyear, 3, lat_plot.shape[0], lon_plot.shape[0]))
v_850_exp = np.zeros((2,totalyear, 3, lat_plot.shape[0], lon_plot.shape[0]))
olr_all_exp   = np.zeros((2,totalyear, 3, lat_plot.shape[0], lon_plot.shape[0]))

for y, startyear in enumerate(range(beginyear,beginyear+totalyear)):
    model_year  = str(startyear).zfill(2)
    u_ctl, v_ctl, olr_ctl = get3MonthsVariablesForModel(startyear, model2)
    for p, posneg in enumerate(posneg_array):
        model1 = 'exp_C'+str(center)+'_'+posneg+'_'+model_year+'_'+model_month+'_'+method

        u_exp, v_exp, olr_exp = get3MonthsVariablesForModel(startyear, model1)

        # Get Middle Variables
        ind_850 = np.argmin(np.abs(lev-850))
        u_850_exp[p,y] = getRegionalExpirimentDifference(u_exp[:,ind_850], u_ctl[:,ind_850], lat, lon, lat_range, lon_range)
        v_850_exp[p,y] = getRegionalExpirimentDifference(v_exp[:,ind_850], v_ctl[:,ind_850], lat, lon, lat_range, lon_range)
        olr_all_exp[p,y] = getRegionalExpirimentDifference(olr_exp, olr_ctl, lat, lon, lat_range, lon_range)


plt.close('all')
import os
from matplotlib import font_manager as fm, rcParams
rcParams.update({'errorbar.capsize': 2})
#fpath = os.path.join(rcParams["datapath"], "/export/home/shihweif/ensoComplexity/fonts/Georgia.ttf")
prop_small = fm.FontProperties(fname="fonts/Georgia.ttf", size=8)
prop = fm.FontProperties(fname="fonts/Georgia.ttf", size=12)
fig = plt.figure()
fig.set_size_inches(9, 6, forward=True)
axes_posi = [[0.07, 0.56, 0.82, 0.37], [0.07, 0.05, 0.38, 0.40], [0.51, 0.05, 0.38, 0.40]]
title_texts = ['Deep convections responses to warming/cooling in the equatorial eastern Pacific',
               'OLR responses to EEP cooling (-4$^\circ$C)',
               'OLR responses to EEP warming (4$^\circ$C)']


#ax1 = createSubplots(0)
season = [0,1,2]
ax1 = plt.axes(axes_posi[0])
X = np.array([int(posneg[3:]) if posneg[0]=='p' else -1*int(posneg[3:]) for posneg in posneg_all])
east_plot = np.nanmean(olr_east_exp[:,:,season], axis=(1,2))
ax1.scatter([0],[0],c='k', label='Control')
ax1.scatter(X,east_plot,c='C0')
ax1.plot([X[0],X[1]],[east_plot[0],east_plot[1]],c='C1')
ax1.plot([X[1],X[2]],[east_plot[1],east_plot[2]],c='C1')
ax1.plot([X[2],0],[east_plot[2],0],c='C1')
ax1.plot([0,X[3]],[0,east_plot[3]],c='C1')
ax1.plot([X[3],X[4]],[east_plot[3],east_plot[4]],c='C1')
ax1.plot([X[4],X[5]],[east_plot[4],east_plot[5]],c='C1')
ax1.axhline(y=0, linestyle=':', color='k', alpha=0.8, linewidth=0.7)
ax1.set_title(title_texts[0], fontproperties=prop)
ax1.tick_params(axis='both', which='major', labelsize=8)
ax1.set_ylabel('OLR deviation to controls (mm/day)', fontproperties=prop_small)
ax1.set_xlabel('Gaussian SST center for experiments ($^\circ$C)', fontproperties=prop_small)
ax1.text(.005, .98, '(a)', horizontalalignment='left', verticalalignment='top', transform=ax1.transAxes, fontsize=10, color='k')


cmap = plt.cm.coolwarm
maximum = 4
levels = (np.arange(11.)-5)/5*50
for i in range(2):
    ax1 = plt.axes(axes_posi[i+1])
    plot, m = plotContour(np.nanmean(-1*olr_all_exp[i,:,season], axis=(0,1)), lat_plot, lon_plot, levels=levels, cmap=cmap, ax=ax1)
    plotVector(np.nanmean(u_850_exp[i,:,season], axis=(0,1)), np.nanmean(v_850_exp[i,:,season], axis=(0,1)), lat_plot, lon_plot, m=m, maximum=maximum)
    plotContourOnly(np.nanmean(sst_forcing[i,season], axis=0), lat_sst_plot, lon_sst_plot, m, colors='g', levels=[-3,-2,-1,0,1,2,3])
    ax1.set_aspect('auto')
    ax1.set_title(title_texts[i+1], fontproperties=prop)
    subText = '(b)' if i==0 else '(c)'
    ax1.text(.005, .98, subText, horizontalalignment='left', verticalalignment='top', transform=ax1.transAxes, fontsize=10, color='k')
    #plots.append(plot)



cbar_ax = fig.add_axes([0.93, 0.1, 0.015, 0.8])
cbar=plt.colorbar(plot, cax=cbar_ax)
fig_dir = '/export/home/shihweif/ensoComplexity/Figure/ModelRuns/'
plt.savefig('Figure/ExtendedFigure5.png', dpi=800)







