#!/usr/bin/env python

# This script plots spectra for a chosen field

#=====perform the various generic imports========
import warnings
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as clrs
# Set default plot save resolution to a large value:
mpl.rcParams['savefig.dpi'] = 200
# Set label size:
mpl.rcParams['xtick.labelsize'] = 30
mpl.rcParams['ytick.labelsize'] = 30
mpl.rcParams['axes.linewidth'] = 3
warnings.simplefilter("ignore",DeprecationWarning)

#========================================
#=====----------Main code----------======
#========================================

print()
print('   ---------------------------------------------')
print('   | Plots 1d power spectra for a chosen field |')
print('   ---------------------------------------------')
print()
print(' Choose one of the following:')
print()
print('   (1)  relative vorticity, zeta')
print('   (2)  buoyancy, b')
print()
opt_in = input(' Option (default 1)? ')
option = int(opt_in or 1)

if option==1:
   file1='zspec.asc'
   labely='$\log_{10}\,|\hat{\\zeta}|^2$'
elif option==2:
   file1='bspec.asc'
   labely='$\log_{10}\,|\hat{b}|^2$'
else:
   print(' *** not an available option! stopping! ***')
   exit

# Open input file(s):
in_file1=open('spectra/'+file1,'r')
# Read the first header line to get kmax:
first_line=in_file1.readline()
kmax=int(first_line.split()[-1])
in_file1.seek(0)

nx=kmax

# Read in the full data to a 1d array and close input file:
raw_data1 = np.fromfile(file=in_file1,dtype=float,sep='\n')
in_file1.close()
nframes = int(len(raw_data1)/(2*kmax+2))  

# Set the number of frames:
print()
print(' Number of time frames found: %i' %nframes)
print()

# Shape the data array(s) for plotting and work out
# min/max values over all time:
frames=range(0,nframes)
time=[raw_data1[i*(2*kmax+2)] for i in frames]
tim_eles = [i*(2*kmax+2)+j for i in frames for j in range(2)]
shaped_data = np.delete(raw_data1,tim_eles)[0:(2*kmax+2)*nframes].reshape((nframes,kmax,2))
k=np.zeros((nframes,nx))
s1=np.zeros((nframes,nx))
smax=-1000.0
for i in frames:
   k[i,:]=shaped_data[i].transpose()[0][0:nx]
   s1[i,:]=shaped_data[i].transpose()[1][0:nx]
   smax=max(smax,np.amax(s1[i,:]))

smax=0.5*int(2.0*(smax+1000.0)+1.0)-1000.0

ymax_in = input(' Maximum log_10(spectrum) to plot (default '+str(smax)+')? ')
ymax = float(ymax_in or smax)

if option==1:
   smin=smax-3.0
elif option==2:
   smin=smax-6.0

ymin_in = input(' Minimum log_10(spectrum) to plot (default '+str(smin)+')? ')
ymin = float(ymin_in or smin)

print()

#======================================================================
# Grab the correct sub-array for plotting (ic is the current frame):
global ic 
ic = 0

# Initiate a plotting window and plot relevant data into it:
fig = plt.figure(1,figsize=[10,10])
fig.subplots_adjust(bottom=0.2,left=0.2)
ax = fig.add_subplot(111)
im1 = ax.plot(k[ic],s1[ic],'k-',lw=2)

# Set plot title, legend and limits of the plot:
ax.set_title('$t =$'+str('%.3f'%time[ic]), fontsize=30)
ax.set_xlabel('$\log_{10}\,k$', fontsize=30)
ax.set_ylabel(labely, fontsize=30)
ax.tick_params(length=10, width=3)
ax.set_ylim(ymin,ymax)

def on_press(event):
  # Routine to deal with re-plotting the window on keypress. 
  # The keys -/= (ie. -/+ without the shift key) cycle through frames 
  # forward and backwards.
  global ic
  if event.key=='=':
    axes=event.canvas.figure.get_axes()[0]
    ic = (ic+1+nframes) % nframes
    xlim=axes.get_xlim()
    ylim=axes.get_ylim()

    # Clear the axes for replot:
    axes.clear()

    # Replot the relevant data:
    axes.plot(k[ic],s1[ic],'k-',lw=2)

    # Set the title, legend and plot limits:
    axes.set_title('$t =$'+str('%.3f'%time[ic]), fontsize=30)
    axes.set_xlabel('$\log_{10}\,k$', fontsize=30)
    axes.set_ylabel(labely, fontsize=30)
    axes.tick_params(length=10, width=3)
    axes.set_xlim(xlim)  
    axes.set_ylim(ylim)
    plt.draw()

  if event.key=='-':
    axes=event.canvas.figure.get_axes()[0]
    ic = (ic-1+nframes) % nframes
    xlim=axes.get_xlim()
    ylim=axes.get_ylim()

    # Clear the axes for replot:
    axes.clear()

    # Replot the relevant data:
    axes.plot(k[ic],s1[ic],'k-',lw=2)

    # Set the title, legend and plot limits:
    axes.set_title('$t =$'+str('%.3f'%time[ic]), fontsize=30)
    axes.set_xlabel('$\log_{10}\,k$', fontsize=30)
    axes.set_ylabel(labely, fontsize=30)
    axes.tick_params(length=10, width=3)
    axes.set_xlim(xlim)  
    axes.set_ylim(ylim)

    # Finally redraw the plot:
    plt.draw()

# Begin the plot and link keypress events to the above handler:
cid=fig.canvas.mpl_connect('key_press_event',on_press)
plt.show()
