#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 16 18:46:09 2022

@author: grahamr
"""
import numpy as np
import matplotlib.pyplot as plt
import netCDF4 as nc
import os
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.ticker import ScalarFormatter
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from mpl_toolkits import mplot3d

'''
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A script to reproduce the first eight plots of Graham and Pierrehumbert 2024
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
To reproduce Figure 9, you need to run isca_weathering_calc.py with the 
data included in this Zenodo object.


'''
#%% Figure 1 

continents_path = "C:/Users/sobcr/Documents/isca_data/land.nc"
continents_data = nc.Dataset(continents_path)
lat = np.squeeze(continents_data.variables['lat'][:])
lon = np.squeeze(continents_data.variables['lon'][:])
land_mask = np.squeeze(continents_data.variables['land_mask'][:])

plt.figure()
plt.contour(lon,lat,land_mask,colors='k',levels=[0])
plt.xlabel('Longitude [degrees]')
plt.ylabel('Latitude [degrees]')
plt.title('Continental configuration in simulations')

#%% Populating arrays with values from calculations
S_list = [675,750,800,1000,1250]
color_list = ['b','c','m','r','g']
red = mpatches.Patch(color='red', label=r'S = 1000 W m$^{-2}$')
green = mpatches.Patch(color='green', label=r'S = 1250 W m$^{-2}$')
magenta = mpatches.Patch(color='magenta', label=r'S = 800 W m$^{-2}$')
cyan = mpatches.Patch(color='cyan', label=r'S = 750 W m$^{-2}$')

blue = mpatches.Patch(color='blue', label=r'S = 675 W m$^{-2}$')
star = mlines.Line2D([], [], color='black', marker='*',
                          markersize=10, label=r'Global-mean $S_{\rm absorbed}$')
dot = mlines.Line2D([], [], color='black', marker='o',
                          markersize=10, label='Global-mean latent\nheat flux')

plus = mlines.Line2D([], [], color='black', marker='P',
                          markersize=10, label='Net global-mean upward\nsensible heat flux')
x = mlines.Line2D([], [], color='black', marker='x',
                          markersize=10, label='Net global-mean upward\nsurface longwave flux')


mac_weathering = {}
mac_weathering[675]= [13.109873891321348,17.600118554956136,16.01936963818386,14.489892550159395]
mac_weathering[675]= [17.600118554956136,16.01936963818386,14.489892550159395]
mac_weathering[750]=[16.831139535156677,13.632047181427271,13.082953036831736]
mac_weathering[800]=[14.761976085314812, 12.201122518509933, 12.069717584618655]
mac_weathering[1000]= [8.577042141455403,6.907487589122413]
mac_weathering[1250]= [5.571764386750302,6.001713720599974,6.162981750606316]

whak_weathering = {}
whak_weathering[675]= [9.702083350570764,43.39293419495854,104.08100601225117,196.04874058327968]
whak_weathering[675]= [43.39293419495854,104.08100601225117,196.04874058327968]
whak_weathering[750]=[38.616800384272494,205.9612949917343,436.77799849892244]
whak_weathering[800]=[91.67896219683881,477.85676631738335,854.2373959150348]
whak_weathering[1000]= [1847.9871264633657,4746.779560796297]
whak_weathering[1250]= [30.757718549831548,39.089319877592565,54.36122854878972]

albedo = {}
albedo[675]= [0.1691010828103949,0.2279736573344702,0.26822546824275895,0.2986480208769898]
albedo[675]= [0.2279736573344702,0.26822546824275895,0.2986480208769898]
albedo[750]=[0.16732018222446177,0.2249312829507736,0.264131438453193]
albedo[800]= [0.16606964979690056, 0.22282320428856248, 0.2616534133998029]
albedo[1000]= [0.16050946855720608, 0.21421735086704752]
albedo[1250]= [0.10777428080612968,0.10740811095955875,0.10685977117612054]

sensible = {}
sensible[675]= [13.493888988861805,3.644728580048031,0.5926879131090256,-0.9405131166213648]
sensible[675]= [3.644728580048031,0.5926879131090256,-0.9405131166213648]
sensible[750]= [6.147461871699133,0.953884267057392,-0.8911973730402294]
sensible[800]= [4.527586612018108, 0.9875279234879657,-0.8310762462093999]
sensible[1000]=[7.16874888086675,5.803818268925583]
sensible[1250]= [10.95040510519872,10.912780338203854,11.262139716967496]

lw = {}
lw[675]= np.array([37.035705702876086,20.733599503718285,14.359066073914688,9.388610113835758])
lw[675]= np.array([20.733599503718285,14.359066073914688,9.388610113835758])
lw[750]= np.array([33.080209895495734,17.293483561730785,9.076827035448138])
lw[800]= np.array([31.358313182478877,14.509741491581039,6.214731612668568])
lw[1000]= np.array([25.04292159205062, 7.049890007171377])
lw[1250]= np.array([83.4644885093483,79.57604800227108,74.76540838657456])

instell = {}
instell[675]= np.array([109.89, 89.475, 76.610, 66.821])
instell[675]= np.array([89.475, 76.610, 66.821])
instell[750]=np.array([117.004394319221,93.26307363431306,78.06367042241558])
instell[800]= np.array([121.04 ,  94.965,  79.458])
instell[1000]= np.array([134.14,101.39])
instell[1250]= np.array([218.9545968179087,217.4396499963897,215.2024609693146])

delta_rad = {}
for n in S_list:
    delta_rad[n]= instell[n]-lw[n]


latent_heat_water = 2.5e6 #J/kg -> this is what Isca uses

latent = {}
latent[675]= [59.487, 65.979, 61.264, 58.091]
latent[675]= [65.979, 61.264, 58.091]
latent[750]=[77.55483172761629,74.70489560274545,69.3082238492803139]
latent[800]= [85.77 , 79.15, 73.08]
latent[1000]= [101.67,87.44]
latent[1250]= [124.75827483733296,126.55854868994565,129.47895484778294]



pressures = {}
pressures[675]= np.array([1, 2, 3, 4])#*1000
pressures[675]= np.array([2, 3, 4])#*1000
pressures[750]= np.array([1,2,3])
pressures[800]= np.array([1, 2, 3])#*1000
pressures[1000]= np.array([1,2])#*1000
pressures[1250]= np.array([200e-6,300e-6,400e-6])#*1000


runoff = {}
runoff[675] = [0.2013436518906263,0.32451806235059116,0.35344011223304045]
runoff[750] = [0.2700240350882417,0.44049385970600957,0.6362029291981199]
runoff[800] = [0.31245399255637496,0.5669555950335301,0.8389878719784156]
runoff[1000] = [0.45019831027166785,0.820646700658827]
runoff[1250] = [0.8178690496792347,0.8627310715238912,0.839845874690808]



precip = {}
precip[675]= np.array([2.375e-5,2.642e-5,2.452e-5,2.331e-5])* 3.15e7/360 #dividing by 1000]to convert from kg/m^2/s to m^3/m^2/s, multiplying by 1000]to convert to mm/s and multiplying by 3.15e7 to go from mm/s to mm/yr and dividing by 360 to go from mm/yr to mm/day
precip[675]= np.array([2.642e-5,2.452e-5,2.331e-5])* 3.15e7/360 #dividing by 1000]to convert from kg/m^2/s to m^3/m^2/s, multiplying by 1000]to convert to mm/s and multiplying by 3.15e7 to go from mm/s to mm/yr and dividing by 360 to go from mm/yr to mm/day
precip[750]=np.array([3.1005934803089034e-05,2.9913089720056806e-05,2.7612831524977955e-05])* 3.15e7/360 
precip[800]= np.array([3.429e-5,3.165e-5,2.89e-5])* 3.15e7/360 
precip[1000]= np.array([4.08e-5,3.5e-5]) * 3.15e7/360  
precip[1250]= np.array([4.991279257832359e-05,5.067768976333801e-05,5.183638238986899e-05])* 3.15e7/360 

land_precip = {}
land_precip[675]= np.array([1.6517321339331008,1.7108704148981448,1.599632779945908,1.4874170502454565])
land_precip[675] = np.array([1.7108704148981448,1.599632779945908,1.4874170502454565])
land_precip[750]= np.array([1.9923503347744822,1.737508876099421,1.7661193782697375])
land_precip[800]= np.array([1.9706672490488952,1.845079292307217,1.9782899002208112])
land_precip[1000]= np.array([1.568300239476611,1.5828761376103935])
land_precip[1250]= np.array([2.8116247243364745,2.8576623829901577,2.77857147002758])


planetary_surface_area = 510064471909788.25
land_surface_area = 119173695324258.66
frac_land_precip = {}
for n in S_list:
    frac_land_precip[n]= land_precip[n]*land_surface_area/(precip[n]*planetary_surface_area)
land_ratio = {}
for n in S_list:
    land_ratio[n] = land_precip[n]/precip[n]

tsurf = {}
tsurf[675] = [277.546,293.046,302.004,308.5050]
tsurf[675] = [293.046,302.004,308.5050]
tsurf[750] =[289.5928909596203,305.37715610789644,314.27905849246986]
tsurf[800] = [296.601,312.389,320.128]
tsurf[1000] = [318.462,332.846]
tsurf[1250] = [298.10836637208524,299.6059961007419,301.5111424070665]




dRdT = {}
dLdT = {}                                                                       
for n in S_list:
    dLdT[n] = (latent[n][-1]-latent[n][0])/(tsurf[n][-1]-tsurf[n][0])
    dRdT[n] = (delta_rad[n][-1]-delta_rad[n][0])/(tsurf[n][-1]-tsurf[n][0])
latent_frac = {}
latent_frac[675]= np.array([0.54133224, 0.73740151, 0.79968672, 0.86935245])
latent_frac[675] = np.array([0.73740151, 0.79968672, 0.86935245])
latent_frac[750] = np.array([0.66283691, 0.80101258, 0.88784224])
latent_frac[800] = np.array([0.70860872, 0.83346496, 0.91973118])
latent_frac[1000]= np.array([0.75793947, 0.86241247])
latent_frac[1250] = np.array([0.56979062, 0.58203988, 0.60166113])



#%% Figure 2
fig,(ax1,ax2) = plt.subplots(1,2)
for n,S in enumerate(S_list):
    ax1.plot(pressures[S],tsurf[S],color_list[n]+'.-')#,label=r'$S$=675]W m$^{-2}$')
    ax2.plot(pressures[S],tsurf[S],color_list[n]+'.-')#,label=r'$S$=675]W m$^{-2}$')
ax1.set_xscale('log')



ax2.set_xscale('log')
#ax1.set_xlim(100e-6*10**3,500e-6*10**3)
#ax2.set_xlim(0.9*10**3,5*10**3)
ax1.set_xlim(150e-6,500e-6)
ax2.set_xlim(0.95,4.1)

#

ax1.spines.right.set_visible(False)
ax2.spines.left.set_visible(False)
ax2.yaxis.tick_right()
ax2.tick_params(labelright=False)  # Don't put tick labels at the right
ax1.yaxis.tick_left()

# Adds slanted lines to axes
d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(
    marker=[(-1, -d), (1, d)],
    markersize=12,
    linestyle='none',
    color='k',
    mec='k',
    mew=1,
    clip_on=False
)
formatter = ScalarFormatter()
#formatter.set_scientific(True) 
#formatter.set_powerlimits((0,0))

#formatter.set_scilimits(0,0)
#formatter.style('plain')
#ax1.xaxis.set_major_formatter(formatter)
#ax1.ticklabel_format(axis='x',style='sci',scilimits=(0,0)) 
ax2.xaxis.set_major_formatter(formatter)
#ax2.ticklabel_format(axis='x',style='plain',useMathText=True) 
#ax1.xaxis.set_minor_formatter(formatter)
ax2.xaxis.set_minor_formatter(formatter)
tick_labels = np.array(ax2.get_xticks(minor=True))
ax2.set_xticklabels(tick_labels.astype(int),minor=True)
tick_labels = np.array(ax2.get_xticks())
ax2.set_xticklabels(tick_labels.astype(int))

plt.show()
ax1.plot([1, 1], [1, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 0], [1, 0], transform=ax2.transAxes, **kwargs)
ax1.set_ylabel(r'Global-mean surface temperature [K]')
fig.text(0.5,0.03,r'$p$CO$_2$ [bars]',ha='center')
#plt.xscale('log')

fig.set_size_inches(8, 5)

ax1.legend(handles=[blue,cyan,magenta,red,green],loc='best')
#ax1.set_xticklabels([0.0001,0.0002,0.0003,0.0004])
#plt.legend(loc='best')
fig.suptitle(r'High $p$CO$_2$ simulations display greater ECS, as expected')

#%% Figure 3
fig, (ax1,ax2) = plt.subplots(1,2)

for n,S in enumerate(S_list):
    
    ax1.plot(pressures[S],instell[S],color_list[n]+'*-')#,label=r'$S_{\rm abs}$, ($S$=675]W m$^{-2}$)')
    ax1.plot(pressures[S],latent[S],color_list[n]+'o-')#,label=r'$L$, ($S$=675]W m$^{-2}$)')
    ax2.plot(pressures[S],instell[S],color_list[n]+'*-')#,label=r'$S_{\rm abs}$, ($S$=675]W m$^{-2}$)')
    ax2.plot(pressures[S],latent[S],color_list[n]+'o-')#,label=r'$L$, ($S$=675]W m$^{-2}$)')

ax1.set_xscale('log')
ax2.set_xscale('log')
#ax1.set_xlim(100e-6*10**3,500e-6*10**3)
#ax2.set_xlim(0.9*10**3,5*10**3)
ax1.set_xlim(150e-6,500e-6)
ax2.set_xlim(0.95,4.1)

#

ax1.spines.right.set_visible(False)
ax2.spines.left.set_visible(False)
ax2.yaxis.tick_right()
ax2.tick_params(labelright=False)  # Don't put tick labels at the right
ax1.yaxis.tick_left()

# Adds slanted lines to axes
d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(
    marker=[(-1, -d), (1, d)],
    markersize=12,
    linestyle='none',
    color='k',
    mec='k',
    mew=1,
    clip_on=False
)
formatter = ScalarFormatter()
#formatter.set_scientific(True) 
#formatter.set_powerlimits((0,0))

#formatter.set_scilimits(0,0)
#formatter.style('plain')
#ax1.xaxis.set_major_formatter(formatter)
#ax1.ticklabel_format(axis='x',style='sci',scilimits=(0,0)) 
ax2.xaxis.set_major_formatter(formatter)
#ax2.ticklabel_format(axis='x',style='plain',useMathText=True) 
#ax1.xaxis.set_minor_formatter(formatter)
ax2.xaxis.set_minor_formatter(formatter)
tick_labels = np.array(ax2.get_xticks(minor=True))
ax2.set_xticklabels(tick_labels.astype(int),minor=True)
tick_labels = np.array(ax2.get_xticks())
ax2.set_xticklabels(tick_labels.astype(int))

plt.show()
ax1.plot([1, 1], [1, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 0], [1, 0], transform=ax2.transAxes, **kwargs)
ax1.set_ylabel(r'Global-mean latent heat flux or absorbed instellation [W m$^{-2}$]')
fig.text(0.5,0.03,r'$p$CO$_2$ [bars]',ha='center')
#plt.xscale('log')

fig.set_size_inches(8,6)
ax2.legend(handles=[dot,star,blue,cyan,magenta,red,green],loc='best')
#ax1.set_xticklabels([0.0001,0.0002,0.0003,0.0004])
#plt.legend(loc='best')
fig.suptitle(r'Planets at low instellations and high $p$CO$_2$ approach'+' \nthe energetic limit on global-mean evaporation \n set by absorbed instellation')
#%% Figure 4
fig,(ax1,ax2) = plt.subplots(1,2)
for n,S in enumerate(S_list):
    ax1.plot(pressures[S],lw[S],color_list[n]+'x-')#,label=r'$S$=675]W m$^{-2}$')
    ax2.plot(pressures[S],lw[S],color_list[n]+'x-')#,label=r'$S$=675]W m$^{-2}$')
    ax1.plot(pressures[S],sensible[S],color_list[n]+'P-')#,label=r'$S$=675]W m$^{-2}$')
    ax2.plot(pressures[S],sensible[S],color_list[n]+'P-')#,label=r'$S$=675]W m$^{-2}$')



ax1.set_xscale('log')

ax2.set_xscale('log')

#ax1.set_xlim(100e-6*10**3,500e-6*10**3)
#ax2.set_xlim(0.9*10**3,5*10**3)
ax1.set_xlim(150e-6,500e-6)
ax2.set_xlim(0.95,4.1)

#

ax1.spines.right.set_visible(False)
ax2.spines.left.set_visible(False)
ax2.yaxis.tick_right()
ax2.tick_params(labelright=False)  # Don't put tick labels at the right
ax1.yaxis.tick_left()

# Adds slanted lines to axes
d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(
    marker=[(-1, -d), (1, d)],
    markersize=12,
    linestyle='none',
    color='k',
    mec='k',
    mew=1,
    clip_on=False
)
formatter = ScalarFormatter()
#formatter.set_scientific(True) 
#formatter.set_powerlimits((0,0))

#formatter.set_scilimits(0,0)
#formatter.style('plain')
#ax1.xaxis.set_major_formatter(formatter)
#ax1.ticklabel_format(axis='x',style='sci',scilimits=(0,0)) 
ax2.xaxis.set_major_formatter(formatter)
#ax2.ticklabel_format(axis='x',style='plain',useMathText=True) 
#ax1.xaxis.set_minor_formatter(formatter)
ax2.xaxis.set_minor_formatter(formatter)
tick_labels = np.array(ax2.get_xticks(minor=True))
ax2.set_xticklabels(tick_labels.astype(int),minor=True)
tick_labels = np.array(ax2.get_xticks())
ax2.set_xticklabels(tick_labels.astype(int))


ax1.plot([1, 1], [1, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 0], [1, 0], transform=ax2.transAxes, **kwargs)
ax1.set_ylabel(r'Net surface sensible or longwave flux [W m$^{-2}$]')
fig.text(0.5,0.03,r'$p$CO$_2$ [bars]',ha='center')
#plt.xscale('log')

ax2.legend(handles=[plus,x,blue,cyan,magenta,red,green],loc='best')
#ax1.set_xticklabels([0.0001,0.0002,0.0003,0.0004])
#plt.legend(loc='best')
fig.suptitle('Surface sensible and longwave fluxes are throttled\n'+r'under high $p$CO$_2$, low instellation conditions')
fig.set_size_inches(8, 5)
#%% Figure 5
fig,(ax1,ax2) = plt.subplots(1,2)
for n,S in enumerate(S_list):
    
    ax1.plot(pressures[S],precip[S],color_list[n]+'.-')#,label=r'$S$=675]W m$^{-2}$')
    ax2.plot(pressures[S],precip[S],color_list[n]+'.-')#,label=r'$S$=675]W m$^{-2}$')
    
        


ax1.set_xscale('log')
ax2.set_xscale('log')
#ax1.set_xlim(100e-6*10**3,500e-6*10**3)
#ax2.set_xlim(0.9*10**3,5*10**3)
ax1.set_xlim(150e-6,500e-6)
ax2.set_xlim(0.95,4.1)

#

ax1.spines.right.set_visible(False)
ax2.spines.left.set_visible(False)
ax2.yaxis.tick_right()
ax2.tick_params(labelright=False)  # Don't put tick labels at the right
ax1.yaxis.tick_left()

# Adds slanted lines to axes
d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(
    marker=[(-1, -d), (1, d)],
    markersize=12,
    linestyle='none',
    color='k',
    mec='k',
    mew=1,
    clip_on=False
)
formatter = ScalarFormatter()
#formatter.set_scientific(True) 
#formatter.set_powerlimits((0,0))

#formatter.set_scilimits(0,0)
#formatter.style('plain')
#ax1.xaxis.set_major_formatter(formatter)
#ax1.ticklabel_format(axis='x',style='sci',scilimits=(0,0)) 
ax2.xaxis.set_major_formatter(formatter)
#ax2.ticklabel_format(axis='x',style='plain',useMathText=True) 
#ax1.xaxis.set_minor_formatter(formatter)
ax2.xaxis.set_minor_formatter(formatter)
tick_labels = np.array(ax2.get_xticks(minor=True))
ax2.set_xticklabels(tick_labels.astype(int),minor=True)
tick_labels = np.array(ax2.get_xticks())
ax2.set_xticklabels(tick_labels.astype(int))


ax1.plot([1, 1], [1, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 0], [1, 0], transform=ax2.transAxes, **kwargs)
ax1.set_ylabel(r'Global-mean precipitation [mm day$^{-1}$]')
fig.text(0.5,0.03,r'$p$CO$_2$ [bars]',ha='center')
#plt.xscale('log')


ax2.legend(handles=[blue,cyan,magenta,red,green],loc='best')
#ax1.set_xticklabels([0.0001,0.0002,0.0003,0.0004])
#plt.legend(loc='best')
fig.suptitle(r'Under energetically-limited conditions, $p$CO$_2$'+'\nand precipitation decouple')
fig.set_size_inches(8, 5)


#%% Figure 6
plt.figure()
for n,S in enumerate(S_list):
    plt.plot(tsurf[S],precip[S],color_list[n]+'.-',label=r'$S$=675]W m$^{-2}$')

plt.xlabel('Global-mean surface temperature [K]')
plt.ylabel(r'Global-mean precipitation [mm day$^{-1}$]')
plt.title('Under energetically-limited conditions,\n global-mean surface temperature'+'\nand precipitation decouple')
plt.legend(handles=[blue,cyan,magenta,red,green],loc='best')
fig = plt.gcf()
#fig.set_size_inches(8, 5)
#%% Figure 7
fig,(ax1,ax2) = plt.subplots(1,2)
for n,S in enumerate(S_list):
    ax1.plot(pressures[S],whak_weathering[S],color_list[n]+'s-')#,label=r'$S$=675]W m$^{-2}$')
    ax2.plot(pressures[S],whak_weathering[S],color_list[n]+'s-')#,label=r'$S$=675]W m$^{-2}$')
ax1.set_xscale('log')


ax2.set_xscale('log')
#ax1.set_xlim(100e-6*10**3,500e-6*10**3)
#ax2.set_xlim(0.9*10**3,5*10**3)
ax1.set_xlim(150e-6,500e-6)
ax2.set_xlim(0.95,4.1)

#

ax1.spines.right.set_visible(False)
ax2.spines.left.set_visible(False)
ax2.yaxis.tick_right()
ax2.tick_params(labelright=False)  # Don't put tick labels at the right
ax1.yaxis.tick_left()

# Adds slanted lines to axes
d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(
    marker=[(-1, -d), (1, d)],
    markersize=12,
    linestyle='none',
    color='k',
    mec='k',
    mew=1,
    clip_on=False
)
formatter = ScalarFormatter()
#formatter.set_scientific(True) 
#formatter.set_powerlimits((0,0))

#formatter.set_scilimits(0,0)
#formatter.style('plain')
#ax1.xaxis.set_major_formatter(formatter)
#ax1.ticklabel_format(axis='x',style='sci',scilimits=(0,0)) 
ax2.xaxis.set_major_formatter(formatter)
#ax2.ticklabel_format(axis='x',style='plain',useMathText=True) 
#ax1.xaxis.set_minor_formatter(formatter)
ax2.xaxis.set_minor_formatter(formatter)
tick_labels = np.array(ax2.get_xticks(minor=True))
ax2.set_xticklabels(tick_labels.astype(int),minor=True)
tick_labels = np.array(ax2.get_xticks())
ax2.set_xticklabels(tick_labels.astype(int))


ax1.plot([1, 1], [1, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 0], [1, 0], transform=ax2.transAxes, **kwargs)
ax1.set_ylabel(r'CO$_2$ consumption from silicate weathering [Tmol year$^{-1}$]')
fig.text(0.5,0.03,r'$p$CO$_2$ [bars]',ha='center')
#plt.xscale('log')

ax1.legend(handles=[blue,cyan,magenta,red,green],loc='best')
ax1.set_yscale('log')
ax2.set_yscale('log')
#ax1.set_xticklabels([0.0001,0.0002,0.0003,0.0004])
#plt.legend(loc='best')
fig.suptitle(r'WHAK weathering increases rapidly with $p$CO$_2$'+'\ndespite sluggish hydrology')
fig.set_size_inches(8, 5)


#%%  Figure 8 
fig,(ax1,ax2) = plt.subplots(1,2)
for n,S in enumerate(S_list):
    ax1.plot(pressures[S],mac_weathering[S],color_list[n]+'s-')#,label=r'$S$=675]W m$^{-2}$')
    ax2.plot(pressures[S],mac_weathering[S],color_list[n]+'s-')#,label=r'$S$=675]W m$^{-2}$')

ax1.set_xscale('log')

ax2.set_xscale('log')
#ax1.set_xlim(100e-6*10**3,500e-6*10**3)
#ax2.set_xlim(0.9*10**3,5*10**3)
ax1.set_xlim(150e-6,500e-6)
ax2.set_xlim(0.95,4.1)

#

ax1.spines.right.set_visible(False)
ax2.spines.left.set_visible(False)
ax2.yaxis.tick_right()
ax2.tick_params(labelright=False)  # Don't put tick labels at the right
ax1.yaxis.tick_left()

# Adds slanted lines to axes
d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(
    marker=[(-1, -d), (1, d)],
    markersize=12,
    linestyle='none',
    color='k',
    mec='k',
    mew=1,
    clip_on=False
)
formatter = ScalarFormatter()
#formatter.set_scientific(True) 
#formatter.set_powerlimits((0,0))

#formatter.set_scilimits(0,0)
#formatter.style('plain')
#ax1.xaxis.set_major_formatter(formatter)
#ax1.ticklabel_format(axis='x',style='sci',scilimits=(0,0)) 
ax2.xaxis.set_major_formatter(formatter)
#ax2.ticklabel_format(axis='x',style='plain',useMathText=True) 
#ax1.xaxis.set_minor_formatter(formatter)
ax2.xaxis.set_minor_formatter(formatter)
tick_labels = np.array(ax2.get_xticks(minor=True))
ax2.set_xticklabels(tick_labels.astype(int),minor=True)
tick_labels = np.array(ax2.get_xticks())
ax2.set_xticklabels(tick_labels.astype(int))


ax1.plot([1, 1], [1, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 0], [1, 0], transform=ax2.transAxes, **kwargs)
ax1.set_ylabel(r'CO$_2$ consumption from silicate weathering [Tmol year$^{-1}$]')
fig.text(0.5,0.03,r'$p$CO$_2$ [bars]',ha='center')
#plt.xscale('log')

ax1.legend(handles=[blue,cyan,magenta,red,green],loc='best')
#ax1.set_xticklabels([0.0001,0.0002,0.0003,0.0004])
#plt.legend(loc='best')
fig.suptitle('With MAC weathering, energetically-limited\nprecipitation destabilizes the carbon cycle')
fig.set_size_inches(8, 5)



#%%
