import netCDF4 as nc
import numpy as np
import os.path
from scipy.interpolate import griddata
import glob

class ModelData:
    _lat = [-50, 65]
    _lat2 = [-30, 30]
    _lat3 = [-30, 20]
    _lon = [122, 290]
    _lon2 = [122, 350]
#    _newLat = np.arange(20-(-20)+1)-20
#    _newLon = np.arange((290-122)/1.5+1)*1.5+122
#    _newLon = np.arange((285-124.5)/1.5+1)*1.5+124.5
    
    def getVariable(self, var, version, period, **kwargs):
        print 'getVariable: {}, {}, {}'.format(var, version, period)
        switcher = {
            'sst':                  self.getSST,
            'wind':                 self.getWind,
            'sea level':            self.getSealevel,
            '20C Isotherm':         self.get20CIsotherm,
            'sea level pressure':   self.getSeaLevelPressure,
            'Geopotential Height':  self.getGeopotentialHeight,
            'precipitation':        self.getPrecipitation,
            'u Ocean':              self.getUOcean,
            'w50 Ocean':            self.getW50Ocean,
            'temp Ocean':           self.getTempOcean,
            'wind speed':           self.getSurfaceWindSpeed,
            'sensible heat':        self.getSensibleHeatFlux,
            'latent heat':          self.getLatentHeatFlux,
            'Net Q flux':           self.getNetQFlux,
            'vertical profile':     self.getVerticalProfile,
            'PMM vertical profile': self.getVerticalProfileAlongPMM
            }
        # Get the function from switcher dictionary
        func = switcher.get(var, lambda: "no such variable")
        # Get the period number
        switcher2 = {
            'N65toS50':     0,
            'N30toS30':     7,
            'N30toS20':     4,
            'pg':           3,
            'pd':           2,
            'PacAtlc':      9,
            '201650S65N':   11,
        }
        period = switcher2.get(period, 99)
        # Execute the function
        return func(version, period, **kwargs)
            
    def getSST(self, version, period=None, **kwargs):
        if isinstance(version, bool):
            return self.getSSTXue(version)
        else:
            return self.getSSTMine(version, period, **kwargs)
        
    def getSSTXue(self, is1982Data):
        return
        
    def getSSTMine(self, version , period, grid=False):
        switcher = {
            'OISST':    0,
            'ICOADS':   1,
            'HADISST':  2,
            'ERSST':    3,
            'GECCO2':   4,
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
        sst_attr    = 'sst'
#        time_attr   = 'time'
        if case == 0:
            list_of_files = glob.glob('../SST/tos_OISST_L4_AVHRR-only-v2_*nc')
            sst_attr = 'tos'
        elif case == 1:
            list_of_files = glob.glob('../SST/ICOADS.sst.mean.nc')
        elif case == 2:
            if period==11:
                list_of_files = glob.glob('../SST/HadISST_sst_2018.nc')
            else:
                list_of_files = glob.glob('../SST/HadISST_sst.nc')
            lat_attr = 'latitude'
            lon_attr = 'longitude'
        elif case == 3:
            list_of_files = glob.glob('../SST/ERSST.sst.mnmean.v3.nc')
        elif case == 4:
            list_of_files = glob.glob('../SST/GECCO2.temp29_34_55_lev5.nc')
            sst_attr    = 'temp'
        else:
            return "no such version"
        lat, lon            = self.getLatLon(list_of_files[0], lat_attr, lon_attr)
        lon_bnd = ModelData._lon
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        elif period == 9:
            lat_bnd = ModelData._lat
            lon_bnd = ModelData._lon2
        elif period == 11:
            lat_bnd = ModelData._lat
        else:
            print 'No such period'
            lat_bnd = None
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, lon_bnd)

        sst = self.getData(list_of_files, sst_attr, maskLat, maskLon)
        if case == 0:
            return "no such period"
#            if period == 1:
#                sst = sst[160:400,:,:] - 273.15
#            else: 
#                return "no such period"
        elif case == 1:
            return "no such period"
#            if period == 2:
#                sst = sst[120:360,:,:]
#            else: 
#                return "no such period"
        elif case == 2:
#            sst = sst[1056:1740,:,:]
#            if period == 0:
#                sst = sst[1200:1740,:,:]
            if period == 11:
                sst = sst[936:1764,:,:]
        elif case == 3:
            sst = sst[1248:1932,:,:]
#            if period == 0:
#                sst = sst[1392:1932,:,:]
#            else:
#                return "no such period"
        elif case == 4:
            return "no such period"
#            if period == 0:
#                sst = sst[264:504,:,:] 
#            elif period == 1:
#                sst = sst[564:804,:,:]
#            elif period == 2:
#                sst = sst[264:804,:,:]
#            else:
#                return "no such period"
        else:
            return "no such case"
        sst = self.regrid(sst, lat[maskLat], lon[maskLon], 'nearest', period)
#        sst = sst + 273.15 if period == 1 else sst

        self.saveAsNetCDF(sst, version+'.sst'+str(period)+'.nc', 'tos', period)
        if grid is not None:
            if grid:
                newLat = np.arange(ModelData._lat[1]-(ModelData._lat[0])+1)+ModelData._lat[0]
                newLon = np.arange((ModelData._lon[1]-(ModelData._lon[0]))/1.5+1)*1.5+ModelData._lon[0]
                return sst, [newLat, newLon]
        return sst
        
    def getWind(self, version, period=None):
        if isinstance(version, bool):
            return self.getWindXue(version)
        else:
            return self.getWindMine(version, period)
    
    def getWindXue(self, is1982Data):
        return
    
    def getWindMine(self, version, period):
        switcher = {
            'ICOADS':   0,
            'NCEP':     1,
            'GECCO2':   2,
            'JRA55':    3,
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files  = glob.glob('../Wind/ICOADS.*pstr.mean.nc') # wind stress
            u_attr  = 'upstr'
            v_attr  = 'vpstr'
        elif case == 1:
            if period == 11:
                files  = glob.glob('../Wind/NCEP.2018.*wnd.mon.mean.nc')
            else:
                files  = glob.glob('../Wind/NCEP.*wnd.mon.mean.nc')# wind
            files.sort()
            u_attr  = 'uwnd'
            v_attr  = 'vwnd'
        elif case == 2:
            files  = glob.glob('../Wind/GECCO2.f*29_34_55.nc') # wind stress, "N/m^2"
            u_attr  = 'fu'
            v_attr  = 'fv'
        elif case == 3:
            files  = glob.glob('../Wind/JRA55.*wind.nc') # wind
            u_attr  = 'var33'
            v_attr  = 'var34'
        else:
            return "no such version"
        lat, lon    = self.getLatLon(files[0], lat_attr, lon_attr)
        lon_bnd = ModelData._lon
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        elif period == 9:
            lat_bnd = ModelData._lat
            lon_bnd = ModelData._lon2
        elif period == 11:
            lat_bnd = ModelData._lat
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, lon_bnd)
        vwind       = self.getWindData(files[1], v_attr, maskLat, maskLon)
        uwind       = self.getWindData(files[0], u_attr, maskLat, maskLon)
        if case == 0:
            return "no such period"
#            if period == 1:
#                vwind = vwind[120:360,:,:]
#                uwind = uwind[120:360,:,:]
#            else: 
#                return "no such period"
        elif case == 1:
            if period == 11:
                vwind = vwind[:828,:,:]
                uwind = uwind[:828,:,:]
            else:
                vwind = vwind[120:804,:,:]
                uwind = uwind[120:804,:,:]
#            if period == 0:
#                vwind = vwind[264:804,:,:]
#                uwind = uwind[264:804,:,:]
#            else:
#                return "no such period"
        elif case == 2:
            return "no such period"
#            if period == 0:
#                vwind = vwind[264:504,:,:]
#                uwind = uwind[264:504,:,:]
#            elif period == 1:
#                vwind = vwind[564:804,:,:]
#                uwind = uwind[564:804,:,:]
#            elif period == 2:
#                vwind = vwind[264:804,:,:]
#                uwind = uwind[264:804,:,:]
#            else:
#                return "no such period"
        elif case == 3:
            vwind = vwind[:684]
            uwind = uwind[:684]
#            if period == 0:
#                vwind = vwind[144:684]
#                uwind = uwind[144:684]
        else:
            return "no such period"
        vwind       = self.regrid(vwind, lat[maskLat], lon[maskLon], 'nearest', period)
        uwind       = self.regrid(uwind, lat[maskLat], lon[maskLon], 'nearest', period)
        fwind       = np.concatenate((uwind, vwind), axis=2)
#        fwind[np.isnan(fwind)] = 0

        self.saveAsNetCDF(uwind, version+'.uwind'+str(period)+'.nc', 'uwind', period)
        self.saveAsNetCDF(vwind, version+'.vwind'+str(period)+'.nc', 'vwind', period)
        return fwind

    def getSealevel(self, version, period=None):
        if isinstance(version, bool):
            return self.getSealevelXue(version)
        else:
            return self.getSealevelMine(version, period)    
    
    def getSealevelXue(self, is1982Data):
#        if is1982Data:
#            # lat: -74.5~64.5 lon: 0.5~359.5, 1/3X1 degree
#            list_of_files = glob.glob('Sea level height/GODAS/sshg.*.nc')
##            lat, lon = self.getLatLon(list_of_files[0], 'lat', 'lon')
#        else:
#            # lat: 89 ~ -89 lon: 1~359, 2X2 degree
#            list_of_files = glob.glob('Sea level height/zeta29_34_55.nc')
##            list_of_files = glob.glob('Sea level height/zeta.nc')
##            lat, lon = self.getLatLon(list_of_files[0], 'LATITUDE_T', 'LONGITUDE_T')
#        lat, lon = self.getLatLon(list_of_files[0], 'lat', 'lon')
#        maskLat = np.where((lat <= 20) & (lat >= -20))[0]
#        maskLon = np.where((lon <= 285) & (lon >= 124))[0]
#        if is1982Data:
#            sl = self.getData(list_of_files, 'sshg', maskLat, maskLon)
#            sl = sl[24:229,:,:]*100
#            sl = self.regrid(sl, lat[maskLat], lon[maskLon], 'nearest')
#            self.saveAsNetCDF(sl, 'sl.nc', 'sshg', is1982Data)
#        else:
#            sl = self.getData(list_of_files, 'zeta', maskLat, maskLon)
#            sl = sl[192:408,:,:]*100
##            sl = self.getData(list_of_files, 'ZETA', maskLat, maskLon)
##            sl = sl[144:360,:,:]*100
#            sl = self.regrid(sl, lat[maskLat], lon[maskLon], 'nearest')
#            self.saveAsNetCDF(sl, 'hisSl.nc', 'sshg', is1982Data)            
#        return sl
        return
        
    def getSealevelMine(self, version, period):
        switcher = {
            'GODAS':    0,
            'SODA':     1,
            'GECCO2':   2,
            'ORAS4':    3
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files   = glob.glob('../Sea level height/GODAS/sshg.*.nc')
            sl_attr = 'sshg'
        elif case == 1:
            files   = glob.glob('../Sea level height/SODA/SODA.slAll.nc')
            sl_attr = 'sshg'
#            files   = glob.glob('Sea level height/SODA/SODA_2.2.4_*.cdf')
#            sl_attr = ['ssh', 'SSH']
        elif case == 2:
            if period==11:
                files   = glob.glob('../Sea level height/zeta2016_29_34_70.nc')
            else:
                files   = glob.glob('../Sea level height/zeta29_34_55.nc')
            sl_attr = 'zeta'
        elif case == 3:
            files   = glob.glob('../Sea level height/ORAS4/zos_oras4_1m_*_grid_1x1.nc')
            sl_attr = 'zos'
        else:
            return "no such version"
        lat, lon = self.getLatLon(files[0], lat_attr, lon_attr)
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        elif period == 11:
            lat_bnd = ModelData._lat
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, ModelData._lon)
        sl = self.getData(files, sl_attr, maskLat, maskLon)
        if case == 0:
            return "no such period"
#            if period == 1:
#                sl = sl[180:420,:,:]*100
#            if period == 2:
#                sl = sl*100
#            else:
#                return "no such period"
        elif case == 1:
            return "no such period"
#            if period == 0:
#                sl = sl[1188:1428,:,:]*100
#            elif period == 1:
#                sl = sl[1488:1680,:,:]*100
#            elif period == 2:
#                sl = sl[1188:1680,:,:]*100
#            elif period == 3:
#                sl = sl[1548:1680,:,:]*100
#            else:
#                return "no such period"
#            if period == 0:
#                return "no such period"
#            else:
#                return "no such period"
        elif case == 2:
            if period ==11:
                sl = sl[:828,:,:]*100
            else:
                sl = sl[120:804,:,:]*100
#            if period == 0:
#                sl = sl[264:804,:,:]*100
#            else:
#                return "no such period"
        elif case == 3:
            sl = sl[:684,:,:]*100
#            if period == 0:
#                sl = sl[144:684,:,:]*100
#            else:
#                return "no such period"
        else:
            return "no such case"
        sl = self.regrid(sl, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        if case == 0 and period == 2:
#            sl_SODA = self.readVarible('SODA.sl0.nc', 'sshg')
#            missingIndex_0 = np.where(np.isnan(sl_SODA[0]))
#            missingIndex_1 = np.where(np.isnan(sl[0]))
##            print missingIndex
#            sl = np.concatenate((sl_SODA[:120,:,:],sl[:420,:,:]), axis=0)
#            sl[:,missingIndex_0[0],missingIndex_0[1]] = np.NaN
#            sl[:,missingIndex_1[0],missingIndex_1[1]] = np.NaN
##            sl_temp = np.ones([120+420, sl.shape[1], sl.shape[2]], dtype=sl.dtype) * np.NaN
##            sl_temp = sl_temp.astype(sl.dtype)
##            sl[nonMissingIndex,:] = sl
#        if sl.shape[0] == 240:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 192:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', [16, period])
#        elif sl.shape[0] == 540:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 492:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', [41, period])
#        elif sl.shape[0] == 180:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 804:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        else:
#            return 'No such length'
        return sl
    
    def getNetQFlux(self, version, period):
        switcher = {
            'GODAS':    0,
            'SODA':     1,
            'GECCO2':   2,
            'ORAS4':    3
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files   = glob.glob('../Sea level height/GODAS/sshg.*.nc')
            qnet_attr = 'sshg'
        elif case == 1:
            files   = glob.glob('../Sea level height/SODA/SODA.slAll.nc')
            qnet_attr = 'sshg'
#            files   = glob.glob('Sea level height/SODA/SODA_2.2.4_*.cdf')
#            sl_attr = ['ssh', 'SSH']
        elif case == 2:
            files   = glob.glob('../Net Q Flux/qnet29_34_70.nc')
            qnet_attr = 'qnet'
        elif case == 3:
            files   = glob.glob('../Sea level height/ORAS4/zos_oras4_1m_*_grid_1x1.nc')
            qnet_attr = 'zos'
        else:
            return "no such version"
        lat, lon = self.getLatLon(files[0], lat_attr, lon_attr)
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, ModelData._lon)
        qnet = self.getData(files, qnet_attr, maskLat, maskLon)
        if case == 0:
            return "no such period"
        elif case == 1:
            return "no such period"
        elif case == 2:
            qnet = qnet[120:804,:,:]
        elif case == 3:
            qnet = qnet[:684,:,:]
        else:
            return "no such case"
        qnet = self.regrid(qnet, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(qnet, version+'.qnet'+str(period)+'.nc', 'qnet', period)
        return qnet

    def get20CIsotherm(self, version , period, grid=False):
        switcher = {
            'GODAS':    0,
            'SODA':     1,
            'GECCO2':   2,
            'ORAS4':    3
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
        z_attr      = 'Depth'
#        time_attr   = 'time'
        if case == 0:
            files   = glob.glob('../Sea level height/GODAS/sshg.*.nc')
            temp_attr = 'sshg'
        elif case == 1:
            files   = glob.glob('../Sea level height/SODA/SODA.slAll.nc')
            temp_attr = 'sshg'
#            files   = glob.glob('Sea level height/SODA/SODA_2.2.4_*.cdf')
#            sl_attr = ['ssh', 'SSH']
        elif case == 2:
            lat_attr    = 'y'
            lon_attr    = 'x'
            files   = glob.glob('../Temp/Depth*.nc')
            temp_attr = 'temp'
        elif case == 3:
            files   = glob.glob('../Sea level height/ORAS4/zos_oras4_1m_*_grid_1x1.nc')
            temp_attr = 'zos'
        else:
            return "no such version"
        lat, lon = self.getLatLon(files[0], lat_attr, lon_attr)
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, ModelData._lon)
        temp = self.getData(files, temp_attr, maskLat, maskLon, sCase=True)
        z    = self.getDepth(files, z_attr)
        if case == 0:
            return "no such period"
        elif case == 1:
            return "no such period"
        elif case == 2:
            temp = temp[120:804,:,:]
        elif case == 3:
            temp = temp[:684,:,:]
        else:
            return "no such case"
        z20 = self.getIsothermFromTemperature(temp, z, 20)
        z20 = self.regrid(z20, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(z20, version+'.z20'+str(period)+'.nc', 'z20', period)
        return z20
    
    def getVerticalProfile(self, version, period, variable='Temp', lat_bnd=[-2,2]):
        lat_attr    = 'y'
        lon_attr    = 'x'
        z_attr      = 'Depth'
        if variable == 'Temp':
            files   = glob.glob('../Temp/temp_lev5To105.nc')
            var_attr = 'temp'
        elif variable == 'Salt':
            files   = glob.glob('../Salinity/salt_lev5To105.nc')
            var_attr = 'salt'
        else:
            return 'no such variable'
        lat, lon = self.getLatLon(files[0], lat_attr, lon_attr)
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, ModelData._lon)
        var = self.getData(files, var_attr, maskLat, maskLon, sCase=True)
        z    = self.getDepth(files, z_attr)
        varLatAvg = np.nanmean(var, axis=2)
        varGrided = self.regridVertical(np.swapaxes(varLatAvg,1,2), lon[maskLon], np.array(z), 'nearest')
        self.saveVerticalAsNetCDF(varGrided, 'vertical.'+var_attr+'.'+str(lat_bnd)+'.nc', var_attr, depth=z)
        return varGrided
    
    def getDataAlongPMM(self, data, lat, lon):           
        pmm_lon_S = [170, 190]
        pmm_lon_N = [220, 240]
        pmm_lat = [0, 20]
        diff_lat = lat[1]-lat[0]
        total_steps = int(np.ceil((pmm_lat[1]-pmm_lat[0])/float(diff_lat))+1)
        lon_step_diff = (pmm_lon_N[0]-pmm_lon_S[0])/float(total_steps-1)
        pmm_vertical_profile = np.zeros((data.shape[0],total_steps,data.shape[1]))
        pmm_lat_final = np.zeros(total_steps)
        pmm_lon_final = np.zeros(total_steps)
        for i in range(total_steps):
            lat_lim = [pmm_lat[0]+diff_lat*i-diff_lat/2., pmm_lat[0]+diff_lat*i+diff_lat/2.]
            lon_lim = [pmm_lon_S[0]+lon_step_diff*i, pmm_lon_S[1]+lon_step_diff*i]
            maskLat, maskLon = self.getMaskLatLon(lat, lon, lat_lim, lon_lim)
            print lat[maskLat]
            print lon[maskLon]
            data_in_lon = data[:,:,:,maskLon]
            data_in_line = data_in_lon[:,:,maskLat]
            pmm_vertical_profile[:,i,:] = np.nanmean(data_in_line, axis=(2,3))
            pmm_lat_final[i] = pmm_lat[0]+diff_lat*i
            pmm_lon_final[i] = (pmm_lon_S[0]+lon_step_diff*i+ pmm_lon_S[1]+lon_step_diff*i)/2.
        return pmm_vertical_profile, pmm_lat_final, pmm_lon_final
    
    
    def getVerticalProfileAlongPMM(self, version, period, variable='Temp'):
        lat_attr    = 'y'
        lon_attr    = 'x'
        z_attr      = 'Depth'
        if variable == 'Temp':
            files   = glob.glob('../Temp/temp_lev5To105.nc')
            var_attr = 'temp'
        elif variable == 'Salt':
            files   = glob.glob('../Salinity/salt_lev5To105.nc')
            var_attr = 'salt'
        else:
            return 'no such variable'
        lat, lon = self.getLatLon(files[0], lat_attr, lon_attr)
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, ModelData._lat, ModelData._lon)
        var = self.getData(files, var_attr, maskLat, maskLon, sCase=True)
        var_processed, pmm_lat, pmm_lon = self.getDataAlongPMM(var, lat[maskLat], lon[maskLon])
        z    = self.getDepth(files, z_attr)
#        varLatAvg = np.nanmean(var, axis=2)
#        varGrided = self.regridVertical(np.swapaxes(varLatAvg,1,2), lon[maskLon], np.array(z), 'nearest')
        self.saveVerticalPMMAsNetCDF(var_processed, 'vertical.'+var_attr+'.pmm.nc', var_attr, pmm_lat, pmm_lon, z)
        return var_processed
    
    def getSeaLevelPressure(self, version , period, grid=False):
        switcher = {
            'ICOADS':   0,
            'NCEP':     1,
            'GECCO2':   2,
            'JRA55':    3,
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files  = glob.glob('../Wind/ICOADS.*pstr.mean.nc') # wind stress
            slp_attr  = 'upstr'
        elif case == 1:
            files  = glob.glob('../SLP/NCEP.slp.mon.mean.nc') # wind
            slp_attr  = 'slp'
        elif case == 2:
            files  = glob.glob('../Wind/GECCO2.f*29_34_55.nc') # wind stress, "N/m^2"
            slp_attr  = 'fu'
        elif case == 3:
            files  = glob.glob('../Wind/JRA55.*wind.nc') # wind
            slp_attr  = 'var33'
        else:
            return "no such version"
        lat, lon    = self.getLatLon(files[0], lat_attr, lon_attr)
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        if period == 11:
            lat_bnd = ModelData._lat
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, ModelData._lon)
        slp = self.getData(files, slp_attr, maskLat, maskLon)
        if case == 0:
            return "no such period"
        elif case == 1:
            if period ==11:
                slp = slp[:828,:,:]
            else:
                slp = slp[120:804,:,:]
        elif case == 2:
            return "no such period"
        elif case == 3:
            slp = slp[:684]
        else:
            return "no such period"
        slp = self.regrid(slp, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(slp, version+'.slp'+str(period)+'.nc', 'slp', period)
        return slp
    
    def getGeopotentialHeight(self, version , period, grid=False):
        switcher = {
            'ICOADS':   0,
            'NCEP':     1,
            'GECCO2':   2,
            'JRA55':    3,
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files  = glob.glob('../Wind/ICOADS.*pstr.mean.nc') # wind stress
            hgt_attr  = 'upstr'
        elif case == 1:
            files  = glob.glob('../HGT/NCEP.hgt.mon.mean.nc') # wind
            hgt_attr  = 'hgt'
        elif case == 2:
            files  = glob.glob('../Wind/GECCO2.f*29_34_55.nc') # wind stress, "N/m^2"
            hgt_attr  = 'fu'
        elif case == 3:
            files  = glob.glob('../Wind/JRA55.*wind.nc') # wind
            hgt_attr  = 'var33'
        else:
            return "no such version"
        lat, lon    = self.getLatLon(files[0], lat_attr, lon_attr)
        lon_bnd = ModelData._lon
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        elif period == 9:
            lat_bnd = ModelData._lat
            lon_bnd = ModelData._lon2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, lon_bnd)
        z    = self.getDepthFromOne(files[0], 'level')
        hgt = self.getData(files, hgt_attr, maskLat, maskLon, depth=z.shape[0])
        print hgt.shape
        if case == 0:
            return "no such period"
        elif case == 1:
            hgt = hgt[120:804]
        elif case == 2:
            return "no such period"
        elif case == 3:
            hgt = hgt[:684]
        else:
            return "no such period"
        hgt = self.regrid(hgt, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(hgt, version+'.hgt'+str(period)+'.nc', 'hgt', period, depth=z)
        return hgt
    
    def getSurfaceWindSpeed(self, version , period, grid=False):
        switcher = {
            'ICOADS':   0,
            'NCEP':     1,
            'GECCO2':   2,
            'JRA55':    3,
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files  = glob.glob('../Wind/ICOADS.*pstr.mean.nc') # wind stress
            wspd_attr  = 'upstr'
        elif case == 1:
            files  = glob.glob('../windSpeed/NCEP.wspd.mon.mean.nc') # wind speed "m/s"
            wspd_attr  = 'wspd'
        elif case == 2:
            files  = glob.glob('../Wind/GECCO2.f*29_34_55.nc') # wind stress, "N/m^2"
            wspd_attr  = 'fu'
        elif case == 3:
            files  = glob.glob('../Wind/JRA55.*wind.nc') # wind
            wspd_attr  = 'var33'
        else:
            return "no such version"
        lat, lon    = self.getLatLon(files[0], lat_attr, lon_attr)
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, ModelData._lon)
        wspd = self.getData(files, wspd_attr, maskLat, maskLon)
        if case == 0:
            return "no such period"
        elif case == 1:
            wspd = wspd[120:804,:,:]
        elif case == 2:
            return "no such period"
        elif case == 3:
            wspd = wspd[:684]
        else:
            return "no such period"
        wspd = self.regrid(wspd, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(wspd, version+'.wspd'+str(period)+'.nc', 'wspd', period)
        return wspd
    
    def getLatentHeatFlux(self, version , period, grid=False):
        switcher = {
            'ICOADS':   0,
            'NCEP':     1,
            'GECCO2':   2,
            'JRA55':    3,
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files  = glob.glob('../Wind/ICOADS.*pstr.mean.nc') # wind stress
            lheat_attr  = 'upstr'
        elif case == 1:
            files  = glob.glob('../latentHeatFlux/NCEP.lhtfl.sfc.mon.mean.nc') # "W/m^2"
            lheat_attr  = 'lhtfl'
        elif case == 2:
            files  = glob.glob('../Wind/GECCO2.f*29_34_55.nc') # wind stress, "N/m^2"
            lheat_attr  = 'fu'
        elif case == 3:
            files  = glob.glob('../Wind/JRA55.*wind.nc') # wind
            lheat_attr  = 'var33'
        else:
            return "no such version"
        lat, lon    = self.getLatLon(files[0], lat_attr, lon_attr)
        lon_bnd = ModelData._lon
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, lon_bnd)
        lheat = self.getData(files, lheat_attr, maskLat, maskLon)
        if case == 0:
            return "no such period"
        elif case == 1:
            lheat = lheat[120:804,:,:]
        elif case == 2:
            return "no such period"
        elif case == 3:
            lheat = lheat[:684]
        else:
            return "no such period"
        lheat = self.regrid(lheat, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(lheat, version+'.lhtfl'+str(period)+'.nc', 'lhtfl', period)
        return lheat
    
    def getSensibleHeatFlux(self, version , period, grid=False):
        switcher = {
            'ICOADS':   0,
            'NCEP':     1,
            'GECCO2':   2,
            'JRA55':    3,
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files  = glob.glob('../Wind/ICOADS.*pstr.mean.nc') # wind stress
            sheat_attr  = 'upstr'
        elif case == 1:
            files  = glob.glob('../sensibleHeatFlux/NCEP.shtfl.sfc.mon.mean.nc') # "W/m^2"
            sheat_attr  = 'shtfl'
        elif case == 2:
            files  = glob.glob('../Wind/GECCO2.f*29_34_55.nc') # wind stress, "N/m^2"
            sheat_attr  = 'fu'
        elif case == 3:
            files  = glob.glob('../Wind/JRA55.*wind.nc') # wind
            sheat_attr  = 'var33'
        else:
            return "no such version"
        lat, lon    = self.getLatLon(files[0], lat_attr, lon_attr)
        lon_bnd = ModelData._lon
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, lon_bnd)
        sheat = self.getData(files, sheat_attr, maskLat, maskLon)
        if case == 0:
            return "no such period"
        elif case == 1:
            sheat = sheat[120:804,:,:]
        elif case == 2:
            return "no such period"
        elif case == 3:
            sheat = sheat[:684]
        else:
            return "no such period"
        sheat = self.regrid(sheat, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(sheat, version+'.shtfl'+str(period)+'.nc', 'shtfl', period)
        return sheat
    
    def getPrecipitation(self, version , period, grid=False):
        switcher = {
            'ICOADS':   0,
            'NCEP':     1,
            'GECCO2':   2,
            'JRA55':    3,
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files  = glob.glob('../Precipitation/') 
            slp_attr  = 'precip'
        elif case == 1:
            files  = glob.glob('../Precipitation/precip.mon.anom.nc') 
            slp_attr  = 'precip'
        elif case == 2:
            files  = glob.glob('../Wind/GECCO2.f*29_34_55.nc') # wind stress, "N/m^2"
            slp_attr  = 'fu'
        elif case == 3:
            files  = glob.glob('../Precipitation/full_data_monthly_v2018_05.nc') 
            slp_attr  = 'precip'
        else:
            return "no such version"
        
        lat, lon    = self.getLatLon(files[0], lat_attr, lon_attr)
        lon_bnd = ModelData._lon
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        elif period == 9:
            lat_bnd = ModelData._lat
            lon_bnd = ModelData._lon2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, lon_bnd)
        precip = self.getData(files, slp_attr, maskLat, maskLon)
        
        if case == 0:
            return "no such period"            
        elif case == 1:
            precip = precip[120:804,:,:]
        elif case == 2:
            return "no such period"
        elif case == 3:
            precip[precip<-1000] = np.nan
            precip = precip[67*12:(67+57),:,:]
        else:
            return "no such period"
        precip = self.regrid(precip, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(precip, version+'.precip'+str(period)+'.nc', 'precip', period)
        return precip
    
    @staticmethod
    def getIsothermFromTemperature(temperature, z, isovalue):
        # Code rewrote from ra_isotherm(temp,Z,isovalue) in MATLAB
        temp = temperature.reshape(temperature.shape[0], temperature.shape[1], temperature.shape[2]*temperature.shape[3])
        nonMissingIndex = np.where(np.logical_not(np.isnan(temp[0,0])))[0]
        therm = np.ones((temp.shape[0], temp.shape[2]))*np.NaN
        for k in range(temp.shape[0]):
            for i in nonMissingIndex:
                te = temp[k,:,i]
                pos = np.where(te<isovalue)
                if (pos[0].shape[0]>0) and (pos[0][0]>0):
                    p2 = pos[0][0]
                    p1 = p2 - 1
#                    print 'p1:{}, p2:{}'.format(p1,p2)
#                    print 'z1:{}, z2:{}'.format(z[p1],z[p2])
#                    print 'te1:{}, te2:{}'.format(te[p1], te[p2])
                    therm[k, i] = np.interp(20, [te[p1], te[p2]],[z[p1],z[p2]])
#                elif (pos[0].shape[0]>0) and (pos[0][0]<=0):
#                    therm[k, i] = 0
                else:
                    therm[k, i] = np.NaN
        return therm.reshape(temperature.shape[0], temperature.shape[2], temperature.shape[3])
            
    def getUOcean(self, version, period):
        switcher = {
            'GODAS':    0,
            'SODA':     1,
            'GECCO2':   2,
            'ORAS4':    3
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files   = glob.glob('../uOcean/GODAS/sshg.*.nc')
            uo_attr = 'sshg'
        elif case == 1:
            files   = glob.glob('../uOcean/SODA/SODA.slAll.nc')
            uo_attr = 'sshg'
#            files   = glob.glob('Sea level height/SODA/SODA_2.2.4_*.cdf')
#            sl_attr = ['ssh', 'SSH']
        elif case == 2:
            files   = glob.glob('../uOcean/u29_34_55_Depth105.nc')
            uo_attr = 'u'
        elif case == 3:
            files   = glob.glob('../uOcean/ORAS4/zos_oras4_1m_*_grid_1x1.nc')
            uo_attr = 'zos'
        else:
            return "no such version"
        lat, lon = self.getLatLon(files[0], lat_attr, lon_attr)
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, ModelData._lon)
        z      = self.getDepthFromOne(files[0], 'Depth')
        uOcean = self.getData(files, uo_attr, maskLat, maskLon, depth=z.shape[0])
        if case == 0:
            return "no such period"
#            if period == 1:
#                sl = sl[180:420,:,:]*100
#            if period == 2:
#                sl = sl*100
#            else:
#                return "no such period"
        elif case == 1:
            return "no such period"
#            if period == 0:
#                sl = sl[1188:1428,:,:]*100
#            elif period == 1:
#                sl = sl[1488:1680,:,:]*100
#            elif period == 2:
#                sl = sl[1188:1680,:,:]*100
#            elif period == 3:
#                sl = sl[1548:1680,:,:]*100
#            else:
#                return "no such period"
#            if period == 0:
#                return "no such period"
#            else:
#                return "no such period"
        elif case == 2:
            uOcean = uOcean[120:804,:,:]
#            if period == 0:
#                sl = sl[264:804,:,:]*100
#            else:
#                return "no such period"
        elif case == 3:
            uOcean = uOcean[:684,:,:]*100
#            if period == 0:
#                sl = sl[144:684,:,:]*100
#            else:
#                return "no such period"
        else:
            return "no such case"
        uOcean = self.regrid(uOcean, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(uOcean, version+'.uOcean'+str(period)+'.nc', 'uOcean', period, depth=z)
#        if case == 0 and period == 2:
#            sl_SODA = self.readVarible('SODA.sl0.nc', 'sshg')
#            missingIndex_0 = np.where(np.isnan(sl_SODA[0]))
#            missingIndex_1 = np.where(np.isnan(sl[0]))
##            print missingIndex
#            sl = np.concatenate((sl_SODA[:120,:,:],sl[:420,:,:]), axis=0)
#            sl[:,missingIndex_0[0],missingIndex_0[1]] = np.NaN
#            sl[:,missingIndex_1[0],missingIndex_1[1]] = np.NaN
##            sl_temp = np.ones([120+420, sl.shape[1], sl.shape[2]], dtype=sl.dtype) * np.NaN
##            sl_temp = sl_temp.astype(sl.dtype)
##            sl[nonMissingIndex,:] = sl
#        if sl.shape[0] == 240:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 192:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', [16, period])
#        elif sl.shape[0] == 540:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 492:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', [41, period])
#        elif sl.shape[0] == 180:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 804:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        else:
#            return 'No such length'
        return uOcean
    
    def getW50Ocean(self, version, period):
        switcher = {
            'GODAS':    0,
            'SODA':     1,
            'GECCO2':   2,
            'ORAS4':    3
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files   = glob.glob('../w50Ocean/GODAS/sshg.*.nc')
            wo_attr = 'sshg'
        elif case == 1:
            files   = glob.glob('../w50Ocean/SODA/SODA.slAll.nc')
            wo_attr = 'sshg'
#            files   = glob.glob('Sea level height/SODA/SODA_2.2.4_*.cdf')
#            sl_attr = ['ssh', 'SSH']
        elif case == 2:
            files   = glob.glob('../w50Ocean/w50_29_34_55_Depth50.nc')
            wo_attr = 'w'
        elif case == 3:
            files   = glob.glob('../w50Ocean/ORAS4/zos_oras4_1m_*_grid_1x1.nc')
            wo_attr = 'zos'
        else:
            return "no such version"
    
        lat, lon = self.getLatLon(files[0], lat_attr, lon_attr)
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, ModelData._lon)
        z      = self.getDepthFromOne(files[0], 'Depth')
        w50Ocean = self.getData(files, wo_attr, maskLat, maskLon, depth=z.shape[0])
        if case == 0:
            return "no such period"
#            if period == 1:
#                sl = sl[180:420,:,:]*100
#            if period == 2:
#                sl = sl*100
#            else:
#                return "no such period"
        elif case == 1:
            return "no such period"
#            if period == 0:
#                sl = sl[1188:1428,:,:]*100
#            elif period == 1:
#                sl = sl[1488:1680,:,:]*100
#            elif period == 2:
#                sl = sl[1188:1680,:,:]*100
#            elif period == 3:
#                sl = sl[1548:1680,:,:]*100
#            else:
#                return "no such period"
#            if period == 0:
#                return "no such period"
#            else:
#                return "no such period"
        elif case == 2:
            w50Ocean = w50Ocean[120:804,:,:]
#            if period == 0:
#                sl = sl[264:804,:,:]*100
#            else:
#                return "no such period"
        elif case == 3:
            w50Ocean = w50Ocean[:684,:,:]*100
#            if period == 0:
#                sl = sl[144:684,:,:]*100
#            else:
#                return "no such period"
        else:
            return "no such case"
        w50Ocean = self.regrid(w50Ocean, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(w50Ocean, version+'.w50Ocean'+str(period)+'.nc', 'w50Ocean', period, depth=z)
#        if case == 0 and period == 2:
#            sl_SODA = self.readVarible('SODA.sl0.nc', 'sshg')
#            missingIndex_0 = np.where(np.isnan(sl_SODA[0]))
#            missingIndex_1 = np.where(np.isnan(sl[0]))
##            print missingIndex
#            sl = np.concatenate((sl_SODA[:120,:,:],sl[:420,:,:]), axis=0)
#            sl[:,missingIndex_0[0],missingIndex_0[1]] = np.NaN
#            sl[:,missingIndex_1[0],missingIndex_1[1]] = np.NaN
##            sl_temp = np.ones([120+420, sl.shape[1], sl.shape[2]], dtype=sl.dtype) * np.NaN
##            sl_temp = sl_temp.astype(sl.dtype)
##            sl[nonMissingIndex,:] = sl
#        if sl.shape[0] == 240:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 192:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', [16, period])
#        elif sl.shape[0] == 540:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 492:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', [41, period])
#        elif sl.shape[0] == 180:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 804:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        else:
#            return 'No such length'
        return w50Ocean
    
    def getTempOcean(self, version, period):
        switcher = {
            'GODAS':    0,
            'SODA':     1,
            'GECCO2':   2,
            'ORAS4':    3
        }
        case = switcher.get(version, 99)
        lat_attr    = 'lat'
        lon_attr    = 'lon'
#        time_attr   = 'time'
        if case == 0:
            files   = glob.glob('../uOcean/GODAS/sshg.*.nc')
            to_attr = 'sshg'
        elif case == 1:
            files   = glob.glob('../uOcean/SODA/SODA.slAll.nc')
            to_attr = 'sshg'
#            files   = glob.glob('Sea level height/SODA/SODA_2.2.4_*.cdf')
#            sl_attr = ['ssh', 'SSH']
        elif case == 2:
            files   = glob.glob('../tempOcean/temp29_34_55_Depth105.nc')
            to_attr = 'temp'
        elif case == 3:
            files   = glob.glob('../uOcean/ORAS4/zos_oras4_1m_*_grid_1x1.nc')
            to_attr = 'zos'
        else:
            return "no such version"
        lat, lon = self.getLatLon(files[0], lat_attr, lon_attr)
        if period == 0:
            lat_bnd = ModelData._lat
        elif period == 4:
            lat_bnd = ModelData._lat3
        elif period == 7:
            lat_bnd = ModelData._lat2
        else:
            print 'No such period'
#        lon_bnd             = ModelData._lon2 if period == 1 else ModelData._lon
        maskLat, maskLon    = self.getMaskLatLon(lat, lon, lat_bnd, ModelData._lon)
        z    = self.getDepthFromOne(files[0], 'Depth')
        tempO = self.getData(files, to_attr, maskLat, maskLon, depth=z.shape[0])
        if case == 0:
            return "no such period"
#            if period == 1:
#                sl = sl[180:420,:,:]*100
#            if period == 2:
#                sl = sl*100
#            else:
#                return "no such period"
        elif case == 1:
            return "no such period"
#            if period == 0:
#                sl = sl[1188:1428,:,:]*100
#            elif period == 1:
#                sl = sl[1488:1680,:,:]*100
#            elif period == 2:
#                sl = sl[1188:1680,:,:]*100
#            elif period == 3:
#                sl = sl[1548:1680,:,:]*100
#            else:
#                return "no such period"
#            if period == 0:
#                return "no such period"
#            else:
#                return "no such period"
        elif case == 2:
            tempO = tempO[120:804,:,:]
#            if period == 0:
#                sl = sl[264:804,:,:]*100
#            else:
#                return "no such period"
        elif case == 3:
            tempO = tempO[:684,:,:]*100
#            if period == 0:
#                sl = sl[144:684,:,:]*100
#            else:
#                return "no such period"
        else:
            return "no such case"
        tempO = self.regrid(tempO, lat[maskLat], lon[maskLon], 'nearest', period)
        self.saveAsNetCDF(tempO, version+'.tempO'+str(period)+'.nc', 'tempO', period, depth=z)
#        if case == 0 and period == 2:
#            sl_SODA = self.readVarible('SODA.sl0.nc', 'sshg')
#            missingIndex_0 = np.where(np.isnan(sl_SODA[0]))
#            missingIndex_1 = np.where(np.isnan(sl[0]))
##            print missingIndex
#            sl = np.concatenate((sl_SODA[:120,:,:],sl[:420,:,:]), axis=0)
#            sl[:,missingIndex_0[0],missingIndex_0[1]] = np.NaN
#            sl[:,missingIndex_1[0],missingIndex_1[1]] = np.NaN
##            sl_temp = np.ones([120+420, sl.shape[1], sl.shape[2]], dtype=sl.dtype) * np.NaN
##            sl_temp = sl_temp.astype(sl.dtype)
##            sl[nonMissingIndex,:] = sl
#        if sl.shape[0] == 240:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 192:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', [16, period])
#        elif sl.shape[0] == 540:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 492:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', [41, period])
#        elif sl.shape[0] == 180:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        elif sl.shape[0] == 804:
#            self.saveAsNetCDF(sl, version+'.sl'+str(period)+'.nc', 'sshg', period)
#        else:
#            return 'No such length'
        return tempO
    
    @staticmethod
    def getData(files, variable, maskLat, maskLon, sCase=False, depth=None):
#        print files
        fvar = np.zeros([0, maskLat.size, maskLon.size, depth]) if depth else np.zeros([0, maskLat.size, maskLon.size])
        files.sort()
        for fileName in files:
            print fileName
            f       = nc.Dataset(fileName)
            if isinstance(variable, basestring):
                var     = f.variables[variable][:,:]
            else:
                vs      = f.variables
                for vari in variable:
                    if vari in vs:
                        var = f.variables[vari][:,:]
            var     = np.array(var)
            if sCase:
                if fvar.ndim == 3:
                    fvar = np.zeros([ var.shape[0], 0, maskLat.size, maskLon.size]) 
                var     = var[:,:,maskLat,:]
                var     = var[:,:,:,maskLon]
                var[var>1000] = np.nan
                var[var<-1000] = np.nan
                print var.shape
                print fvar.shape
                fvar    = np.concatenate((fvar, var), axis=1) 
            elif depth:
                var     = np.swapaxes(var,1,2) if var.ndim  == 4 else var
                var     = np.swapaxes(var,2,3) if var.ndim  == 4 else var
                var     = var[:,maskLat,:]
                var     = var[:,:,maskLon]
#                var[var>2000] = np.nan
                var[var<-1000] = np.nan
                fvar    = np.concatenate((fvar, var), axis=0) 
            else:
                var     = var[:,0,:,:] if var.ndim  == 4 else var
                var     = var.reshape((1, var.shape[0], var.shape[1]))if var.ndim == 2 else var
                var     = var[:,maskLat,:]
                var     = var[:,:,maskLon]
                var[var>2000] = np.nan
                var[var<-1000] = np.nan
                fvar    = np.concatenate((fvar, var), axis=0) 
        return fvar
        
    @staticmethod
    def getWindData(files, variable, maskLat, maskLon):
        print files
        f       = nc.Dataset(files)
        var     = f.variables[variable][:,:,:]
        var     = np.array(var)
        var     = var[:,0,:,:] if var.ndim  == 4 else var
        var     = var[:,maskLat,:]
        var     = var[:,:,maskLon]
        var[var>1000] = np.nan
        var[var<-1000] = np.nan
        f.close()
        return var
        
    @staticmethod    
    def getLatLon(fileName, lat, lon):
        f       = nc.Dataset(fileName)
        lat     = f.variables[lat][:]
        lon     = f.variables[lon][:]
        lon[lon<0] += 360
        return lat, lon
    
    @staticmethod    
    def getDepth(files, depth):
        deps = []
        if len(files) == 1:
            f       = nc.Dataset(files[0])
            deps     = f.variables[depth][:]
        else:
            for fileName in files:
                print fileName
                f       = nc.Dataset(fileName)
                dep     = f.variables[depth][:]
                deps.append(dep[0])
        return deps
    
    @staticmethod    
    def getDepthFromOne(fileName, depth):
        f       = nc.Dataset(fileName)
        dep     = f.variables[depth][:]
        dep     = np.array(dep) 
        return dep
        
    
    @staticmethod
    def getMaskLatLon(lat, lon, latLim, lonLim):            
        maskLat = np.where((lat <= latLim[1]) & (lat >= latLim[0]))[0]
        maskLon = np.where((lon <= lonLim[1]) & (lon >= lonLim[0]))[0]
        maskLat = maskLat[::-1] if lat[1] > 0 else maskLat
        return maskLat, maskLon
    
    @staticmethod
    def getTime(f, timevar):            
        time_var = f.variables[timevar]
        dtime = nc.num2date(time_var[:],time_var.units)
        return dtime
    
        
    @staticmethod    
    def regrid(var, lat, lon, method, period):
        ndim = var.ndim
        months = np.arange(var.shape[0])
        newlon = np.arange((ModelData._lon[1]-ModelData._lon[0])/1.5+1)*1.5+ModelData._lon[0]
        if period == 0:
            newlat = np.arange(ModelData._lat[1]-(ModelData._lat[0])+1)+ModelData._lat[0]
        elif period == 4:
            newlat = np.arange(ModelData._lat3[1]-(ModelData._lat3[0])+1)+ModelData._lat3[0]
        elif period == 7:
            newlat = np.arange(ModelData._lat2[1]-(ModelData._lat2[0])+1)+ModelData._lat2[0]
        elif period == 9:
            newlat = np.arange(ModelData._lat[1]-(ModelData._lat[0])+1)+ModelData._lat[0]
            newlon = np.arange((ModelData._lon2[1]-ModelData._lon2[0])/1.5+1)*1.5+ModelData._lon2[0]
        elif period == 11:
            newlat = np.arange(ModelData._lat[1]-(ModelData._lat[0])+1)+ModelData._lat[0]
        else:
            print 'No such period'
        #create mesh
        X, Y = np.meshgrid(lon, lat)
        XI, YI = np.meshgrid(newlon, newlat)
        regrid = np.zeros((months.size, newlat.size, newlon.size, var.shape[3])) if ndim == 4 else np.zeros((months.size, newlat.size, newlon.size))
#        month = 0
        for month in range(months.size):
            if ndim == 4:
                for d in range(var.shape[3]):
                    regrid[month,:,:,d] = griddata((X.flatten(),Y.flatten()), var[month,:,:,d].flatten(), (XI, YI), method = method)
            else:    
                regrid[month,:,:] = griddata((X.flatten(),Y.flatten()), var[month,:,:].flatten(), (XI, YI), method = method)
#        regrid[month,:,:] = griddata((X.flatten(),Y.flatten()), var[month,:,:].flatten(), (XI, YI), method = 'nearest')
        return regrid
    
    @staticmethod    
    def regridVertical(var, lon, z, method):
        months = np.arange(var.shape[0])
        newlon = np.arange((ModelData._lon[1]-ModelData._lon[0])/1.5+1)*1.5+ModelData._lon[0]

        #create mesh
        X, Y = np.meshgrid(z, lon)
        XI, YI = np.meshgrid(z, newlon)
        regrid = np.zeros((months.size, newlon.size, z.size))
        print var.shape
        print X.shape
        print Y.shape
        print XI.shape
        print YI.shape
        print regrid.shape
#        month = 0
        for month in range(months.size): 
            regrid[month,:,:] = griddata((X.flatten(),Y.flatten()), var[month,:,:].flatten(), (XI, YI), method = method)
#        regrid[month,:,:] = griddata((X.flatten(),Y.flatten()), var[month,:,:].flatten(), (XI, YI), method = 'nearest')
        return regrid
        
    @staticmethod    
    def regrid2d(var, lat, lon, method):
        lon    = np.tile(lon,2)
        months = np.arange(var.shape[0])
        newlat = np.arange(ModelData._lat[1]-(ModelData._lat[0])+1)+ModelData._lat[0]
        newlon = np.arange((ModelData._lon[1]-ModelData._lon[0])/1.5+1)*1.5+ModelData._lon[0]
        newlon = np.tile(newlon,2)
        #create mesh
        X, Y = np.meshgrid(lon, lat)
        XI, YI = np.meshgrid(newlon, newlat)
        regrid = np.zeros((months.size, newlat.size, newlon.size))
#        month = 0
        for month in range(months.size):
            regrid[month,:,:] = griddata((X.flatten(),Y.flatten()), var[month,:,:].flatten(), (XI, YI), method = method)
#        regrid[month,:,:] = griddata((X.flatten(),Y.flatten()), var[month,:,:].flatten(), (XI, YI), method = 'nearest')
        return regrid
    
    @staticmethod    
    def saveAsNetCDF(var, filename, variable, version, newlat=None, newlon=None, depth=None):
        fName   = 'ModelData/' + filename
        print 'In ModelData: saveAsNetCDF: {}'.format(fName)
#        print 'In saveAsNetCDF version: {}'.format(version)
        newlon  = np.arange((ModelData._lon[1]-ModelData._lon[0])/1.5+1)*1.5+ModelData._lon[0]
        if newlat is None:       
            if version == 0:
                newlat = np.arange(ModelData._lat[1]-(ModelData._lat[0])+1)+ModelData._lat[0]
            elif version == 4:
                newlat = np.arange(ModelData._lat3[1]-(ModelData._lat3[0])+1)+ModelData._lat3[0]
            elif version == 7:
                newlat = np.arange(ModelData._lat2[1]-(ModelData._lat2[0])+1)+ModelData._lat2[0]
            elif version == 9:
                newlat = np.arange(ModelData._lat[1]-(ModelData._lat[0])+1)+ModelData._lat[0]
                newlon = np.arange((ModelData._lon2[1]-ModelData._lon2[0])/1.5+1)*1.5+ModelData._lon2[0]
            elif version == 11:
                newlat = np.arange(ModelData._lat[1]-(ModelData._lat[0])+1)+ModelData._lat[0]
            else:
                print 'No such period'
        if isinstance(version, bool):
            years   = 18
            start   = 1982 if version else 1964
        elif isinstance(version, int):
            if version == 0:
                years   = 57
                start   = 1958 
            elif version == 7:
                years = 45
                start = 1970
            elif version == 4:
                years = 45
                start = 1970
            elif version == 9:
                years   = 57
                start   = 1958 
            elif version == 11:
                years = 69
                start = 1948
#            elif version == 2:
#                years = 45
#                start = 1970
#            elif version == 3:
#                years = 15
#                start = 2000
#            elif version == 4:
#                years = 67
#                start = 1948
        elif isinstance(version, basestring):
            start   = float(version)
        else:
            years   = version[0]
            if version[1] == 0:
                start   = 1974 
            elif version[1] == 1:
                start   = 1995
            elif version[1] == 2:
                start   = 1970
            else:
                print 'No such period'
                return 'No such period'
        
        if isinstance(version, basestring):
            months = start+np.arange(var.shape[0])/12.
        else:
            months  = np.zeros(years*12)
            for y in range(years):
                for m in range(12):
                    months[y*12+m] = int((start+y)*100+m+1)
        if isinstance(version, bool):
            months = months[:-11] if version else months   
        if os.path.isfile(fName):
            try:
                os.remove(fName)
            except OSError:
                pass
        f   = nc.Dataset(fName, "w", format="NETCDF4")
        f.createDimension("time", var.shape[0])
        f.createDimension("lat", newlat.shape[0])
        f.createDimension("lon", newlon.shape[0])
        print 'time dimension: {}'.format(var.shape[0])
        print months
#        print 'lat dimension: {}'.format(newlat.shape[0])
#        print 'lon dimension: {}'.format(newlon.shape[0])
#        print 'var dimension: {}'.format(var.shape)
        times   = f.createVariable("time","u4",("time",))
        times[:]= months
        lats    = f.createVariable("lat","f4",("lat",))
        lats[:] = newlat
        lons    = f.createVariable("lon","f4",("lon",))
        lons[:] = newlon
        if depth is not None:
            f.createDimension("depth", depth.shape[0])
            depths    = f.createVariable("depth","f4",("depth",))
            depths[:] = depth
            v       = f.createVariable(variable,"f4",("time", "lat", "lon","depth"))
        else:
            v       = f.createVariable(variable,"f4",("time", "lat", "lon"))
        print 'var:{}'.format(var.shape)
        print 'time:{}'.format(months.shape)
        print 'lat:{}'.format(newlat.shape)
        print 'lon:{}'.format(newlon.shape)
        v[:,:,:]= var
        f.close()
#        print 'In ModelData:  saveAsNetCDF: {}'.format(var.shape)
        
    @staticmethod    
    def saveVerticalAsNetCDF(var, filename, variable, depth=None):
        fName   = 'ModelData/' + filename
        print 'In ModelData: saveAsNetCDF: {}'.format(fName)
#        print 'In saveAsNetCDF version: {}'.format(version)
        lon  = np.arange((ModelData._lon[1]-ModelData._lon[0])/1.5+1)*1.5+ModelData._lon[0]

        start   = 1948 
        months = start+np.arange(var.shape[0])/12.

        if os.path.isfile(fName):
            try:
                os.remove(fName)
            except OSError:
                pass
        f   = nc.Dataset(fName, "w", format="NETCDF4")
        f.createDimension("time", var.shape[0])
        f.createDimension("lon", lon.shape[0])
        f.createDimension("depth", depth.shape[0])
        print 'time dimension: {}'.format(var.shape[0])
        print months
#        print 'lat dimension: {}'.format(newlat.shape[0])
#        print 'lon dimension: {}'.format(newlon.shape[0])
#        print 'var dimension: {}'.format(var.shape)
        times   = f.createVariable("time","u4",("time",))
        times[:]= months
        lons    = f.createVariable("lon","f4",("lon",))
        lons[:] = lon
        depths    = f.createVariable("depth","f4",("depth",))
        depths[:] = depth
        v       = f.createVariable(variable,"f4",("time", "lon","depth"))
        print 'var:{}'.format(var.shape)
        print 'time:{}'.format(months.shape)
        print 'lon:{}'.format(lon.shape)
        print 'depth:{}'.format(depth.shape)
        v[:,:,:]= var
        f.close()
        
    @staticmethod    
    def saveVerticalPMMAsNetCDF(var, filename, variable, pmm_lat, pmm_lon, depth):
        fName   = 'ModelData/' + filename
        print 'In ModelData: saveAsNetCDF: {}'.format(fName)
#        print 'In saveAsNetCDF version: {}'.format(version)
        start   = 1948 
        months = start+np.arange(var.shape[0])/12.

        if os.path.isfile(fName):
            try:
                os.remove(fName)
            except OSError:
                pass
        f   = nc.Dataset(fName, "w", format="NETCDF4")
        f.createDimension("time", var.shape[0])
        f.createDimension("pmm_lat", pmm_lat.shape[0])
        f.createDimension("pmm_lon", pmm_lon.shape[0])
        f.createDimension("depth", depth.shape[0])
        print 'time dimension: {}'.format(var.shape[0])
        print months
#        print 'lat dimension: {}'.format(newlat.shape[0])
#        print 'lon dimension: {}'.format(newlon.shape[0])
#        print 'var dimension: {}'.format(var.shape)
        times   = f.createVariable("time","u4",("time",))
        times[:]= months
        lats    = f.createVariable("pmm_lat","f4",("pmm_lat",))
        lats[:] = pmm_lat
        lons    = f.createVariable("pmm_lon","f4",("pmm_lon",))
        lons[:] = pmm_lon
        depths    = f.createVariable("depth","f4",("depth",))
        depths[:] = depth
        v       = f.createVariable(variable,"f4",("time", "pmm_lat","depth"))
        print 'var:{}'.format(var.shape)
        print 'time:{}'.format(months.shape)
        print 'pmm_lat:{}'.format(pmm_lat.shape)
        print 'depth:{}'.format(depth.shape)
        v[:,:,:]= var
        f.close()
#        print 'In ModelData:  saveAsNetCDF: {}'.format(var.shape)
        
    def readVarible(self, filename, variable, vertical=False):
        fileName ='ModelData/' + filename
        f   = nc.Dataset(fileName)
        if vertical:
            z = f.variables['depth'][:]
            z = np.array(z)
            f.close()
            return z
        else:
            var = f.variables[variable][:,:,:]
            var = np.array(var)                
            f.close()
            return var
    
    def getNewLatLon(self, version):
        if version == 0:
            newlat = np.arange(ModelData._lat[1]-(ModelData._lat[0])+1)+ModelData._lat[0]
        elif version == 4:
            newlat = np.arange(ModelData._lat3[1]-(ModelData._lat3[0])+1)+ModelData._lat3[0]
        elif version == 7:
            newlat = np.arange(ModelData._lat2[1]-(ModelData._lat2[0])+1)+ModelData._lat2[0]
        else:
            print 'No such period'
        newlon  = np.arange((ModelData._lon[1]-ModelData._lon[0])/1.5+1)*1.5+ModelData._lon[0]
        return [newlat, newlon]