# Python script to generate all figures shown in Jucker and Reichler, Journal of Climate (2022)

This scripts needs `xarray`, `matplotlib`, `cartopy`, `seaborn`, `scipy`, and `aostools` to run.

In [None]:
%matplotlib inline

In [None]:
import xarray as xr
from aostools import climate as ac
from aostools import constants as at
# import matplotlib
# matplotlib.use('WebAgg')
from matplotlib import pyplot as plt
plt.ion()
from matplotlib.ticker import ScalarFormatter
import matplotlib.ticker as mticker
import matplotlib.path as mpath
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from cartopy import crs as ccrs
import seaborn as sns
import os
from scipy import stats
import numpy as np
import pandas as pd

where the data can be found

In [None]:
data_dir = 'data/'

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sns.set_context('talk')
sns.set_style('whitegrid')

In [None]:
methods = ['ucomp_0.0']
datas = ['cm2p1']
lag_vars_2d = ['slp','t_surf','precip']
lag_vars_3d = { 'hghtstd': {'slice' : {'lat':slice(-90,-60)}, 'mean':'coslat'},
              }
lag_bins = 3
lag_bin_names = ['early stage','mature stage','late stage']

In [None]:
cmaps = {
    'default' : 'RdBu_r',
    'slp'     : 'BrBG_r',
    't_surf'  : 'RdBu_r',
    'precip'  : 'PuOr',
    'nam_sh'  : 'RdBu',
    'sam'     : 'Reds_r',
    'tpp'     : 'cividis_r',
    'tppfull' : 'cividis_r',
    'f2'      : 'BrBG',
    'f2std'   : 'PuOr',
    'u300'    : 'RdYlBu_r',
}
for uname in ['ucomp','ucomp500','ucompfull']:
    cmaps[uname] = cmaps['u300']
cmaps['tsurf'] = cmaps['t_surf']

cont_vals = {'slp':list(np.linspace(-4.4,4.4,12)), 
             't_surf':[round(c,1) for c in np.linspace(-0.7,0.7,15) if c !=0],
             'precip':[round(c,1) for c in np.linspace(-0.7,0.7,15) if c !=0],
             'hghtstd': list(np.linspace(-3,3,16)),
             'f2std': [round(c,2) for c in np.linspace(-0.5,0.5,21) if c !=0] }

projections = {
    'default' : 'SouthPolarStereo',
}
colrs = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
def SaveFig(fig,outFile):
    fig.savefig(outFile,bbox_inches='tight',transparent=True)
    print(outFile)

In [None]:
def StatTest(ds):
    event_id = ds.get_axis_num('event')
    _,pvalue = stats.ttest_1samp(ds,0,axis=event_id)
    dsall = ds.mean('event')
    dnsig = ds.mean('event').where(pvalue>0.05)
    return dsall,dnsig

In [None]:
def SymmetricAxes(ax,choice):
    if 'x' in choice:
        xlims = ax.get_xlim()
        ax.set_xlim(min(-xlims[1],xlims[0]),max(-xlims[0],xlims[1]))
    if 'y' in choice:
        ylims = ax.get_ylim()
        ax.set_ylim(min(-ylims[1],ylims[0]),max(-ylims[0],ylims[1]))

In [None]:
def GetWidthHeight(fig,ax):
    bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    width, height = bbox.width, bbox.height
    return width,height

In [None]:
def SetWidthHeight(w,h, ax=None):
    """ w, h: width, height in inches """
    if not ax: ax=plt.gca()
    l = ax.figure.subplotpars.left
    r = ax.figure.subplotpars.right
    t = ax.figure.subplotpars.top
    b = ax.figure.subplotpars.bottom
    figw = float(w)/(r-l)
    figh = float(h)/(t-b)
    ax.figure.set_size_inches(figw, figh)


set hatching parameters

In [None]:
sig_alpha = 0.65
hatch = ['..']
plt.rcParams.update({'hatch.color': [sig_alpha]*3})

In [None]:
def Plot2D(ds,var,data=None,method=None,return_transf=False,proj=None,colorrange=None,fig=None,ax=None,transf=None,lat_avg=None,saveFig=True):
    if ds.name is not None:
        ds.name = None
    if lat_avg is None:
        lat_avg = [-90,-20]
    stereo_extent = [-90,-20]
    figArgs = {'levels':21,'robust':True}
    if var in cmaps.keys():
        cm = cmaps[var]
    else:
        cm = cmaps['default']
    figArgs['cmap'] = cm
    figArgs['zorder'] = 0
    if 'level' in ds.coords:
        ds = ds.sel(level=slice(850,None))
        if 'lat' in ds.coords:
            ds = ac.GlobalAvgXr(ds,lat_avg)
        ds,dt = StatTest(ds)
        if fig is None:
            fig,ax = plt.subplots()
        figArgs['ax'] = ax
        figArgs['y'] = 'level'
        figArgs['yscale'] = 'log'
        figArgs['extend'] = 'both'
        if colorrange is not None:
            if len(colorrange) == 2:
                figArgs['vmax'] = colorrange[1]
            else:
                figArgs['levels'] = colorrange
        if isinstance(colorrange,list) and len(colorrange) > 2:
            figArgs['cbar_kwargs'] = {'ticks' : [c for c in colorrange[::2] if c < 0] + sorted([c for c in colorrange[-1::-2] if c > 0]),
                                      }
        ds.plot.contourf(**figArgs)
        del figArgs['cmap']
        if 'cbar_kwargs' in figArgs.keys():
            del figArgs['cbar_kwargs']
        figArgs['add_colorbar'] = False
        figArgs['colors'] = 'none'
        figArgs['hatches'] = hatch
        dt.plot.contourf(**figArgs)
        ax.grid()
        ax.grid(zorder=10)
        ax.set_title(var)
        ax.invert_yaxis()
        ax.yaxis.set_major_formatter(ScalarFormatter())
    else:
        if 'lat' in ds.coords:
            if 'lon' in ds.coords:
                if lag_bins > 4:
                    ncols = lag_bins//2
                    nrows = 2
                else:
                    ncols = lag_bins
                    nrows = 1
                if proj is None:
                    if var in projections.keys():
                        proj = projections[var]
                    else:
                        proj = projections['default']
                if 'Stereo' in proj:
                    ds = ds.sel(lat=slice(*stereo_extent))
                    ds = ac.CloseGlobe(ds)
                    clon = 0
                else:
                    clon = 180
                if proj == 'none' and fig is None:
                    fig,ax = plt.subplots(ncols=ncols,nrows=nrows)
                    transf = {}
                elif fig is None:
                    fig,ax,transf = ac.Projection(projection=proj,ncols=ncols,nrows=nrows,coast=True
                                                ,kw_args={'central_longitude':clon})
                if 'Stereo' in proj:
                    data_aspect = 1
                    # prepare circular plot instead of default rectangle box
                    theta = np.linspace(0, 2*np.pi, 100)
                    center, radius = [0.5, 0.5], 0.5
                    verts = np.vstack([np.sin(theta), np.cos(theta)]).T
                    circle = mpath.Path(verts * radius + center)
                else:
                    data_aspect = (ds.lat.max()-ds.lat.min())/(ds.lon.max()-ds.lon.min())*2
                fig.set_figwidth(ncols*fig.get_figwidth())
                if proj != 'none':
                    fig.set_figheight(nrows*fig.get_figheight()*data_aspect)
                for key in transf.keys():
                    figArgs[key] = transf[key]
                ds = ds.groupby_bins('lag',lag_bins).mean()
                if lag_bin_names is not None and len(ds.lag_bins) == len(lag_bin_names):
                    ds = ds.assign_coords(lag_bins=lag_bin_names)
                ds,dt = StatTest(ds)
                cmax = None
                if colorrange is None:
                    cmin = ds.quantile(0)
                    cmax = ds.quantile(1)
                    cmin = max(-cmax,cmin)
                    cmax = min(-cmin,cmax)
                elif isinstance(colorrange,list) and len(colorrange)==2:
                    cmin = colorrange[0]
                    cmax = colorrange[1]
                if cmax is None and isinstance(colorrange,list):
                    figArgs['levels'] = colorrange
                else:
                    figArgs['vmax'] = cmax
                figArgs['extend'] = 'both'
                figArgs['add_colorbar'] = False
                for i,a in enumerate(ax.flatten()):
                    # cartopy 0.18 bug: need to do this before plotting
                    if proj is not 'none':
                        if 'Stereo' in proj:
                            a.set_extent([-180, 180, stereo_extent[0], stereo_extent[1]], ccrs.PlateCarree())
                            gl = a.gridlines(crs=ccrs.PlateCarree(), linewidth=1.0, color='gray', alpha=1, linestyle='-')
                            gl.xlocator = mticker.FixedLocator([-180, -120, -60, 0, 60, 120, 180])
                            gl.ylocator = mticker.FixedLocator([-90,-60,-30])
                            gl.xformatter = LONGITUDE_FORMATTER
                            gl.yformatter = LATITUDE_FORMATTER
                            gl.n_steps = 90
                            a.set_boundary(circle, transform=a.transAxes)
                        else:
                            gl = a.gridlines(draw_labels=True)
                            llab = blab = False
                            if i == 0 or  i == lag_bins/2:
                                llab = True
                            if i >= lag_bins/2:
                                blab = True
                            gl.bottom_labels = blab
                            gl.left_labels = llab
                            gl.top_labels = False
                            gl.right_labels = False
                    cont = ds.isel(lag_bins=i).plot.contourf(ax=a,**figArgs)
                    nsigArgs = {'colors':'none','hatches':hatch,'add_colorbar':False}
                    for key in transf.keys():
                        nsigArgs[key] = transf[key]
                    dt.isel(lag_bins=i).plot.contourf(ax=a,**nsigArgs)
                    a.grid()
                    a.grid(zorder=10)
                    ttle = a.get_title()
                    a.set_title(ttle.replace('lag_bins = ',''))
                cbArgs = {'ax':ax.flatten()}
                if isinstance(colorrange,list) and len(colorrange) > 2:
                    cbArgs['ticks'] = [c for c in colorrange[::2] if c < 0] + sorted([c for c in colorrange[-1::-2] if c > 0])
# this didn't work with a pre-defined facet plot
#                 cax = plt.axes([1.05, 0.1, 0.075, 0.8])
#                 fig.colorbar(cont,ax=cax,**cbArgs)
# this works with a pre-defined faced plot
                fig.colorbar(cont,**cbArgs)
                if saveFig:
                    ac.AddPanelLabels(ax.flatten(),loc='upper left',ypos=1.1)
            else:
                if fig is None:
                    fig,ax = plt.subplots()
                figArgs['ax'] = ax
                figArgs['y'] = 'lat'
                figArgs['extend'] = 'both'
                if colorrange is not None:
                    if len(colorrange) == 2:
                        figArgs['vmax'] = colorrange[1]
                    else:
                        figArgs['levels'] = colorrange
                ds = ds.sel(lat=slice(-90,-20))
                ds,dt = StatTest(ds)
                ds.plot.contourf(**figArgs)
                nsigArgs = {'colors':'none','hatches':hatch,'add_colorbar':False}
                dt.plot.contourf(ax=ax,y='lat',**nsigArgs)
                ax.set_title(var)
                ax.grid(zorder=10)
    if data is None:
        if return_transf:
            return fig,ax,transf
        else:
            return fig,ax
    #
    if saveFig:
        outFile = 'figures/{0}_composite.pdf'.format(var.lower())
        fig.savefig(outFile,bbox_inches='tight')
        print(outFile)
    if return_transf:
        return fig,ax,transf
    else:
        return fig,ax

In [None]:
def TropoLine(ds,negclr,posclr,fig=None,ax=None):
    tropo = ds.sel(level=slice(1000,300)).mean('level')
    tropom= tropo.mean('event')
    pval = ac.StatTest(tropo,0,test='T',dim='event')
    pvalpos = pval.where(tropom>0)
    pvalneg = pval.where(tropom<0)
    std = tropo.std(dim='event')
    if fig is None:
        fig,ax = plt.subplots()
    tropneg = tropom.where(tropom<0)
    stdneg  = std.where(tropom<0)
    tropneg.plot(ax=ax,linewidth=1.0,color=negclr)
    ax.fill_between(tropneg.lag,tropneg-stdneg,tropneg+stdneg,color=negclr,alpha=0.1)
    troppos = tropom.where(tropom>0)
    stdpos  = std.where(tropom>0)
    troppos.plot(ax=ax,linewidth=1.0,color=posclr)
    ax.fill_between(troppos.lag,troppos-stdpos,troppos+stdpos,color=posclr,alpha=0.1)
    for trp,st,pv,clr in [[tropneg,stdneg,pvalneg,negclr],[troppos,stdpos,pvalpos,posclr]]:
        trps = trp.where(pv<0.05)
        stds = st.where(pv<0.05)
        trps.plot(ax=ax,color=clr)
        ax.fill_between(trps.lag,trps-stds,trps+stds,color=clr,alpha=0.2)
    sns.despine()
    SymmetricAxes(ax,'y')
    ax.set_xlim(-60,60)
    ax.axhline(0,color='k')
    ax.set_xticks([-50,-25,0,25,50])
    return fig,ax

In [None]:
for data in datas:
    ncols = lag_bins
    nrows = len(lag_vars_2d)
    clon = 0
    proj = 'SouthPolarStereo'
    fig,axs,transf = ac.Projection(projection=proj,ncols=ncols,nrows=nrows,coast=True
                                                ,kw_args={'central_longitude':clon})
    for v,var in enumerate(lag_vars_2d):
        ds = xr.open_dataarray(data_dir+'{0}/{1}_composite_ucomp_0.0_lag-60-60_full.nc'.format(data,var))
        if var == 'precip':
            ds = ds*86400 # mm/day
        if 'm' in ds.coords:
            ds = ds.sel(m=1)
        if var in cont_vals.keys():
            Plot2D(ds,var,data,None,colorrange=cont_vals[var],fig=fig,ax=axs[v],transf=transf,saveFig=False)
        else:
            Plot2D(ds,var,data,None,fig=fig,ax=axs[v],transf=transf,saveFig=False)
    fig.set_figwidth(ncols*5.0)
    fig.set_figheight(nrows*4.0)
    ac.AddPanelLabels(axs.flatten(),loc='upper left',ypos=1.1)
    outFile = 'figures/2d_composites.pdf'.format(data)
    SaveFig(fig,outFile)

In [None]:
cpan = sns.color_palette()

In [None]:
tpc = {}
upc = {}
latdom= [-70,-50]     
for method in methods:
    for data in datas:
        tpp = xr.open_dataarray(data_dir+'{0}/tppfull_composite_{1}_lag-60-60_full.nc'.format(data,method))
        ucomp = xr.open_dataarray(data_dir+'{0}/ucompfull_composite_{1}_lag-60-60_full.nc'.format(data,method))
        key = '_'.join([method,data])
        tpc[key] = ac.GlobalAvgXr(tpp,[-90,-20]).mean('event')
        upc[key] = ac.GlobalAvgXr(ucomp,latdom).mean('event')

In [None]:
for method in methods:
    for data in datas:
        for var in lag_vars_3d.keys():
            ds = xr.open_dataarray(data_dir+'{0}/{1}_composite_{2}_lag-60-60_full.nc'.format(data,var,method))
            ds = ds.sel(lag_vars_3d[var]['slice'])
            if 'mean' in lag_vars_3d[var].keys():
                key = lag_vars_3d[var]['mean']
                if 'cos' in key:
                    latvar = key.replace('cos','')
                    coslat = np.cos(np.deg2rad(ds[latvar]))
                    ds = coslat*ds
                    norm = coslat.mean()
                else:
                    latvar = key
                    norm = 1
                ds = ds.mean(latvar)/norm
            if var in cont_vals.keys():
                fig1,ax1 = Plot2D(ds,var,data=None,colorrange=cont_vals[var])
            else:
                fig1,ax1 = Plot2D(ds,var,data=None)
            w,h = GetWidthHeight(fig1,ax1)
            key = '_'.join([method,data])
            upc[key].plot.contour(ax=ax1,levels=[0],colors='k',x='lag',linestyles='-')
            tpc[key].plot.line(ax=ax1,color='k',x='lag',linestyle='--')
            ax1.set_ylabel('pressure [hPa]')
            ax1.set_ylim(1000,1)
            ax1.set_title('SAM')
            SaveFig(fig1,'figures/{0}_composite.pdf'.format(var))
            fig2,ax2 = TropoLine(ds,cpan[0],cpan[3])
            SetWidthHeight(w,h,ax2)
#             ax.set_ylabel('SAM')
            ax2.set_title('SAM 1000-300hPa mean')
            SaveFig(fig2,'figures/{0}_tropo_composite.pdf'.format(var))

In [None]:
lat_avg = [-90,-20]
for method in methods:
    for data in datas:
        ep = xr.open_dataarray(data_dir+'{0}/f2std_composite_{1}_lag-60-60_full.nc'.format(data,method))
        ep = ep.sel(m=1)
        fig1,ax1 = Plot2D(ep,'f2std',data=None,colorrange=cont_vals['f2std'],lat_avg=lat_avg)
        w,h = GetWidthHeight(fig1,ax1)
#         upc.plot.contour(ax=ax,levels=[20],colors='k',x='lag',yincrease=False,linestyles='--')
        key = '_'.join([method,data])
        upc[key].plot.contour(ax=ax1,levels=[0] ,colors='k',x='lag',yincrease=False,linestyles='-')
        tpc[key].plot.line(ax=ax1,color='k',x='lag',linestyle='--')
        ax1.set_ylabel('pressure [hPa]')
        ax1.set_title('upward EP flux')
        SaveFig(fig1,'figures/f2std_composite.pdf')
        fig2,ax2 = TropoLine(ac.GlobalAvgXr(ep,lat_avg),cpan[1],cpan[4])
        SetWidthHeight(w,h,ax2)
        ax2.set_title('standardized $F_p$ 1000-300hPa mean')
        SaveFig(fig2,'figures/f2std_tropo_composite.pdf')

In [None]:
for method in methods:
    for data in datas:
        ep = xr.open_dataarray(data_dir+'{0}/f2std_composite_{1}_lag-60-60_full.nc'.format(data,method))
        ep.name = None
        ep.lat.attrs['units'] = ''
        epm = ep.sel(lat=slice(-90,0),level=slice(1000,300)).mean('level')
        epm = epm.rename({'m':'k'})
        ds,dt = StatTest(epm)
        ds.plot(col='k',cmap=cmaps['f2std'])
        fig = plt.gcf()
        axs = fig.axes[:-1]
        nsigArgs = {'colors':'none','hatches':hatch,'add_colorbar':False}
        for i,a in enumerate(axs):
            dt.isel(k=i).plot.contourf(ax=a,**nsigArgs)
#             a.grid(True,zorder=100)
            if i == 0:
                a.set_title('total')
            else:
                a.set_ylabel('')
        SaveFig(fig,'figures/f2std_tropo_m_composite.pdf')

In [None]:
tpplats = {'polar':[-90,-70],'tropics':[-20,0]}
for method in methods:
    for data in datas:
        ds = xr.open_dataarray(data_dir+'{0}/tpz_composite_{1}_lag-60-60_full.nc'.format(data,method))
        event_id = ds.get_axis_num('event')
        _,pvalue = stats.ttest_1samp(ds,0,axis=event_id)
        std = ds.std('event')
        ds = ds.mean('event')
        upper = ds+std
        lower = ds-std
        ds = ds.where(pvalue<0.05)
        coslat = np.cos(np.deg2rad(ds.lat))
        fig,ax = plt.subplots()
        ax.axhline(0,color='k')
        for key in tpplats.keys():
            norm = coslat.sel(lat=slice(*tpplats[key])).mean('lat')
            ttmp = (coslat*ds).sel(lat=slice(*tpplats[key])).mean('lat')/norm
            line = ttmp.plot.line(ax=ax,label='{}'.format(tpplats[key]))
            colr = line[0].get_color()
            low = (coslat*lower).sel(lat=slice(*tpplats[key])).mean('lat')/norm
            upp = (coslat*upper).sel(lat=slice(*tpplats[key])).mean('lat')/norm
            ax.fill_between(low.lag,low,upp,color=colr,alpha=0.3)
        ax.set_title('tropopause height [km]')
        sns.despine()
#         ax.grid()
        ax.legend()
#         ax.set_xlim(-60,60)
#         ax.invert_yaxis()
        outFile = 'figures/tpz_line_composite.pdf'
        fig.savefig(outFile,bbox_inches='tight')
        print(outFile)

# EP flux wavenumbers

In [None]:
eplevs = [500,250,100,10]
eplats = [[-85,-60]]

In [None]:
def ReadF(data,method,var):
    if 'std' in var:
        read_var = var.replace('std','')
        std_var  = var
    else:
        read_var = var
        std_var  = None
    f2_raw = xr.open_dataarray(data_dir+'{0}/{2}_composite_{1}_lag-60-60_full.nc'.format(data,method,read_var))
    if std_var is not None:
        f2_std = xr.open_dataarray(data_dir+'{0}/{2}_composite_{1}_lag-60-60_full.nc'.format(data,method,std_var))
        # now get the std for m=0 and apply that to all wave numbers
        std_0 = (f2_raw.isel(m=0)/f2_std.isel(m=0)).squeeze()
        del std_0['m']
        f2_raw = f2_raw/std_0
    f2_synopt = f2_raw.sel(m=0) - f2_raw.sel(m=slice(1,3)).sum('m')
    f2_synopt = f2_synopt.assign_coords({'m':4})
    f2std = xr.concat([f2_raw,f2_synopt],dim='m')
    event_id = f2std.get_axis_num('event')
    _,pvalue = stats.ttest_1samp(f2std,0,axis=event_id)
    f2std = f2std.mean('event').where(pvalue<0.05) 
    return f2std

In [None]:
var = 'f2std'

In [None]:
f2std = ReadF('cm2p1',method,var)
ucomp = xr.open_dataarray(data_dir+'{0}/ucomp_composite_{1}_lag-60-60_full.nc'.format(data,method))

In [None]:
coslat = np.cos(np.deg2rad(f2std.lat))
figArgs = {'y':'level','yscale':'log'}
linestyles = ['-','--','--','--','-.']
markers    = ['' ,'x' ,'o' ,'^' ,'']
if lag_bins >= 4:
    nrows = 2
else:
    nrows = 1
ncols = lag_bins//nrows

In [None]:
f2lag = f2std.groupby_bins('lag',lag_bins).mean()
if lag_bin_names is not None:
    f2lag = f2lag.assign_coords(lag_bins=lag_bin_names)
ulag  = ucomp.groupby_bins('lag',lag_bins).mean().mean('event')
if lag_bin_names is not None:
    ulag = ulag.assign_coords(lag_bins=lag_bin_names)
for l,lat in enumerate(eplats):
    fig,ax = plt.subplots(nrows,ncols,sharex=True,sharey=False)
    fig.set_figwidth(ncols*fig.get_figwidth())
    fig.set_figheight(nrows*fig.get_figheight())
    for i,a in enumerate(ax.flatten()):
        ax2 = a.twiny()
        utmp = ac.GlobalAvgXr(ulag.isel(lag_bins=i),lat)
#         utmp = ulag.isel(lag_bins=i).sel(lat=slice(*lat)).mean('lat')
        utmp.plot.line(color='gray',lw=2,ax=ax2,ls=':',**figArgs)
        f2tmp = ac.GlobalAvgXr(f2lag.isel(lag_bins=i),lat)
#         f2tmp = f2lag.isel(lag_bins=i).sel(lat=slice(*lat)).mean('lat')
        f2tmp.plot.line(ax=a,add_legend=False,**figArgs)
        for j,line in enumerate(a.get_lines()):
            line.set_linestyle(linestyles[j])
#             line.set_marker(markers[j])
#         a.grid()
        ax2.legend(['U'],loc='lower left')
        a.legend(['tot','k=1','k=2','k=3','k>3'],loc='upper right')
        a.axvline(0,color='k')
        a.set_ylim(1e3,1)
        if i == 0:
            a.set_ylabel('pressure [hPa]')
        else:
            a.set_ylabel('')
        ttle = a.get_title()
        a.set_title(ttle.replace('lag_bins = ',''))
        ax2.set_title('')
        ax2.set_xlabel('')
        ax2.set_xlim(-30,30)
        ax2.set_xticks([])
        if var == 'f2std':
            a.set_xlim(-1.2,1.2)
        if nrows == 1 or i >= ncols:
            a.set_xlabel('standardized f2 []')
        #if i == 0 or i == ncols:
        a.yaxis.set_major_formatter(ScalarFormatter())
    ac.AddPanelLabels(ax,'upper left')
#     fig.suptitle('lat [{0} to {1}]'.format(*lat))   
    sns.despine(left=True)
    outFile = 'figures/{0}_wave_vert_{1}_to_{2}_composite.pdf'.format(var,lat[0],lat[1])
    fig.savefig(outFile,bbox_inches='tight')
    print(outFile)

In [None]:
n2_raw = xr.open_dataarray(data_dir+'cm2p1/n2_composite_{0}_lag-60-60_full.mean.nc'.format(method))
n2m = []
for m in f2std.m:
    n2k = n2_raw-(m/at.coslat(n2_raw.lat))**2
    n2k['m'] = m
    n2m.append(n2k)#.where(np.abs(n2k)<500))
n2 = xr.concat(n2m,dim='m')

In [None]:
# only include k=1
n2 = n2.sel(m=[0,1])

In [None]:
f1 = ReadF('cm2p1',method,'f1')/at.a0/at.coslat(n2.lat)
f2 = ReadF('cm2p1',method,'f2')/at.a0/at.coslat(n2.lat)/100

In [None]:
if lag_bins < 4:
    nrows = len(n2.m)-1
    ncols = lag_bins

In [None]:
n2lag = (n2).groupby_bins('lag',lag_bins).mean()
f1lag = (f1.isel(level=slice(None,None,2),lat=slice(None,None,2))).groupby_bins('lag',lag_bins).mean()
f2lag = (f2.isel(level=slice(None,None,2),lat=slice(None,None,2))).groupby_bins('lag',lag_bins).mean()
lat = [-85,0]
figArgs['yscale'] = 'linear'
figArgs['zorder'] = 1
contArgs = figArgs.copy()
contArgs['vmin'] = -300
contArgs['vmax'] =  300
contArgs['extend']= 'both'
contArgs['cmap'] = 'RdBu_r'
contArgs['zorder'] = 1
contArgs['add_colorbar'] = False
if lag_bins < 4:
    fig,axs = plt.subplots(nrows,ncols,sharex=True,sharey=True)
    fig.set_figwidth(ncols*fig.get_figwidth())
    fig.set_figheight(nrows*fig.get_figheight())
for mm,m in enumerate(n2lag.m[1:]):
    if lag_bins >= 4:
        fig,ax = plt.subplots(nrows,ncols,sharex=False,sharey=False)
        fig.set_figwidth(ncols*fig.get_figwidth())
        fig.set_figheight(nrows*fig.get_figheight())
    else:
        if nrows > 1:
            ax = axs[mm]
        else:
            ax = axs
    for i,a in enumerate(ax.flatten()):
        cf = n2lag.isel(lag_bins=i).sel(lat=slice(*lat),m=m).plot.contourf(levels=21,ax=a,**contArgs)
        n2lag.isel(lag_bins=i).sel(lat=slice(*lat),m=m).plot.contour(levels=[0],colors='k',ax=a,**figArgs)
        ep1 = f1lag.isel(lag_bins=i).sel(m=m,lat=slice(*lat))
        ep2 = f2lag.isel(lag_bins=i).sel(m=m,lat=slice(*lat))
        x = ep1.lat
        y = ep2.level
        ac.PlotEPfluxArrows(x,y,ep1,ep2,fig,a,yscale='log',pivot='middle',scale=2e15)
        a.set_ylim(1e3,1)
#         ttle = a.get_title()
#         lag_int = ttle.split('],')[0].split('=')[-1]+']'
#         a.set_title('lag'+lag_int)
        a.set_title(lag_bin_names[i])
        if i >= ncols:
            a.set_xlabel('latitude [deg]')
        else:
            a.set_xlabel('')
        if np.mod(i,ncols) == 0:
            a.set_ylabel('pressure [hPa]')
        else:
            a.set_ylabel('')
        #if i == 0 or i == ncols:
        a.yaxis.set_major_formatter(ScalarFormatter())
    if lag_bins >= 4:
        ac.AddColorbar(fig,ax,cf)
        ac.AddPanelLabels(ax,'upper left',ypos=1.11)
        fig.tight_layout()
    #     fig.suptitle('k = {}'.format(m.values))   
        sns.despine()
        outFile = 'figures/{0}_k{1}_composite.pdf'.format('n2',m.values)
        fig.savefig(outFile,bbox_inches='tight')
        print(outFile)
if lag_bins < 4:
    if nrows == 1:
        shrnk = 1.0
    else:
        shrnk = 0.3
    ac.AddColorbar(fig,axs,cf,shrink=shrnk)
    ac.AddPanelLabels(axs,'upper left',ypos=1.11)
    for a in axs.flatten():
        a.grid(zorder=5)
        if a.is_last_row:
            a.set_xlabel('latitude')
#     fig.tight_layout()
#     fig.suptitle('k = {}'.format(m.values))   
    sns.despine()
    outFile = 'figures/{0}_kall_composite.pdf'.format('n2')
    fig.savefig(outFile,bbox_inches='tight')
    print(outFile)

# Vortex Moment Analysis

In [None]:
from glob import glob
vxfls = glob(data_dir+'cm2p1/vxmoms_composite_ucomp_0.0_lag-60-60_full.*.nc')
vxfls.sort()
pvms = []
for fle in vxfls:
    edge = fle.replace('.nc','').split('full.')[-1]
    tmp = xr.open_dataset(fle)
    tmp = tmp.where(tmp.aspect_ratio < 100).where(tmp.centroid_latitude < 90)
    tmp['edge'] = edge
    pvms.append(tmp)
pvmoms = xr.concat(pvms,dim='edge')
# note: centroid_longitude is \in [-90,90] bcse it is computed using an arctan.
#  In reality, it should be \in [0,360], so I am not sure where the missing 180 degrees are...

In [None]:
vxera = xr.open_mfdataset(data_dir+'era5/vxmoms_*.nc').squeeze().load()
vxera_extr = xr.merge([vxera.aspect_ratio.rolling(time=7).min(),vxera.centroid_latitude.rolling(time=7).max()])
era_mrkrs = ['*','o']
era_stles = ['--','-.']

Seviour et al (2013) define 
- Displacement: centroid_latitude < 66N for >= 7days
- Split: aspect_ratio > 2.4 for >= 7days

For us, this means that between lags -3 and 3
- Displacement: max(centroid_latitude) > -66
- Split: min(aspect_ratio) > 2.4

In [None]:
# # first version: check moments at onset
# asp = pvmoms.sel(lag=slice(-3,3)).min('lag').aspect_ratio
# cent= pvmoms.sel(lag=slice(-3,3)).max('lag').centroid_latitude
# # second version: check anytime
# #  this means that the rolling mean min aspect ratio >2.4 for split
# #   and the rolling mean max aspect ratio < 66S for displacement

split_lim = 1.8

asp = pvmoms.rolling(lag=7).min().max('lag').aspect_ratio
cent = pvmoms.rolling(lag=7).max().min('lag').centroid_latitude
dat = xr.merge([asp,cent])

for edge in ['30.4']:
    if isinstance(edge,str):
        edg = edge
    else:
        edg = str(edge.values)
    jp = sns.JointGrid(data=dat.sel(edge=edge),x='aspect_ratio',y='centroid_latitude')
    jp.plot_joint(sns.kdeplot)
    jp.plot_joint(sns.scatterplot)
    for y,yr in enumerate(['2002','2019']):
        xtmp = vxera_extr.sel(time=yr).aspect_ratio.max()
        ytmp = vxera_extr.sel(time=yr).centroid_latitude.min()
        jp.ax_joint.scatter(xtmp,ytmp,marker=era_mrkrs[y],color='r',label=yr)
    jp.ax_joint.legend()
    jp.plot_marginals(sns.histplot,kde=True)
    ax = jp.ax_joint
    ylims = ax.get_ylim()
    xlims = ax.get_xlim()
    ax.set_ylim(ylims[0],90)
    ax.set_xlim(1,3)
    ax.set_ylim(50,90)
    ax.set_xlabel('aspect ratio')
    ax.set_ylabel('centroid latitude')
    ttle = 'Vortex moment density, edge = '+edg
    ax.axhline(66,color='black',linewidth=1.0,linestyle=':')
    jp.ax_marg_y.axhline(66,color='black',linewidth=1.0,linestyle=':')
    ax.axvline(2.4,color='black',linewidth=1.0,linestyle=':')
    jp.ax_marg_x.axvline(2.4,color='black',linewidth=1.0,linestyle=':')
    if split_lim != 2.4:
        ax.axvline(split_lim,color='black',linewidth=1.0,linestyle='--')
        jp.ax_marg_x.axvline(split_lim,color='black',linewidth=1.0,linestyle='--')
    outFile = 'figures/vx_moment_scatter.{0}.pdf'.format(edg)
    plt.savefig(outFile,bbox_inches='tight')
    print('saved file '+outFile)

In [None]:
cmap = 'plasma_r'
asp = pvmoms.sel(edge=str(edge)).aspect_ratio.rolling(lag=7).min()
hst = []
bins = np.linspace(1,3,21)
for l in asp.lag.values:
    h,b = np.histogram(asp.sel(lag=l),bins=bins)
    hx = xr.DataArray(h,coords=[('aspect_ratio',0.5*(bins[1:]+bins[:-1]))],name='hist')
    hx['lag'] = l
    hst.append(hx)
hstx = xr.concat(hst,dim='lag')
hstx.plot.contourf(x='lag',cmap='plasma',vmin=0.1,vmax=50,extend='max',robust=True)
# add ERA5
tmp = vxera_extr.sel(time=slice('2002-07-26','2002-11-23')).aspect_ratio.rename({'time':'lag'})
tmp.assign_coords(lag=hstx.lag).plot(x='lag',color='r',ls=era_stles[0],label='2002')
tmp = vxera_extr.sel(time=slice('2019-07-17','2019-11-14')).aspect_ratio.rename({'time':'lag'})
tmp.assign_coords(lag=hstx.lag).plot(x='lag',color='r',ls=era_stles[1],label='2019')
plt.legend()
#
plt.ylabel('aspect ratio []')
plt.title('7-day minimum aspect ratio PDF')
#
outFile = 'figures/aspect_ratio_hist_cont.pdf'
plt.savefig(outFile,bbox_inches='tight')
print('saved file '+outFile)

In [None]:
cmap = 'plasma_r'
cent = pvmoms.sel(edge=str(edge)).centroid_latitude.rolling(lag=7).max()
hst = []
bins = np.linspace(55,90,21)
for l in cent.lag.values:
    h,b = np.histogram(cent.sel(lag=l),bins=bins)
    hx = xr.DataArray(h,coords=[('central_latitude',0.5*(bins[1:]+bins[:-1]))],name='hist')
    hx['lag'] = l
    hst.append(hx)
hstx = xr.concat(hst,dim='lag')
hstx.plot.contourf(x='lag',cmap='plasma',vmin=0.1,vmax=30,extend='max',robust=True,yincrease=False)
# add ERA5
tmp = vxera_extr.sel(time=slice('2002-07-26','2002-11-23')).centroid_latitude.rename({'time':'lag'})
tmp.assign_coords(lag=hstx.lag).plot(x='lag',color='r',ls=era_stles[0],label='2002')
tmp = vxera_extr.sel(time=slice('2019-07-17','2019-11-14')).centroid_latitude.rename({'time':'lag'})
tmp.assign_coords(lag=hstx.lag).plot(x='lag',color='r',ls=era_stles[1],label='2019')
plt.legend()
#
plt.ylabel('centroid latitude [deg south]')
plt.title('7-day maximum centroid latitude PDF')
#
outFile = 'figures/centroid_latitude_hist_cont.pdf'
plt.savefig(outFile,bbox_inches='tight')
print('saved file '+outFile)

In [None]:
filtr = asp > 2.4
nsplits = sum(asp.max('lag') > 2.4)
# here try to find the relevant lags for each event for split composite
splits = pvmoms.sel(edge=edge).aspect_ratio.where(filtr)
z10 = xr.open_dataarray(data_dir+'cm2p1/hght10full_composite_ucomp_0.0_lag-60-60_full.nc')
z10.name = 'Z [km]'
z10 = z10/1000
z10 = ac.CloseGlobe(z10)
z10 = z10.assign_coords({'event':np.arange(len(z10.event))})
fig,ax,transf = ac.Projection('SouthPolarStereo')
# ztmp = z10.isel(event=splits.event.values).sel(lat=slice(-90,-30),lag=splits.lag).mean(['event','lag']).plot.contourf(levels=21,ax=ax,**transf)
z10.where(filtr).sel(lat=slice(-90,-30),lag=splits.lag).mean(['event','lag']).plot.contourf(levels=21,ax=ax,**transf)
xlim = ax.get_xlim()
ylim = ax.get_ylim()
z10.where(filtr).sel(lat=slice(-90,-30),lag=splits.lag).mean(['event','lag']).plot.contour(levels=[float(edge)],colors='gray',ax=ax,xlim=xlim,ylim=ylim,**transf)
ax.set_title('N = {0}'.format(nsplits.values))
outFile = 'figures/hght10full_asp_gt2p4.pdf'
fig.savefig(outFile,transparent=True,bbox_inches='tight')
print(outFile)

In [None]:
filtr = (asp > split_lim)*(asp < 2.4)
nsplits = sum((asp.max('lag') > split_lim)*(asp.max('lag') < 2.4))
# here try to find the relevant lags for each event for split composite
splits = pvmoms.sel(edge=edge).aspect_ratio.where(filtr) 
fig,ax,transf = ac.Projection('SouthPolarStereo')
z10.where(filtr).sel(lat=slice(-90,-30),lag=splits.lag).mean(['event','lag']).plot.contourf(levels=21,ax=ax,**transf)
xlim = ax.get_xlim()
ylim = ax.get_ylim()
z10.where(filtr).sel(lat=slice(-90,-30),lag=splits.lag).mean(['event','lag']).plot.contour(levels=[float(edge)],colors='gray',xlim=xlim,ylim=ylim,ax=ax,**transf)
ax.set_title('N = {0}'.format(nsplits.values))
tmp = '{0}'.format(split_lim).replace('.','p')
outFile = 'figures/hght10full_asp_gt{0}_lt2p4.pdf'.format(tmp)
fig.savefig(outFile,transparent=True,bbox_inches='tight')
print(outFile)

In [None]:
filtr = asp < split_lim
nsplits = sum(asp.max('lag')<split_lim)
# here try to find the relevant lags for each event for split composite
splits = pvmoms.sel(edge=edge).aspect_ratio.where(filtr) 
fig,ax,transf = ac.Projection('SouthPolarStereo')
z10.where(filtr).sel(lat=slice(-90,-30),lag=splits.lag).mean(['event','lag']).plot.contourf(levels=21,ax=ax,**transf)
xlim = ax.get_xlim()
ylim = ax.get_ylim()
z10.where(filtr).sel(lat=slice(-90,-30),lag=splits.lag).mean(['event','lag']).plot.contour(levels=[float(edge)],colors='gray',xlim=xlim,ylim=ylim,ax=ax,**transf)
ax.set_title('N = {0}'.format(nsplits.values));
tmp = '{0}'.format(split_lim).replace('.','p')
outFile = 'figures/hght10full_asp_lt{0}.pdf'.format(tmp)
fig.savefig(outFile,transparent=True,bbox_inches='tight')
print(outFile)

# large scale indices

In [None]:
ninos = xr.open_dataset(data_dir+'ninos_writ.nc')
vals = {'year':[],'Index':[],'Value':[]}
for year in [2002,2019]:
    for var in ninos.data_vars:
        tmp = ninos.sel(time=slice('{0}-09-01'.format(year),'{0}-11-30'.format(year))).mean('time')[var]
        vals['year'].append(year)
        vals['Index'].append(var)
        vals['Value'].append(float(tmp.data))
df_vals = pd.DataFrame.from_dict(vals)

In [None]:
def CreateIndex(lags,axlims={'x':[-4,4],'y':[0,0.25]},inds=['TNI','DMI'],color=None,ax=None,cmon=None):
    cm_inds = []
    pos = {}
    neg = {}
    mu  = {}
    pval= {}
    std = {}
    ttle = {}
    for ind in inds:
        tni = xr.open_dataarray(data_dir+'cm2p1/{0}_composite_ucomp_0.0_lag-60-60_full.nc'.format(ind))
        tni_tmp = tni.sel(lag=slice(*lags)).mean('lag')
        cm_inds.append(tni_tmp)
        pos[ind] = np.sum(tni_tmp > 0)/len(tni_tmp)
        neg[ind] = np.sum(tni_tmp < 0)/len(tni_tmp)
        mu[ind] = tni_tmp.mean().values
        std[ind]= tni_tmp.std().values
        # simple T-test against 0
        #pval[ind] = ac.StatTest(tni_tmp,0,test='T')
        # KS-test against SON climatology
        tni_clim = xr.open_dataarray(data_dir+'cm2p1/{0}_complete.nc'.format(ind.lower()))
        if cmon is None:
            filtr = tni_clim.time.dt.season == 'SON'
        else:
            filtr = tni_clim.time.dt.month == cmon
        tni_clim = tni_clim.isel(time=filtr)
        pval[ind] = ac.StatTest(tni_tmp,tni_clim,test='KS')
        ttle[ind] = 'lags ({0},{1}]'.format(*lags)
    cmx = xr.merge(cm_inds)
    cm_inds = cmx.to_dataframe()
    if len(inds) > 1:
        sns.displot(cm_inds,ax=ax,multiple='dodge',binrange=axlims['x'],bins=10,kde=False,stat='density',color=color)#,alpha=1)
    else:
        if ax is None:
            fig,ax = plt.subplots()
        sns.histplot(cm_inds,x=ind,binrange=axlims['x'],bins=10,stat='density',kde=False,color=color,ax=ax)
    sns.kdeplot(tni_clim,cut=0,color=color,ax=ax)
    sns.despine(offset=10)
    if ax is None:
        ax = plt.gca()
    ax.grid()
    ax.axvline(0,color='k',zorder=0)
    if axlims['x'] is None:
        SymmetricAxes(ax,'x')
    else:
        ax.set_xlim(*axlims['x'])
    if axlims['y'] is not None:
        ax.set_ylim(*axlims['y'])
    if len(inds) == 1:
        ax.text( 0.3,0.05,'{0:.0%}'.format(pos[ind].values),backgroundcolor=(1,1,1,0.7),ha='left',size='large')
        ax.text(-0.3,0.05,'{0:.0%}'.format(neg[ind].values),backgroundcolor=(1,1,1,0.7),ha='right',size='large')
        ax.text(1.5,0.3,'$\mu=${0:4.1f}\n$p=${1:4.1%}\n$\sigma=${2:4.1f}'.format(mu[ind],pval[ind],std[ind]))
        ax.set_title(ttle[ind])
    return plt.gcf(),plt.gca()

In [None]:
cpal = sns.color_palette()

In [None]:
indxs = ['TNI','DMI']
laglist = [[-60,-20],[-20,20],[20,60]]
monthlist = [8,9,10] #months to compare to climatologies
ncols = len(laglist)
nrows = len(indxs)
fig,axs = plt.subplots(nrows=nrows,ncols=ncols,sharex=True,sharey=True,figsize=[4*1.1*ncols,3*1.1*nrows])
for i,ind in enumerate(indxs):
    for l,lags in enumerate(laglist):
        ax = axs[i][l]
        _ = CreateIndex(lags,inds=[ind],axlims={'x':[-4,4],'y':[0,.5]},color=cpal[i],ax=ax,cmon=monthlist[i])
        ttle = ax.get_title()
        if i==0:
            ax.set_title(ttle)
            ax.set_title(lag_bin_names[l])
        else:
            ax.set_title('')
        if l==0:
            ax.set_ylabel('{0} density'.format(ind))
        ax.set_xlabel('$\sigma$')
ac.AddPanelLabels(axs,'upper left')
outFile = 'figures/sst_inds_pdfs.pdf'
SaveFig(fig,outFile)

# Stationary waves

In [None]:
def PlotStatWavesLine(clim,anom,latslice):
    from matplotlib.gridspec import GridSpec
    colrs = []
    ns = [1,2,3]#,4,5,6]
    cloc = clim.sel(lat=slice(*latslice))
    aloc = anom.sel(lat=slice(*latslice))
    cloc = cloc.weighted(at.coslat(cloc.lat)).mean('lat')
    aloc = aloc.weighted(at.coslat(aloc.lat)).mean('lat')
    cwves = ac.GetWavesXr(cloc,wave=-1,dim='lon')
    camps = xr.concat([cwves.isel(k=n).max('lon') for n in ns],dim='k')
    camps.name = 'clim'
    awves = ac.GetWavesXr(aloc,wave=-1,dim='lon')
    aamps = xr.concat([awves.isel(k=n).max('lon') for n in ns],dim='k')
    aamps.name = 'anom'
    twves = ac.GetWavesXr(cloc+aloc,wave=-1,dim='lon')
    tamps = xr.concat([twves.isel(k=n).max('lon') for n in ns],dim='k')
    tamps.name = 'tot'
    amps = xr.merge([camps,aamps,tamps])
    #plot
    fig = plt.figure(figsize=[3*4,3])
    gs = GridSpec(1,3)
    ax = fig.add_subplot(gs[:2])
    ax1 = fig.add_subplot(gs[-1])
    p = cloc.plot(label='clim',ax=ax)
    colrs.append(p[-1].get_color())
    ax.set_ylabel('climatology')
    ax.grid()
#     ax2 = ax.twinx()
    p = aloc.plot(ax=ax,ls='--',label='anom')
    colrs.append(p[-1].get_color())
    p = (cloc+aloc).plot(ax=ax,ls=':',label='tot')
    colrs.append(p[-1].get_color())
#     ax2.set_ylabel('anomaly')
#     ax2.grid()
    ax.axhline(0,color='k')
    sns.despine()
    SymmetricAxes(ax,'y')
#     SymmetricAxes(ax2,'y')
    ax.legend()
    # plot wavenumber amplitudes
    amps.to_dataframe().plot.bar(rot=0,width=0.7,ax=ax1,legend=False)
    bars = ax1.patches
    nks = len(amps.k.values)
    hatches = ['']*nks + ['//']*nks + ['..']*nks
    colors  = [colrs[0]]*nks + [colrs[1]]*nks + [colrs[2]]*nks
    for bar,hatch,color in zip(bars,hatches,colors):
        bar.set_color('none')
        bar.set_hatch(hatch)
        bar.set_edgecolor(color)
    ax1.grid()
    ac.AddPanelLabels(np.array([ax,ax1]),'upper left',ypos=1.15)
    return fig,ax#,ax2

In [None]:
cm2p1_slp_clim = xr.open_dataarray(data_dir+'cm2p1/slp_AS_climat.nc').squeeze()
cm2p1_slp_anom = xr.open_dataarray(data_dir+'cm2p1/slp_composite_ucomp_0.0_lag-60-60_full.nc').mean('event')
cm2p1_slp_clima = ac.StandardGrid(cm2p1_slp_clim - cm2p1_slp_clim.mean('lon'))
cm2p1_slp_anoma = ac.StandardGrid(cm2p1_slp_anom - cm2p1_slp_anom.mean('lon'))

In [None]:
fig,ax = PlotStatWavesLine(cm2p1_slp_clima,cm2p1_slp_anoma.sel(lag=slice(-60,-20)).mean('lag'),[-70,-50])
ax.set_xlim(0,360)
ax.set_ylabel('SLP [hPa]')
ax.set_title('CM2.1 AS SLP clim vs SSW')
SaveFig(fig,'figures/slp_stat_waves_lines.pdf')

# Regional surface histograms

In [None]:
cpal = sns.color_palette()

In [None]:
regions = {
    'EAUS' : {'lon':slice(135,155),'lat':slice(-39,-17)},
    'WAUS' : {'lon':slice(113,135),'lat':slice(-35,-14)},
    'NZ'   : {'lon':slice(166.5,178.5),'lat':slice(-47,-34.5)},
    'SAF'  : {'lon':slice(14,35),'lat':slice(-34.5,-17)},
    'SSA'  : {'lon':slice(287,298),"lat":slice(-55,-34)},
    'IA'   : {'lon':slice(30,110),'lat':slice(-90,-65)},
    'PA'   : {'lon':slice(160,280),'lat':slice(-90,-65)},
}

In [None]:
precip = xr.open_dataarray(data_dir+'cm2p1/precip_composite_ucomp_0.0_lag-60-60_full.nc')*86400
tsurf = xr.open_dataarray(data_dir+'cm2p1/t_surf_composite_ucomp_0.0_lag-60-60_full.nc')
del tsurf['dayofyear']
tsurf = ac.StandardGrid(tsurf,rename=True)

In [None]:
clims = {
    'precip': {'mean' : 0.3, 'perc' : 0.2},
    'tsurf' : {'mean' : 0.4, 'perc' : 0.2}
}
hist_bin_names = ['early stage','late stage','late - early stage']

In [None]:
offsets = {
    'EAUS': {'lon': 5, 'lat':+5},
    'WAUS': {'lon':-10,'lat':-10},
    'NZ'  : {'lon': 5, 'lat':-5},
    'SAF' : {'lon': 0, 'lat': 0},
    'SSA' : {'lon': 0, 'lat': 0},
    'IA'  : {'lon': 0, 'lat': 0},
    'PA'  : {'lon': 0, 'lat': 0}
}

In [None]:
import colorsys
class MplColorHelper:
    
    def __init__(self, cmap, start_val, stop_val):
        import matplotlib as mpl
        import matplotlib.pyplot as plt
        from matplotlib import cm
        if isinstance(cmap,str):
            self.cmap_name = cmap
            self.cmap = plt.get_cmap(cmap)
        else:
            self.cmap = cmap
        self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)
        self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)

    def get_rgb(self, val):
        return self.scalarMap.to_rgba(val)
    
    def get_hls(self, val):
        rgb = self.scalarMap.to_rgba(val)
        return colorsys.rgb_to_hls(*rgb[:-1])


In [None]:
def SetupMap(proj='PlateCarree',cm=180,latrange=[-90,0],figsize=[24,6]):
    main_fig,main_ax,transf = ac.Projection(proj,coast=True,kw_args={'central_longitude':cm},fig_args={'figsize':figsize})
    aspect_fig = main_fig.get_figwidth()/main_fig.get_figheight()
    aspect_man = figsize[0]/figsize[1]
    if aspect_man > aspect_fig:
        figsize = [figsize[0],figsize[0]/aspect_fig]
    else:
        figsize = [figsize[1]*aspect_fig,figsize[1]]
    main_fig.set_figheight(figsize[1])
    main_fig.set_figwidth(figsize[0])
    minlon = 0 + cm
    maxlon = 180 + cm
    main_ax.set_extent([minlon,maxlon,latrange[0],latrange[1]],crs=ccrs.PlateCarree())
    return main_fig,main_ax,transf

In [None]:
def AddInset(proj,main_ax,region,latrange=[-90,0],offset={'lon':0,'lat':0}):
    # map is centered at 180. Thus, x=0 is 0lon, x=1 is 360lon
    # we want the insert to be positioned at the region's lower center,
    #  but insert_axes wants lower left corner
    # map length is 360, map height is given by latrange
    from cartopy import crs as ccrs
    inset_size = 0.1
    latwidth = latrange[1]-latrange[0]
    aspect = 360/latwidth
    inset_width = inset_size
    inset_height= inset_size*aspect
    # get the center of the region in PlateCarree()
    region_center_deg = [offset['lon']+0.5*(region['lon'].start+region['lon'].stop),
                         offset['lat']+0.5*(region['lat'].start+region['lat'].stop)]
    # convert to map coordinates
    x_map, y_map = main_ax.projection.transform_point(*region_center_deg, ccrs.PlateCarree())
    # convert to display coordinates
    x_disp, y_disp = main_ax.transData.transform((x_map, y_map))
    # convert to axes coordinates
    x_axes, y_axes = main_ax.transAxes.inverted().transform((x_disp, y_disp))
    # now need to convert center to lower corner
    lower_corner = {'lon': x_axes-inset_width/2,
                    'lat': y_axes}
    inset_ax = main_ax.inset_axes([lower_corner['lon'],lower_corner['lat'],inset_width,inset_height])
    inset_ax.set_facecolor('none')
    return inset_ax

In [None]:
from scipy import stats

In [None]:
lag_offsets = [20,60]
lat_range = [-90,10]

varn = ['precip','tsurf']
lims = {'tsurf':[-4,4],'precip':[-3,3]}
long_name = {'precip':'Precipitation anomaly [mm/d]','tsurf':'Surface temperature anomaly [K]'}

cases = ['pre','post','delta']

sns.set_context('talk',font_scale=1)

proj = 'LambertCylindrical'
cm = 180

#
minlon = 0 + cm
maxlon = 180 + cm
#
data_aspect = 360/np.diff(lat_range)[0]
plot_vars = [precip,tsurf]
for v,var in enumerate(plot_vars):
    fig,axs,transf = ac.Projection(proj,nrows=len(cases),coast=False
                                   ,kw_args={'central_longitude':180}
                                   ,fig_args={'figsize':[24,0.8*24/data_aspect*len(cases)]}
                                  )
    for c,case in enumerate(cases):
        axs[c].set_extent([minlon,maxlon,lat_range[0],lat_range[1]],crs=ccrs.PlateCarree())
        axs[c].coastlines(color='lightgray')
        varname = varn[v]
        for r,region in enumerate(regions.keys()):
            with sns.axes_style('ticks'):
                ax = AddInset(proj,axs[c],regions[region],lat_range,offsets[region])
            tmp = var.sel(regions[region]).mean(['lon','lat'])
            if case == 'post':
                tmp = tmp.sel(lag=slice(lag_offsets[0],lag_offsets[1])).mean('lag')
            elif case == 'pre':
                tmp = tmp.sel(lag=slice(-lag_offsets[1],-lag_offsets[0])).mean('lag')
            elif case == 'delta':
                tmp = tmp.sel(lag=slice(lag_offsets[0],lag_offsets[1])).mean('lag') \
                    - tmp.sel(lag=slice(-lag_offsets[1],-lag_offsets[0])).mean('lag')
            pos = np.sum(tmp > 0).values/len(tmp)
            neg = np.sum(tmp < 0).values/len(tmp)
            meanv = tmp.mean().values
            ctmp = cmaps[varname]
            # color by mean value
            clim = clims[varname]['mean']
            hls = MplColorHelper(ctmp,-clim,clim).get_hls(meanv)
            # transparency by statistical significance
            alph = 1
            # sigtest
            _,pval  = stats.ks_1samp(tmp.values,stats.norm.cdf)
#             all above 5% invisible
            if pval > 0.05:
                alph = 0
            alph = max(0,min(1,(0.1-pval)*20))
            clr = colorsys.hls_to_rgb(*hls)
            ###
            res = sns.kdeplot(tmp,ax=ax,fill=True,color=clr,linewidth=3,edgecolor='black',alpha=alph)
            if lims[varn[v]] is not None:
                ax.set_xlim(*lims[varn[v]])
            ax.axvline(0,color='k')
            ytxt = 0.15*ax.get_ylim()[1]
            xtxt = 0.1*ax.get_xlim()[1]
            bclr = (1,1,1,0.9)
            ax.text( xtxt,ytxt,'{:.0%}'.format(pos),ha='left',backgroundcolor=bclr,size='medium')
            ax.text(-xtxt,ytxt,'{:.0%}'.format(neg),ha='right',backgroundcolor=bclr,size='medium')
            sns.despine(ax=ax,left=True,offset=0)
            ttle = '{0} ({1:.2f})'.format(region,meanv)
            ax.set_title(ttle,backgroundcolor=bclr)
            ax.set_ylabel('')
            ax.set_yticks([])
            xlabl = ''
            ax.set_xlabel(xlabl)
        axs[c].set_title(hist_bin_names[c],fontsize='x-large')
    ac.AddPanelLabels(axs,'upper left',ypos=1.09)
    fig.suptitle(long_name[varn[v]],fontsize='x-large',y=.93)
    fileName = 'figures/regional-pdf_{0}.pdf'.format(varn[v])
    SaveFig(fig,fileName)

# Tropical zonal wind anomalies

In [None]:
palette = sns.color_palette()

In [None]:
utrop = xr.open_dataarray(data_dir+'cm2p1/ucomp_composite_ucomp_0.0_lag-60-60_full.nc').sel(lat=slice(-5,5))
utrop = utrop.sel(level=slice(850,None)).mean('lat').groupby_bins('lag',lag_bins).mean()
utrop = utrop.assign_coords({'lag_bins':lag_bin_names})
pvalu = ac.StatTest(utrop,0,test='T',dim='event')

In [None]:
for lb,lbin in enumerate(utrop.lag_bins.values):
    utrop.mean('event').isel(lag_bins=lb).plot.line(y='level',linewidth=1,color=palette[lb])
    utrop.mean('event').where(pvalu<0.05).isel(lag_bins=lb).plot.line(y='level',color=palette[lb],label=utrop.lag_bins[lb].values)
ax = plt.gcf().axes[0]
ax.set_ylim(1,850)
xlims = ax.get_xlim()
ax.set_xlim(min(xlims[0],-xlims[1]),max(-xlims[0],xlims[1]))
ac.LogPlot(ax)
sns.despine(left=True)
ax.set_title('Tropical zonal mean zonal wind')
ax.set_xlabel('u [m/s]')
ax.legend(loc='lower right',fontsize='small')
SaveFig(plt.gcf(),'figures/qbo_composite.pdf')

# Final warming date

In [None]:
event_times = xr.open_dataarray(data_dir+'cm2p1/event_time_vector.nc')
final_warmings = xr.open_dataarray(data_dir+'cm2p1/final_warmings.nc')
fw_all = final_warmings.dt.dayofyear
fw_ssw = final_warmings.where(event_times>0).dropna('year').dt.dayofyear
#
mean_all = fw_all.mean().values
sigma_all= fw_all.std().values
mean_ssw = fw_ssw.mean().values
sigma_ssw= fw_ssw.std().values
skew_all = stats.skew(fw_all)
skew_ssw = stats.skew(fw_ssw)

In [None]:
all_labl = r'all:    $\mu$={0:d}, $\sigma$={1:d}, s={2:.1f}'.format(int(mean_all),int(sigma_all),skew_all)
ssw_labl = r'SSW: $\mu$={0:d}, $\sigma$={1:d}, s={2:.1f}'.format(int(mean_ssw),int(sigma_ssw),skew_ssw)
ax = sns.kdeplot(fw_all,label=all_labl,color=palette[0],fill=True)
sns.kdeplot(fw_ssw,fill=False
                  ,ax=ax,label=ssw_labl,color=palette[1])
# ax.axvline(mean_all,color=palette[0])
# ax.axvline(mean_ssw,color=palette[1])
ax.set_xlim(280,365)
sns.despine(left=True)
ax.set_ylim(0,0.06)
ax.legend(loc='upper left')
ax.set_ylabel('Final Warming density')
ax.set_title('Final Warming Day')
SaveFig(plt.gcf(),'figures/final_warming_pdf.pdf')