import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import sys
# sys.path.insert(0, '~/Source/MITgcm/utils/python/MITgcmutils/MITgcmutils/')
# Must add the path to import the package containing the mds module
# export PYTHONPATH=$PYTHONPATH:~/Source/MITgcm/utils/python/MITgcmutils/MITgcmutils
from MITgcmutils import mds

# Use LaTeX fonts
#plt.rc('text', usetex=True)
#plt.rc('font', family='serif')


# =============================================================================
# INPUT
# =============================================================================
# results_path  = '/media/by/D67A71047A70E2A3/prince_run_output/FRISP2019/nonhydrostatic/default/'
# iters = [659520] # time step(s) at which to read temperature
# hb = [-600.0, -600.0]
# results_path  = '/media/by/D67A71047A70E2A3/prince_run_output/FRISP2019/nonhydrostatic/H540/'
# iters = [812160] # time step(s) at which to read temperature
# hb = [-600.0, -540.0]
# results_path  = '../frisp_results/trestore86400/hydrostatic/default/'
# iters = [2736000] # time step(s) at which to read temperature
# hb = [-600.0, -600.0]

# FINAL RESULTS HERE:
#results_path  = '../frisp_results/trestore600/nonhydrostatic/default/'
results_path  = './'
iters = [5760*(i+1) for i in range(20)] # time step(s) at which to read results
hb = [-600.0, -600.0] # bed depth at GL and open boundary
# results_path  = '../frisp_results/trestore600/nonhydrostatic/H540/'
# iters = [1054080] # time step(s) at which to read temperature
# hb = [-600.0, -540.0]

xind = 0 # x-index at which to plot
yind = 49
zind = 14
nx = 1
ny = 200
nz = 102
yind_profiles = [120, 80, 40]
# aspect_ratio = 0.025
# vector_skip = 4
# vector_scale = 0.001
plot_ice_profile = True
yb = [0.0, 10.0] # y endpoints (in km)
hi = [-598.0, -500.0] # ice-ocean boundary depth at GL, open boundary
z_top = -498.0
colors = ['#ADD8E6', '#87A96B', '#FF4F00']
# =============================================================================


plt.rc('xtick',labelsize=14)
plt.rc('ytick',labelsize=14)

for iternum, iter in enumerate(iters):
  print('iteration '+str(iternum+1)+' of '+str(len(iters)))

  # Read in data from MITgcm output files
  XC = mds.rdmds(results_path + 'XC')
  YC = mds.rdmds(results_path + 'YC')
  RC = mds.rdmds(results_path + 'RC')

  diag_U = mds.rdmds(results_path + 'diag_U', itrs = iter)
  U = diag_U[0,:,:,:]
  V = diag_U[1,:,:,:]
  W = diag_U[2,:,:,:]
  Umag = (U**2 + V**2 + W**2)**0.5
  PhiVel = diag_U[3,:,:] # Horiz. vel. potential
  PsiVel = diag_U[4,:,:] # Horiz. vel. streamfunction

  diag_Tracers = mds.rdmds(results_path + 'diag_Tracers', itrs = iter)
  Theta = diag_Tracers[0,:,:,:]

  # Reorganize RC and YC into matrices matching the shape of the T matrix
  RC2 = np.empty((ny, nz))
  YC2 = np.empty((ny, nz))
  for j in range(ny):
      RC2[j,:] = RC[:,0,0]
  for j in range(nz):
      YC2[:,j] = YC[:,0]

  # Initialize plotting environment
  fig = plt.figure(figsize=(16,7))
  ax = []
  ax.append(plt.subplot(2,1,1))
  ax.append(plt.subplot(2,3,4))
  ax.append(plt.subplot(2,3,5))
  ax.append(plt.subplot(2,3,6))

  # Some plot properties
  plt.subplots_adjust(hspace=0.3, wspace=0.3)

  ax[0].set_xlim((0.0,10.0))
  ax[0].set_ylim((-600,-500))
  ax[0].set_xlabel(r'$y$ (km)', fontsize = 22)
  ax[0].set_ylabel(r'$z$ (m)', fontsize = 22)

  ax[1].plot([0, 0], [-1, 1], '--', color='gray')
  ax[1].plot([-1, 1], [0, 0], '--', color='gray')
  ax[1].set_xlabel(r'$u$ (m/s)', fontsize = 22)
  ax[1].set_ylabel(r'$v$ (m/s)', fontsize = 22)
  ax[1].xaxis.set_major_locator(plt.MaxNLocator(7))
  ax[1].yaxis.set_major_locator(plt.MaxNLocator(7))

  ax[2].set_ylim((0,1))
  ax[2].plot([0, 0], [0, 1], '--', color='gray')
  ax[2].plot([-1, 1], [1, 1], '-', color='gray')
  ax[2].set_xlabel(r'$u$ (m/s)', fontsize = 22)
  ax[2].set_ylabel(r'H fraction', fontsize = 22)
  ax[2].xaxis.set_major_locator(plt.MaxNLocator(7))

  ax[3].set_ylim((0,1))
  ax[3].plot([0, 0], [0, 1], '--', color='gray')
  ax[3].plot([-1, 1], [1, 1], '-', color='gray')
  ax[3].set_xlabel(r'$v$ (m/s)', fontsize = 22)
  ax[3].set_ylabel(r'H fraction', fontsize = 22)
  ax[3].xaxis.set_major_locator(plt.MaxNLocator(7))

  # Create mask for contour plot
  HI = hi[0] + 1e-3*YC2*(hi[1]-hi[0])/(yb[1]-yb[0])
  HB = hb[0] + 1e-3*YC2*(hb[1]-hb[0])/(yb[1]-yb[0])
  mask = (RC2 > HI) + (RC2 < HB) # mask = areas we DON'T plot
  Tmasked = np.ma.array(Theta[:,:,xind].T, mask=mask)
  #Umasked = np.ma.array(Umag[:,:,xind].T, mask=mask)

  # Plot contours
  cplot = ax[0].contourf(1e-3*YC2, RC2, Tmasked, 20, cmap='bwr')

  # Display the colorbar
  # divider = make_axes_locatable(ax)
  # cax = divider.append_axes('right', size='5%', pad=0.05)
  cbar = plt.colorbar(cplot, ax = ax[0], shrink=1.0)
  cbar.set_label(label=r'$\theta$ ($^\circ$C)', fontsize = 22)

  # # Plot velocity vectors
  # RC2 = np.empty((ny, nz))
  # YC2 = np.empty((ny, nz))
  # for j in range(ny):
  #     RC2[j,:] = RC[:,0,0]
  # for j in range(nz):
  #     YC2[:,j] = YC[:,0]
  # vector_skip = 4
  # aspect_ratio = 0.025
  # vector_scale = 0.001
  # ax[0].quiver(1e-3*YC2[::vector_skip,::vector_skip],
  #   RC2[::vector_skip,::vector_skip],
  #   1e-3*V[::vector_skip,::vector_skip,xind].T,
  #   aspect_ratio*W[::vector_skip,::vector_skip,xind].T,
  #   scale=vector_scale)

  # # Label axes, colorbar
  # plt.xlabel(r'Distance from grounding line, $y$ (km)', fontsize=20)
  # plt.ylabel(r'Depth, $z$ (m)', fontsize=20)
  # cbar.ax.set_ylabel(r'$\theta$ ($^\circ$ C)', fontsize=20)



  # Plot vertical profiles of velocities
  for j, yind in enumerate(yind_profiles):

    # Compute cavity height at y
    y = YC[yind,0]*1e-3
    hbed = hb[0] \
      + y * (hb[1]-hb[0])/(yb[1]-yb[0])
    hice = hi[0] \
      + y * (hi[1]-hi[0])/(yb[1]-yb[0]) \
      - hbed
    # y = YC[yind,0]
    # h = (420000**2 - (XC[0,xind]-10000)**2 - (y-10000)**2)**0.5 - 420503 + 600;
    # print h

    # Compute cavity height frac
    hfrac = (RC[:,0,0]-hbed) / hice

    mask = (hfrac<1.0) * (hfrac>0.0) # Only want dots in cavity

    ax[1].plot(U[mask, yind, xind], V[mask, yind, xind], 'o-', color=colors[j])
    ax[2].plot(U[mask, yind, xind], hfrac[mask], 'o-', color=colors[j])
    ax[3].plot(V[mask, yind, xind], hfrac[mask], 'o-', color=colors[j])

    # record min/max u, v for plot limits
    if j == 0:
      umin = np.min(U[:, yind, xind])
      umax = np.max(U[:, yind, xind])
      vmin = np.min(V[:, yind, xind])
      vmax = np.max(V[:, yind, xind])
    else:
      if np.min(U[:, yind, xind]) < umin:
        umin = np.min(U[:, yind, xind])
      if np.max(U[:, yind, xind]) > umax:
        umax = np.max(U[:, yind, xind])
      if np.min(V[:, yind, xind]) < vmin:
        vmin = np.min(V[:, yind, xind])
      if np.max(V[:, yind, xind]) > vmax:
        vmax = np.max(V[:, yind, xind])

    # Also, let's plot some dots on the contour plot to show where the velocities
    # match up between the top and bottom plots
    ax[0].plot(1e-3*YC2[yind, mask], RC2[yind, mask], 'o', color=colors[j])

  # Get the maximum magnitude of u and v (we'll make symmetric plots to make it
  # easier to see which direction the flow is going)
  umagmax = np.max([abs(umin), abs(umax)])
  vmagmax = np.max([abs(vmin), abs(vmax)])
  Umagmax = np.max([umagmax, vmagmax])
  # And we'll pad the plot limits to make it look nicer
  uplotlims = (-1.1*umagmax, 1.1*umagmax)
  vplotlims = (-1.1*vmagmax, 1.1*vmagmax)
  Uplotlims = (-1.1*Umagmax, 1.1*Umagmax)
  ax[1].set_xlim(Uplotlims)
  ax[1].set_ylim(Uplotlims)
  ax[2].set_xlim(Uplotlims)
  ax[3].set_xlim(Uplotlims)
  # ax[1].set_aspect('equal') #make square
  # ax[2].set_aspect((Uplotlims[1]-Uplotlims[0])/1.0) # make square
  # ax[2].set_aspect((Uplotlims[1]-Uplotlims[0])/1.0) # make square

  plt.savefig("theta_Benfrisp_no_coriolis_time_{}days.png".format(float(iters[iternum])*30./86400.))
  plt.close()
