# -*- coding: utf-8 -*-
"""
Code for generating Fig.3 & 4 for the paper: 
Model Sensitivity to Insolation Forcing and its Implications for the Holocene Temperature Conundrum

Last updated on Dec 10 2024

@author: Yuntao Bao
Contact: bao.291@osu.edu
"""

import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd
import cartopy.crs as  ccrs
import cartopy.feature as cfeature
from cartopy.util import add_cyclic_point
from scipy import stats 
from scipy.interpolate import griddata
from scipy.stats import pearsonr

def contour_map(fig,img_extent,spec):
    fig.add_feature(cfeature.COASTLINE.with_scale('110m'), color='black', linewidth=0.8) 
    fig.add_feature(cfeature.LAKES, alpha=0.5)
    fig.set_aspect('auto') 

def read_PMIP(dirn, var, model, idT):
    fname1 = dirn + var+'_'+ model + '_piControl_mon_ltm.nc'
    fname11 = dirn + var+'_'+ model + '_midHolocene_mon_ltm.nc'
    ds1 = xr.open_dataset(fname1)
    ds11 = xr.open_dataset(fname11)  
    T_diff = np.mean(ds11[var][idT,:,:].values,0) - np.mean(ds1[var][idT,:,:].values,0)
    T_diff[np.abs(T_diff)<1e-8] = np.nan
    return T_diff
    
def signal_noise_ratio(Tdiff_mean, Tdiff_std):
    ratio = np.abs(Tdiff_mean)/np.abs(Tdiff_std)
    ratio[ratio>=1.] = 1
    ratio[ratio<1.] = np.nan
    return ratio

def T_test_sign(data_all, data_warm, data_cold):  
    Ny = np.shape(data_all)[1]
    Nx = np.shape(data_all)[2]
    # T test (significance):
    Pwarm = np.zeros((Ny,Nx),dtype='float32')
    Pcold = np.zeros((Ny,Nx),dtype='float32')
    Pall = np.zeros((Ny,Nx),dtype='float32')
    # T test:
    for i in range(Nx):
        for j in range(Ny):
            _, Pwarm[j,i] = stats.ttest_1samp(data_warm[:,j,i], popmean=0)
            _, Pcold[j,i] = stats.ttest_1samp(data_cold[:,j,i], popmean=0)
            _, Pall[j,i] = stats.ttest_1samp(data_all[:,j,i], popmean=0)
    tag_Pwarm = np.copy(Pwarm)   
    tag_Pcold = np.copy(Pcold)  
    tag_Pall = np.copy(Pall)           
    tag_Pwarm[Pwarm<=0.05]=1
    tag_Pcold[Pcold<=0.05]=1
    tag_Pall[Pall<=0.05]=1
    tag_Pwarm[Pwarm>0.05]=np.nan
    tag_Pcold[Pcold>0.05]=np.nan
    tag_Pall[Pall>0.05]=np.nan  
    return tag_Pall, tag_Pwarm, tag_Pcold 
    

var = 'tas'
modNlist = ['EC-Earth3-LR', 'ACCESS-ESM1-5', 'FGOALS-g3', 'CESM2', 'INM-CM4-8', 'MRI-ESM2-0',
            'GISS-E2-1-G', 'NorESM2-LM', 'NESM3', 'MPI-ESM1-2-LR', 'MIROC-ES2L', 'FGOALS-f3-L', 
            'AWI-ESM-1-1-LR', 'IPSL-CM6A-LR', 'iTRACE']
Nc = len(modNlist)
# ECS (double CO2)
ECS = np.array([4.3, 3.9, 2.9, 5.2, 1.9, 3.2, 
                2.7, 2.5, 4.7, 2.9, 2.7, 3.0, 
                3.6, 4.6, 3.3])
ECS_diff = 20/280*ECS  # midHolocene cooling

idT_ANN = [0,1,2,3,4,5,6,7,8,9,10,11] 
idT_MAM = [2,3,4]
idT_JJA = [5,6,7]  
idT_SON = [8,9,10]     
idT_DJF = [11,0,1]   

dirn = './'        
# Read temp and calculate diff from PMIP4 model
Tdiff_ANN = []
Tdiff_JJA = []
Tdiff_SON = []
Tdiff_DJF = []
Tdiff_MAM = []
for i in range(Nc):
    model = modNlist[i]
    Tdiff_ANN.append(read_PMIP(dirn, var, model, idT_ANN))
    Tdiff_JJA.append(read_PMIP(dirn, var, model, idT_JJA))
    Tdiff_SON.append(read_PMIP(dirn, var, model, idT_SON))
    Tdiff_DJF.append(read_PMIP(dirn, var, model, idT_DJF))    
    Tdiff_MAM.append(read_PMIP(dirn, var, model, idT_MAM))    

# Read temperature ratio pattern from 4xCO2 simulations
ds = xr.open_dataset(dirn+var+'_ANN_abrupt-4xCO2_PI_diff_ratio_ltm.nc')
T_CO2F = []
lonList = []
latList = []
T_CO2F_glob = np.zeros((Nc), dtype='float32')
for i in range(Nc):
    T_CO2ratio = ds['Tdiff_ratio_'+modNlist[i]].values
    T_CO2F.append(ECS_diff[i]*T_CO2ratio)
    lonList.append(ds['lon_'+modNlist[i]].values)
    latList.append(ds['lat_'+modNlist[i]].values)
    
    coslat1 = np.cos(np.deg2rad(latList[i]))
    weight1 = coslat1 / np.nanmean(coslat1)
    T_CO2F_glob[i] = np.nanmean(np.nanmean(T_CO2F[i]*weight1[:,None],0),0)
        
T_insolF = []
T_insolF_glob = np.zeros((Nc), dtype='float32')
T_diff_glob = np.zeros((Nc), dtype='float32')
T_diff_NH = np.zeros((Nc,5), dtype='float32')
T_CO2F_NH = np.zeros((Nc), dtype='float32')
for i in range(Nc):
    T_insolF.append(Tdiff_ANN[i] + T_CO2F[i])
    coslat1 = np.cos(np.deg2rad(latList[i]))
    weight1 = coslat1 / np.nanmean(coslat1)
    T_insolF_glob[i] = np.nanmean(np.nanmean(T_insolF[i]*weight1[:,None],0),0)    
    T_diff_glob[i] = np.nanmean(np.nanmean(Tdiff_ANN[i]*weight1[:,None],0),0)    
    
    idx_need = np.where(latList[i]>=30)[0]  #30~90N
    coslat2 = np.cos(np.deg2rad(latList[i][idx_need]))
    weight2 = coslat2 / np.nanmean(coslat2)
    T_diff_NH[i,0] = np.nanmean(np.nanmean(Tdiff_ANN[i][idx_need,:]*weight2[:,None],0),0) 
    T_diff_NH[i,1] = np.nanmean(np.nanmean(Tdiff_JJA[i][idx_need,:]*weight2[:,None],0),0) 
    T_diff_NH[i,2] = np.nanmean(np.nanmean(Tdiff_SON[i][idx_need,:]*weight2[:,None],0),0)   
    T_diff_NH[i,3] = np.nanmean(np.nanmean(Tdiff_DJF[i][idx_need,:]*weight2[:,None],0),0)  
    T_diff_NH[i,4] = np.nanmean(np.nanmean(Tdiff_MAM[i][idx_need,:]*weight2[:,None],0),0)
    T_CO2F_NH[i] = np.nanmean(np.nanmean(T_CO2F[i][idx_need,:]*weight2[:,None],0),0)
    
# Multimodel mean    
# grids to calculate ensumble
long = np.arange(0,360,2)
latg = np.arange(-90,90.1,1.5)
lon_mesh, lat_mesh = np.meshgrid(long, latg)
Nxg, Nyg = len(long), len(latg)
Tdiff_remap = np.zeros((Nc,Nyg,Nxg), dtype='float32')
TinsolF_remap = np.zeros((Nc,Nyg,Nxg), dtype='float32')
TCO2F_remap = np.zeros((Nc,Nyg,Nxg), dtype='float32')
for i in range(Nc):
    Xgrid, Ygrid = np.meshgrid(lonList[i], latList[i])
    # remap each model grid to a new grid
    Tdiff_remap[i,:,:] = griddata((Ygrid.flatten(), Xgrid.flatten()), Tdiff_ANN[i].flatten(), (lat_mesh, lon_mesh), method='linear') 
    TinsolF_remap[i,:,:] = griddata((Ygrid.flatten(), Xgrid.flatten()), T_insolF[i].flatten(), (lat_mesh, lon_mesh), method='linear')  
    TCO2F_remap[i,:,:] = griddata((Ygrid.flatten(), Xgrid.flatten()), T_CO2F[i].flatten(), (lat_mesh, lon_mesh), method='linear')                            
Tdiff_ens = np.nanmean(Tdiff_remap,0)  
Tdiff_warm = np.nanmean(Tdiff_remap[0:6,:,:],0)     
Tdiff_cold = np.nanmean(Tdiff_remap[9:,:,:],0)    

TinsolF_ens = np.nanmean(TinsolF_remap,0)  
TinsolF_warm = np.nanmean(TinsolF_remap[0:6,:,:],0)     
TinsolF_cold = np.nanmean(TinsolF_remap[9:,:,:],0)      

TCO2F_ens = np.nanmean(-TCO2F_remap,0)  
TCO2F_warm = np.nanmean(-TCO2F_remap[0:6,:,:],0)     
TCO2F_cold = np.nanmean(-TCO2F_remap[9:,:,:],0)   

Tdiff_ens_sig = np.nanstd(Tdiff_remap,0)  
Tdiff_warm_sig = np.nanstd(Tdiff_remap[0:6,:,:],0)     
Tdiff_cold_sig = np.nanstd(Tdiff_remap[9:,:,:],0)    

TinsolF_ens_sig = np.nanstd(TinsolF_remap,0)  
TinsolF_warm_sig = np.nanstd(TinsolF_remap[0:6,:,:],0)     
TinsolF_cold_sig = np.nanstd(TinsolF_remap[9:,:,:],0)      

TCO2F_ens_sig = np.nanstd(-TCO2F_remap,0)  
TCO2F_warm_sig = np.nanstd(-TCO2F_remap[0:6,:,:],0)     
TCO2F_cold_sig = np.nanstd(-TCO2F_remap[9:,:,:],0)   

Tdiff_tag_all, Tdiff_tag_warm, Tdiff_tag_cold = T_test_sign(Tdiff_remap, Tdiff_remap[0:6,:,:], Tdiff_remap[9:,:,:])
TinsolF_tag_all, TinsolF_tag_warm, TinsolF_tag_cold = T_test_sign(TinsolF_remap, TinsolF_remap[0:6,:,:], TinsolF_remap[9:,:,:])
TCO2F_tag_all, TCO2F_tag_warm, TCO2F_tag_cold = T_test_sign(-TCO2F_remap, -TCO2F_remap[0:6,:,:], -TCO2F_remap[9:,:,:])


#%% Scatter plot
lablist_all = ['1','2','3','4','5','6','7','8','9','10','11','12','13','14','15']
modNlist_all = ['EC-Earth3-LR', 'ACCESS-ESM1-5', 'FGOALS-g3', 'CESM2', 'INM-CM4-8', 'MRI-ESM2-0',
            'GISS-E2-1-G', 'NorESM2-LM', 'NESM3', 'MPI-ESM1-2-LR', 'MIROC-ES2L', 'FGOALS-f3-L', 
            'AWI-ESM-1-1-LR', 'IPSL-CM6A-LR', 'iTRACE']
Nc_all = len(lablist_all)
# lablist1 = ['1','2','3','4','5','6','7','8','9','10','11','12','13','14']
# modNlist1 = ['EC-Earth3-LR', 'ACCESS-ESM1-5', 'FGOALS-g3', 'CESM2', 'INM-CM4-8', 'MRI-ESM2-0',
#             'GISS-E2-1-G', 'NorESM2-LM', 'NESM3', 'MPI-ESM1-2-LR', 'MIROC-ES2L', 'FGOALS-f3-L', 
#             'IPSL-CM6A-LR', 'iTRACE']   # Note: No AWI 4XCO2 data can be found!!
lablist1 = lablist_all
modNlist1 = modNlist_all
Nc1 = len(lablist1)

# Plot figure
fig = plt.figure(figsize=(6,11))  
ax1 = fig.add_axes([0.18, 0.68, 0.68, 0.26])
ax1.scatter(T_diff_glob[0:6],-T_CO2F_glob[0:6], color='red', alpha=0.6, s=200)
ax1.scatter(T_diff_glob[6:9],-T_CO2F_glob[6:9], color='orange', alpha=0.7, s=200)
ax1.scatter(T_diff_glob[9:],-T_CO2F_glob[9:], color='blue', alpha=0.6, s=200)
ax1.plot(np.linspace(-0.7,0,100), np.linspace(-0.7,0,100), color='gray', linestyle='dotted')
ax1.set_xlim([-0.6,0])
ax1.set_ylim([-0.6,0])
ax1.text(-0.57, -0.05, '(a) Response: $CO_2$ vs. Full', fontweight='bold', fontsize=12)
ax1.set_xlabel('$\Delta T_{2m}\ (^\circ C)$', fontweight='bold', fontsize=13)
ax1.set_ylabel('$\Delta T^{CO2}_{2m}\ (^\circ C)$', fontweight='bold', fontsize=13)
ax1.xaxis.set_label_coords(0.5, 0.1)   
for i in range(Nc1):
    ax1.text(T_diff_glob[i]-0.014, -T_CO2F_glob[i]-0.006, lablist1[i], fontweight='bold', color='white')
R2, p2 = pearsonr(T_diff_glob, -T_CO2F_glob)    
ax1.text(-0.55,-0.12,'R=%.2f' %(R2), fontsize=12, fontweight='bold', style='italic')  

ax2 = fig.add_axes([0.18, 0.38, 0.68, 0.26])
ax2.scatter(T_diff_glob[0:6],T_insolF_glob[0:6], color='red', alpha=0.6, s=200)
ax2.scatter(T_diff_glob[6:9],T_insolF_glob[6:9], color='orange', alpha=0.7, s=200)
ax2.scatter(T_diff_glob[9:],T_insolF_glob[9:], color='blue', alpha=0.6, s=200)
ax2.plot(np.linspace(-0.6,0.4,100), np.linspace(-0.6,0.4,100), color='gray', linestyle='dotted')
ax2.set_xlim([-0.6,0.4])
ax2.set_ylim([-0.6,0.4])
ax2.text(-0.57, 0.31, '(b) Response: Insolation vs. Full', fontweight='bold', fontsize=12)
ax2.set_xlabel('$\Delta T_{2m}\ (^\circ C)$', fontweight='bold', fontsize=13)
ax2.set_ylabel('$\Delta T^{Insolation}_{2m}\ (^\circ C)$', fontweight='bold', fontsize=13)
ax2.xaxis.set_label_coords(0.5, 0.1)  
#ax2.xaxis.set_label_coords(0.5, 0.1)  
for i in range(Nc1):
    ax2.text(T_diff_glob[i]-0.014, T_insolF_glob[i]-0.013, lablist1[i], fontweight='bold', color='white')
    if i<6:
        clr='red'
    elif i>=9:
        clr='blue'
    else:
        clr='orange'
    # legend    
    ax2.text(-0.01,0.15-i*0.05, lablist1[i]+': '+modNlist1[i], fontsize=10, fontweight='bold', color=clr)
    
R1, p1 = pearsonr(T_diff_glob, T_insolF_glob)    
ax2.text(-0.53,0.2,'R=%.2f**' %(R1), fontsize=12, fontweight='bold', style='italic')   

ax3 = fig.add_axes([0.18, 0.08, 0.68, 0.26])
ax3.scatter(T_insolF_glob[0:6], -T_CO2F_glob[0:6], color='red', alpha=0.6, s=200)
ax3.scatter(T_insolF_glob[6:9], -T_CO2F_glob[6:9], color='orange', alpha=0.7, s=200)
ax3.scatter(T_insolF_glob[9:], -T_CO2F_glob[9:], color='blue', alpha=0.6, s=200)
ax3.plot(np.linspace(-0.5,0.4,100), -np.linspace(-0.5,0.4,100), color='gray', linestyle='dotted')
ax3.set_xlim([-0.4,0.4])
ax3.set_ylim([-0.4,0])
ax3.set_yticks(np.arange(-0.4,0.01,0.1))
ax3.text(-0.37, -0.04, '(c) Response: Insolation vs. $CO_2$', fontweight='bold', fontsize=12)
ax3.set_xlabel('$\Delta T^{Insolation}_{2m}\ (^\circ C)$', fontsize=13)
ax3.set_ylabel('$\Delta T^{CO2}_{2m}\ (^\circ C)$', fontsize=13)
ax3.xaxis.set_label_coords(0.5, 0.1)  
#ax3.xaxis.set_label_coords(0.5, 0.1)  
for i in range(Nc):
    ax3.text(T_insolF_glob[i]-0.015, -T_CO2F_glob[i]-0.006, lablist1[i], fontweight='bold', color='white')
R3, p3 = pearsonr(T_insolF_glob, -T_CO2F_glob)    
ax3.text(-0.34,-0.09,'R=%.2f**' %(R3), fontsize=12, fontweight='bold', style='italic')  

plt.savefig('Fig3.jpg', dpi=380)

#%% Model average
cmap = plt.get_cmap('RdBu_r') 
cmapList = [cmap(i) for i in range(cmap.N)]
del (cmapList[251:])
del (cmapList[:5])
del (cmapList[int(len(cmapList)/2)-3:int(len(cmapList)/2)+3])
cmapList.insert(int(len(cmapList)/2), (255., 255., 255., 1.))

mycmap = cmap.from_list('my_cmap', cmapList[0:], cmap.N)

clevs1 =  np.array([-2.4,-2,-1.6,-1.2,-0.8,-0.6,-0.4,-0.2,0.2,0.4,0.6,0.8,1.2,1.6,2,2.4])
#clevs2 =  np.arange(-1.,-0.09,0.1)
TitList = ['(a)','(b)','(c)','(d)','(e)','(f)','(g)','(h)','(i)','(j)','(k)','(l)','(m)']

proj = ccrs.Robinson() #ccrs.PlateCarree(central_longitude=0)  # 
leftlon, rightlon, lowerlat, upperlat = (-180,180.1,-90,90.1)
img_extent = [leftlon, rightlon, lowerlat, upperlat]

fig = plt.figure(figsize=(12.5,8))
axes = fig.add_subplot(3, 3, 1, projection=proj)
Tdiff_1, lon0 = add_cyclic_point(Tdiff_ens, coord=long)
c1 = axes.contourf(lon0, latg, Tdiff_1, 
                   levels=clevs1, extend='both',
                   transform=ccrs.PlateCarree(), cmap=mycmap) 
axes.contourf(long, latg, Tdiff_tag_all, colors='none',
              hatches=['..', None], transform=ccrs.PlateCarree())    
plt.rcParams.update({'hatch.color': 'gray'})
contour_map(axes,img_extent,60)
axes.set_title('(a) $\Delta T^{Full}_{2m}$ All models:%.2f$^o$C' %(np.mean(T_diff_glob)), loc='center') 

axes = fig.add_subplot(3, 3, 2, projection=proj)
Tdiff_2, lon0 = add_cyclic_point(Tdiff_warm, coord=long)
axes.contourf(lon0, latg, Tdiff_2, 
                   levels=clevs1, extend='b   oth',
                   transform=ccrs.PlateCarree(), cmap=mycmap) 
axes.contourf(long, latg, Tdiff_tag_warm, colors='none',
              hatches=['..', None], transform=ccrs.PlateCarree())     
contour_map(axes,img_extent,60)
axes.set_title('(b) $\Delta T^{Full}_{2m}$ Warm models:%.2f$^o$C' %(np.mean(T_diff_glob[0:6])), loc='center') 

axes = fig.add_subplot(3, 3, 3, projection=proj)
Tdiff_3, lon0 = add_cyclic_point(Tdiff_cold, coord=long)
axes.contourf(lon0, latg, Tdiff_3, 
                   levels=clevs1, extend='both',
                   transform=ccrs.PlateCarree(), cmap=mycmap) 
axes.contourf(long, latg, Tdiff_tag_cold, colors='none',
              hatches=['..', None], transform=ccrs.PlateCarree()) 
contour_map(axes,img_extent,60)
axes.set_title('(c) $\Delta T^{Full}_{2m}$ Cold models:%.2f$^o$C' %(np.mean(T_diff_glob[9:])), loc='center') 

axes = fig.add_subplot(3, 3, 4, projection=proj)
TCO2F_1, lon0 = add_cyclic_point(TCO2F_ens, coord=long)
c2 = axes.contourf(lon0, latg, TCO2F_1, 
                   levels=clevs1, extend='both',
                   transform=ccrs.PlateCarree(), cmap=mycmap)
axes.contourf(long, latg, TCO2F_tag_all, colors='none',
              hatches=['..', None], transform=ccrs.PlateCarree())  
contour_map(axes,img_extent,60)
axes.set_title('(d) $\Delta T^{CO_2}_{2m}$ All models:%.2f$^o$C' %(np.mean(-T_CO2F_glob)), loc='center') 

axes = fig.add_subplot(3, 3, 5, projection=proj)
TCO2F_2, lon0 = add_cyclic_point(TCO2F_warm, coord=long)
axes.contourf(lon0, latg, TCO2F_2, 
                   levels=clevs1, extend='both',
                   transform=ccrs.PlateCarree(), cmap=mycmap) 
axes.contourf(long, latg, TCO2F_tag_warm, colors='none',
              hatches=['..', None], transform=ccrs.PlateCarree())  
contour_map(axes,img_extent,60)
axes.set_title('(e) $\Delta T^{CO_2}_{2m}$ Warm models:%.2f$^o$C' %(np.mean(-T_CO2F_glob[0:6])), loc='center') 

axes = fig.add_subplot(3, 3, 6, projection=proj)
TCO2F_3, lon0 = add_cyclic_point(TCO2F_cold, coord=long)
axes.contourf(lon0, latg, TCO2F_3, 
                   levels=clevs1, extend='both',
                   transform=ccrs.PlateCarree(), cmap=mycmap) 
axes.contourf(long, latg, TCO2F_tag_cold, colors='none',
              hatches=['..', None], transform=ccrs.PlateCarree())   
contour_map(axes,img_extent,60)
axes.set_title('(f) $\Delta T^{CO_2}_{2m}$ Cold models:%.2f$^o$C' %(np.mean(-T_CO2F_glob[9:])), loc='center')

axes = fig.add_subplot(3, 3, 7, projection=proj)
TinsolF_1, lon0 = add_cyclic_point(TinsolF_ens, coord=long)
axes.contourf(lon0, latg, TinsolF_1, 
                   levels=clevs1, extend='both',
                   transform=ccrs.PlateCarree(), cmap=mycmap)
axes.contourf(long, latg, TinsolF_tag_all, colors='none',
              hatches=['..', None], transform=ccrs.PlateCarree()) 
contour_map(axes,img_extent,60)
axes.set_title('(g) $\Delta T^{Insolation}_{2m}$ All models:%.2f$^o$C' %(np.mean(T_insolF_glob)), loc='center') 

axes = fig.add_subplot(3, 3, 8, projection=proj)
TinsolF_2, lon0 = add_cyclic_point(TinsolF_warm, coord=long)
axes.contourf(lon0, latg, TinsolF_2, 
                   levels=clevs1, extend='both',
                   transform=ccrs.PlateCarree(), cmap=mycmap) 
axes.contourf(long, latg, TinsolF_tag_warm, colors='none',
              hatches=['..', None], transform=ccrs.PlateCarree()) 
contour_map(axes,img_extent,60)
axes.set_title('(h) $\Delta T^{Insolation}_{2m}$ Warm models:%.2f$^o$C' %(np.mean(T_insolF_glob[0:6])), loc='center') 

axes = fig.add_subplot(3, 3, 9, projection=proj)
TinsolF_3, lon0 = add_cyclic_point(TinsolF_cold, coord=long)
axes.contourf(lon0, latg, TinsolF_3, 
                   levels=clevs1, extend='both',
                   transform=ccrs.PlateCarree(), cmap=mycmap) 
axes.contourf(long, latg, TinsolF_tag_cold, colors='none',
              hatches=['..', None], transform=ccrs.PlateCarree()) 
contour_map(axes,img_extent,60)
axes.set_title('(i) $\Delta T^{Insolation}_{2m}$ Cold models:%.2f$^o$C' %(np.mean(T_insolF_glob[9:])), loc='center') 

plt.rcParams.update({'hatch.color': 'gray'})

posit1=fig.add_axes([0.32, 0.06, 0.38, 0.012])
fig.colorbar(c1, cax=posit1, ticks=clevs1, orientation='horizontal',format='%.1f')     

plt.savefig('Fig4.jpg', dpi=380)
