import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import sys
# sys.path.insert(0, '~/Source/MITgcm/utils/python/MITgcmutils/MITgcmutils/')
# Must add the path to import the package containing the mds module
# export PYTHONPATH=$PYTHONPATH:~/Source/MITgcm/utils/python/MITgcmutils/MITgcmutils
from MITgcmutils import mds

# Use LaTeX fonts
#plt.rc('text', usetex=True)
#plt.rc('font', family='serif')


# =============================================================================
# INPUT
# =============================================================================
# pathlist = ['/media/by/D67A71047A70E2A3/prince_run_output/FRISP2019/tests/4292719/',
# '/media/by/D67A71047A70E2A3/prince_run_output/FRISP2019/tests/4293232/',
# '/media/by/D67A71047A70E2A3/prince_run_output/FRISP2019/tests/4293236/',
# '/media/by/D67A71047A70E2A3/prince_run_output/FRISP2019/tests/4293275/',
# '/media/by/D67A71047A70E2A3/prince_run_output/FRISP2019/tests/4293496/']
# iters = [195840, 195840, 195840, 195840, 195840]
# num_plots = 5
# THs = [0.1, 0.25, 0.5, 0.75, 1.0]

pathlist = ['./'] #['../frisp_results/trestore86400/nonhydrostatic/Tn10/',
           # '../frisp_results/trestore86400/nonhydrostatic/Tn05/',
           # '../frisp_results/trestore86400/nonhydrostatic/T00/',
           # '../frisp_results/trestore86400/nonhydrostatic/T05/',
           # '../frisp_results/trestore86400/nonhydrostatic/default/']
# iters = [910080, 921600, 927360, 927360, 910080]
iters = [120,240,360,1440,2880] #iters = [5760, 63360,115200, 172800,288000]#, 927360, 927360, 921600, 910080] # timesteps at which to read results from corresponding files in pathlist
num_plots = len(iters)
THs = [2, 22,40,  60,100]#-1.0, -0.5, 0.0, 0.5, 1.0] # Restoring temperatures
THlabel = r'$\theta_{\mathrm{r}}$ ($^\circ$C)'
leg = [r'$\theta_\mathrm{r}=-1.0$ ($^\circ$C)',
       r'$\theta_\mathrm{r}=-0.5$ ($^\circ$C)',
       r'$\theta_\mathrm{r}=0.0$ ($^\circ$C)',
       r'$\theta_\mathrm{r}=0.5$ ($^\circ$C)',
       r'$\theta_\mathrm{r}=1.0$ ($^\circ$C)'] # legend entrys
legloc = 'upper left'
plot_T_fit = True
Tfit_labels = [r'$(T_\mathrm{a} - T_\mathrm{f,gl})^2$ fit, $y = 2$km',
               r'$(T_\mathrm{a} - T_\mathrm{f,gl})^2$ fit, $y = 4$km',
               r'$(T_\mathrm{a} - T_\mathrm{f,gl})^2$ fit, $y = 6$km']

# pathlist = ['../frisp_results/trestore86400/nonhydrostatic/default/',
#             '../frisp_results/trestore86400/nonhydrostatic/H585/',
#             '../frisp_results/trestore86400/nonhydrostatic/H570/',
#             '../frisp_results/trestore86400/nonhydrostatic/H555/',
#             '../frisp_results/trestore86400/nonhydrostatic/H540/']
# iters = [910080, 1019520, 1077120, 1209600, 1900800]
# num_plots = 5
# THs = [0.0, 1.5, 3.0, 4.5, 6.0]
# THlabel = r'Bed slope (m/km)'
# leg = [r'$s=0.0$ (m/km)',
#        r'$s=1.5$ (m/km)',
#        r'$s=3.0$ (m/km)',
#        r'$s=4.5$ (m/km)',
#        r'$s=6.0$ (m/km)',]
# legloc = 'lower right'
# plot_T_fit = False

plot_MR_fit = False
filtering = True
yinds = [120, 80, 40]
colors = ['#ADD8E6', '#87A96B', '#FF4F00']
# =============================================================================

# Filtering function to make results easier on the eyes
def savitzky_golay(y, window_size, order, deriv=0, rate=1):
  # Taken from https://stackoverflow.com/questions/22988882/how-to-smooth-a-curve-in-python

  # import numpy as np
  from math import factorial

  try:
      window_size = np.abs(np.int(window_size))
      order = np.abs(np.int(order))
  except ValueError: #, msg:
      raise ValueError("window_size and order have to be of type int")
  if window_size % 2 != 1 or window_size < 1:
      raise TypeError("window_size size must be a positive odd number")
  if window_size < order + 2:
      raise TypeError("window_size is too small for the polynomials order")
  order_range = range(order+1)
  half_window = (window_size -1) // 2
  # precompute coefficients
  b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
  m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv)
  # pad the signal at the extremes with
  # values taken from the signal itself
  firstvals = y[0] - np.abs( y[1:half_window+1][::-1] - y[0] )
  lastvals = y[-1] + np.abs(y[-half_window-1:-1][::-1] - y[-1])
  y = np.concatenate((firstvals, y, lastvals))
  return np.convolve( m[::-1], y, mode='valid')

plt.rc('xtick',labelsize=16)
plt.rc('ytick',labelsize=16)

# Initialize plotting environment
fig = plt.figure(figsize=(16,7))
ax_vy  = plt.subplot2grid((1,3),(0,0),colspan=2)
ax_vTH = plt.subplot2grid((1,3),(0,2),colspan=1, sharey=ax_vy)

ax_vy.set_xlim((0.0,8.0))
ax_vy.set_ylim((0.0,7.0))
ax_vy.set_xlabel(r'$y$ (km)', fontsize = 22)
ax_vy.set_ylabel(r'Melt rate (m a$^{-1}$)', fontsize = 22)
ax_vTH.set_xlabel(THlabel, fontsize = 22)

vTHpad = (max(THs)-min(THs))/8
ax_vTH.set_xlim((min(THs)-vTHpad, max(THs)+vTHpad))

# Create list to store saved meltrates
saved_MRs = []

for j in range(num_plots):
  iter = iters[j]
  results_path = pathlist[0]

  # Read in data from MITgcm output files
  XC = mds.rdmds(results_path + 'XC')
  YC = mds.rdmds(results_path + 'YC')
  RC = mds.rdmds(results_path + 'RC')

  diag_Shelfice = mds.rdmds(results_path + 'diag_Shelfice', itrs = iter)
  SHIfwFlx = diag_Shelfice[0,:,:]
  SHI_Tfrz = diag_Shelfice[11,:,:]

  print(SHI_Tfrz.min())

  MR = diag_Shelfice[10,:,:]

  if filtering:
    MR = savitzky_golay(MR[:,0],15,1)

  # Plot the melt rate profile
  ax_vy.plot(1e-3*YC, MR*60.0*60.0*24*365.25, lw=1.5)

  # Save off some melt rates at particular locations
  saved_MRs.append([])
  for yi in yinds:
    saved_MRs[j].append(MR[yi])

saved_MRs = np.array(saved_MRs)

for j, yi in enumerate(yinds):
  if plot_T_fit:
    A = np.vstack([(np.array(THs)+2.3)**2]).T
    print(A)
    k, res, _, _ = np.linalg.lstsq(A, saved_MRs[:,j])#, rcond = None)
    MR_predicted = k[0]*(np.array(THs)+2.34)**2
    ax_vTH.plot(THs, MR_predicted*60.0*60.0*24*365.25, '--',
      color=colors[j], lw=2.0, label=Tfit_labels[j])

    ax_vTH.plot(np.array(THs), saved_MRs[:,j]*60.0*60.0*24*365.25, 'o',
      color = colors[j], lw=2, label=None)

  else:
    ax_vTH.plot(THs, saved_MRs[:,j]*60.0*60.0*24*365.25, '-o',
      color = colors[j], lw=2, label=None)

  ax_vy.plot(1e-3*YC[yi]*np.ones(len(iters)), saved_MRs[:,j]*60.0*60.0*24*365.25,
    'o', color = colors[j], lw=2)




ax_vy.legend(leg, loc=legloc, fontsize=16)
if plot_T_fit:
  ax_vTH.legend(loc='upper left', fontsize=16)

plt.show()
