# =======================================================
# Custom Functions
# =======================================================
def readVarible(filename, variable, vertical=False):
    fileName ='ModelData/' + filename
    f   = nc.Dataset(fileName)
    if vertical:
        z = f.variables['depth'][:]
        z = np.array(z)
        f.close()
        return z
    else:
        var = f.variables[variable][:,:,:]
        var = np.array(var)                
        f.close()
        return var

def getNewLatLon():
    lat = [-50, 65]
    lon = [122, 290]
    newlat = np.arange((lat[1]-lat[0])+1)+lat[0]
    newlon  = np.arange((lon[1]-lon[0])/1.5+1)*1.5+lon[0]
    return [newlat, newlon]
        

def getNino34Index(ssta, lat=None, lon=None):
    if lat is None:
        lat = [-20, 20]
    if lon is None:
        lon = [122, 290]
    newlat  = np.arange(lat[1]-(lat[0])+1)+lat[0]
    newlon  = np.arange((lon[1]-lon[0])/1.5+1)*1.5+lon[0]
    maskLat = np.where((newlat < 5) & (newlat > -5))[0]
    maskLon = np.where((newlon < 240) & (newlon > 190))[0]
    if len(ssta.shape) == 3:
        nino34  = ssta[:, maskLat, :]
        nino34  = nino34[:, :, maskLon]
        avg     = np.nanmean(nino34, axis=2)
        avg     = np.nanmean(avg, axis=1)
    else:
        nino34  = ssta[maskLat, :]
        nino34  = nino34[:, maskLon]
        avg     = np.nanmean(nino34)
    return avg
                
def getEasternSSTIndex(sst, lat=None, lon=None):
    if lat is None:
        lat = [-20, 20]
    if lon is None:
        lon = [122, 290]
    newlat  = np.arange(lat[1]-(lat[0])+1)+lat[0]
    newlon  = np.arange((lon[1]-lon[0])/1.5+1)*1.5+lon[0]
    maskLat = np.where((newlat <= 5) & (newlat >= -5))[0]
    maskLon = np.where((newlon <= 270) & (newlon >= 240))[0]
    if len(sst.shape) == 3:
        eSSTInd    = sst[:, maskLat, :]
        eSSTInd    = eSSTInd[:, :, maskLon]
        avg         = np.nanmean(eSSTInd, axis=2)
        avg         = np.nanmean(avg, axis=1)
    else:
        eSSTInd    = sst[maskLat, :]
        eSSTInd    = eSSTInd[:, maskLon]
        avg         = np.nanmean(eSSTInd)
    return avg
    
def getWesternSSTIndex(sst, lat=None, lon=None):
    if lat is None:
        lat = [-20, 20]
    if lon is None:
        lon = [122, 290]
    newlat  = np.arange(lat[1]-(lat[0])+1)+lat[0]
    newlon  = np.arange((lon[1]-lon[0])/1.5+1)*1.5+lon[0]
    maskLat = np.where((newlat <= 5) & (newlat >= -5))[0]
    maskLon = np.where((newlon <= 190) & (newlon >= 160))[0]
    if len(sst.shape) == 3:
        eSSTInd    = sst[:, maskLat, :]
        eSSTInd    = eSSTInd[:, :, maskLon]
        avg         = np.nanmean(eSSTInd, axis=2)
        avg         = np.nanmean(avg, axis=1)
    else:
        eSSTInd    = sst[maskLat, :]
        eSSTInd    = eSSTInd[:, maskLon]
        avg         = np.nanmean(eSSTInd)
    return avg

def plotFigure1(eSST, wSST, ssta, winda, sst, sf_mode):
    unCom       = UnclassifiedComputations()
    unCom.plotSchematicOfENSOComplexityNaturePaper1(eSST, wSST, ssta, winda, sst, sf_mode)
    
def plotExtendedFigure1(ssta, lon_grid, meof, nino34):    
    onWOENSO    = OnsetWithoutENSO(1)
    onWOENSO.plotThreeSFEvolutionsForENSOComplexityNaturePaper1(ssta, lon_grid, meof, nino34)

def plotExtendedFigure2(eSST, negsf, ePrep, ssta, winda, hgt200a):
    unCom       = UnclassifiedComputations()
    unCom.plotFigure2ForSFAsymmetryGRL(eSST, negsf, ePrep, ssta, winda, hgt200a, lat=[-30,30])

def plotExtendedFigure3(wSST, sf, wPrep, ssta, winda, hgt850a):
    unCom       = UnclassifiedComputations()
    unCom.plotFigure3ForSFAsymmetryGRL(wSST, sf, wPrep, ssta, winda, hgt850a, lat=[-30,30])
    

    
        
# =======================================================
# Import Packages
# =======================================================
import netCDF4 as nc
import numpy as np
import os.path
#from scipy.interpolate import griddata
import glob

# =======================================================
# Custom Packages
# =======================================================
from EOF import EOF
from MEOF import MEOF
from OnsetWithoutENSO import OnsetWithoutENSO
from UnclassifiedComputations import UnclassifiedComputations

# =======================================================
# Main Codes
# =======================================================

# Prepare data for plot
grid                = getNewLatLon()
lat                 = np.where((grid[0] <= 20) & (grid[0] >= -20))[0] 
sst_65N50S          = readVarible('HADISST.sst0.nc', 'tos')
sst_anomaly_65N50S  = EOF.removeSeasonal(sst_65N50S, False, True)
ssta_65N50S         = sst_anomaly_65N50S.reshape((sst_65N50S.shape[0], sst_65N50S.shape[1], sst_65N50S.shape[2]))
ssta_20N20S         = ssta_65N50S[:, lat, :]

uwind_65N50S        = readVarible('NCEP.uwind0.nc', 'uwind')
vwind_65N50S        = readVarible('NCEP.vwind0.nc', 'vwind')
wind_65N50S         = np.concatenate((uwind_65N50S, vwind_65N50S), axis=2)
wind_anomaly_65N50S = EOF.removeSeasonal(wind_65N50S, False, True)
winda_65N50S        = wind_anomaly_65N50S.reshape((wind_65N50S.shape[0], wind_65N50S.shape[1], wind_65N50S.shape[2]))
winda_20N20S        = winda_65N50S[:, lat, :]

precip_65N50S       = readVarible('NCEP.precip0.nc', 'precip')
hgt_65N50S          = readVarible('NCEP.hgt0.nc', 'hgt')

hgt850_anomaly_65N50S   = EOF.removeSeasonal(hgt_65N50S[:,:,:,2], False, True)
hgt850a_65N50S          = hgt850_anomaly_65N50S.reshape((hgt_65N50S.shape[0], hgt_65N50S.shape[1], hgt_65N50S.shape[2]))
hgt200_anomaly_65N50S   = EOF.removeSeasonal(hgt_65N50S[:,:,:,9], False, True)
hgt200a_65N50S          = hgt200_anomaly_65N50S.reshape((hgt_65N50S.shape[0], hgt_65N50S.shape[1], hgt_65N50S.shape[2]))

nino34              = getNino34Index(ssta_20N20S)
eSSTIndex           = getEasternSSTIndex(sst_65N50S[:,lat,:])
wSSTIndex           = getWesternSSTIndex(sst_65N50S[:,lat,:])
ePrepIndex          = getEasternSSTIndex(precip_65N50S[:,lat,:])
wPrepIndex          = getWesternSSTIndex(precip_65N50S[:,lat,:])


eof_sst_0           = EOF('HA_NC_GE_0_HADISST_sst_eof_0.nc')
eof_wind_0          = EOF('HA_NC_GE_0_NCEP_wind_eof_0.nc')
eof_sl_0            = EOF('HA_NC_GE_0_GECCO2_sl_eof_0.nc')
meof1_0             = MEOF(eof_sst_0, eof_wind_0, eof_sl_0, np.array([1, 1, 1]), 'HA_NC_GE_0',  '(SST, TAU, SSH)') 

# plot Figure 1
#plotFigure1(eSSTIndex, wSSTIndex, ssta_20N20S, winda_20N20S, sst_65N50S[:,lat,:], meof1_0.pc[1])

# plot Extended Figures
plotExtendedFigure1(ssta_20N20S, grid[1], meof1_0, nino34)

#lat_30N30S = np.where((grid[0] <= 30) & (grid[0] >= -30))[0] 
#lat_65N20S = np.where((grid[0] <= 65) & (grid[0] >= -20))[0] 
#plotExtendedFigure2(eSSTIndex, -1*meof1_0.pc[1], ePrepIndex, ssta_65N50S[:,lat_30N30S,:], winda_65N50S[:,lat_30N30S,:], hgt200a_65N50S[:,lat_30N30S,:])
#plotExtendedFigure3(wSSTIndex, meof1_0.pc[1], wPrepIndex, ssta_65N50S[:,lat_30N30S,:], winda_65N50S[:,lat_30N30S,:], hgt850a_65N50S[:,lat_65N20S,:])

