import numpy as np
import matplotlib.pyplot as plt
from scipy.io import netcdf_file


plt.close()
fs = 12
plt.rc('text', usetex=True)
plt.rc('font', size=fs)          # controls default text sizes
plt.rc('axes', titlesize=fs)     # fontsize of the axes title
plt.rc('axes', labelsize=fs)     # fontsize of the x and y labels
plt.rc('xtick', labelsize=fs)    # fontsize of the tick labels
plt.rc('ytick', labelsize=fs)    # fontsize of the tick labels
plt.rc('legend', fontsize=fs)    # legend fontsize
plt.rc('figure', titlesize=fs)   # fontsize of the figure title
plt.rcParams['font.family'] = 'Times'

cmax=1.5e-4
cmin=-1.5e-4

ns = {
    '32': 0.32,
    '50': 0.50,
    '76': 0.76
}

# Define file and parameters
nfile = '14'
NX = 1024
NY = 1024
NZ = 512
L0 = 3e4 #meters
LX = 2*np.pi*L0
LY = 2*np.pi*L0
LZ = np.pi*L0
scale = (LZ/(NZ*1e3))
T0 = 1e3 #seconds
U0 = 35  #meters/sec
Nu = 0.015
shape = (1024, 1024, 512)  # Replace with actual dimensions
n_panels = 6
ncols = 2
nrows = 3
yf = int(512-10/(LZ/(NZ*1e3)))
xf = 1024
panel_aspect = (512-10/(LZ/(NZ*1e3))) / xf
max_km = int(yf*scale)
xmax_km = int((xf-512)*scale)
km_ticks = np.arange(0, max_km + 1, 20) 
km_xticks = np.arange(-80, 80 + 1, 20) 
tick_locations = km_ticks / scale        # convert km to grid points
xtick_locations = 512 + (km_xticks / scale)
cmap = 'seismic_r'

left = 0.09
bottom = 0.2
dw = 0.02
db = 0.02 
width = 0.44
height = width * panel_aspect
cbar_height = db 

l1 = left + width + dw
b1 = bottom + height + db
b2 = b1 + height + db

H = 6
fig = plt.figure(figsize=(H, H))
cbar_ax = fig.add_axes([left, bottom - cbar_height - db, width*2 + dw, cbar_height])
#-------------------------------------------------------------------------------------

with netcdf_file('../../2nd-submission/data/fig3_data.nc','r') as f:
    v=f.variables
    mr1 =v['mr_76_su'][:].copy()
    mr2 =v['mr_76_th'][:].copy() 
    mr3 =v['mr_76_ss'][:].copy()
    mr4 =v['mr_50_su'][:].copy()
    mr5 =v['mr_32_su'][:].copy()
    mr6 =v['mr_32_ss'][:].copy()

#-------------------------------------------------------------------------------------

dset = '76'

path = paths[dset]['su']
Ns = ns[dset] * Nu
ax = fig.add_axes([left, b2, width, height])
im = ax.imshow(mr1, origin='lower', vmax=cmax, vmin=cmin, cmap=cmap)

ax.set_xticks(xtick_locations)
ax.set_xticklabels(['80', '', '40', '', '0', '', '40', '', '80'])
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
#ax.set_xlabel('Radius [km]', labelpad=10)
ax.set_xlabel('Radius [km]')

ax.set_ylim(0, yf)
ax.set_yticks(tick_locations)
ax.set_yticklabels(['0', '', '40', '', '80'])
ax.set_ylabel('Height [km]', color='black')
xt = 20
yt = 380
ax.text(xt, yt, 'A')


path = paths[dset]['th']
Ns = ns[dset] * Nu
ax = fig.add_axes([l1, b2, width, height])
im = ax.imshow(mr2, origin='lower', vmax=cmax, vmin=cmin, cmap=cmap)
ax.set_xlim(0, xf)
ax.set_ylim(0, yf)

ax.set_xticks(xtick_locations)
ax.set_xticklabels(['80', '', '40', '', '0', '', '40', '', '80'])
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
ax.set_xlabel('Radius [km]')
ax.set_yticks([])

ax.text(xt, yt, 'D')


path = paths[dset]['ss']
Ns = ns[dset] * Nu
ax = fig.add_axes([l1, b1, width, height])
im = ax.imshow(mr3, origin='lower', vmax=cmax, vmin=cmin, cmap=cmap)
ax.set_ylim(0, yf)
ax.set_xticks([])
ax.set_yticks([])
ax.text(xt, yt, 'E')

#-------------------------------------------------------------------------------------
dset = '50'

path = paths[dset]['su']
Ns = ns[dset] * Nu
ax = fig.add_axes([left, b1, width, height])
im = ax.imshow(mr4, origin='lower', vmax=cmax, vmin=cmin, cmap=cmap)
ax.set_ylim(0, yf)
ax.set_xticks([])
ax.set_yticks(tick_locations)
ax.set_yticklabels(['0', '', '40', '', '80'])
ax.set_ylabel('Height [km]', color='black')

ax.text(xt, yt, 'B')

#-------------------------------------------------------------------------------------
dset = '32'

path = paths[dset]['su']
Ns = ns[dset] * Nu
ax = fig.add_axes([left, bottom, width, height])
im = ax.imshow(mr5, origin='lower', vmax=cmax, vmin=cmin, cmap=cmap)
ax.set_ylim(0, yf)
ax.set_xticks([])
ax.set_yticks(tick_locations)
ax.set_yticklabels(['0', '', '40', '', '80'])

ax.set_ylabel('Height [km]', color='black')
ax.text(xt, yt, 'C')


path = paths[dset]['ss']
Ns = ns[dset] * Nu
ax = fig.add_axes([l1, bottom, width, height])
im = ax.imshow(mr6, origin='lower', vmax=cmax, vmin=cmin, cmap=cmap)
ax.set_ylim(0, yf)
ax.set_xticks([])
ax.set_yticks([])
ax.text(xt, yt, 'F')
ax.text(700, 380, 't=26 min', fontsize=11)
ax.text(-150, -250, r'$r^\prime - r_{vs}(z)$', fontsize=12)

cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
tick_locs = np.arange(cmin, cmax+5e-5, 5e-5) 
tick_labels = ['', '-1e-4', '', '0', '', '1e-4', '']
cbar.set_ticks(tick_locs)
cbar.set_ticklabels(tick_labels)

plt.show()


