#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 16 11:43:12 2021

@author: davidhealy
"""
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd 

from scipy import signal # further FFT functionality
from obspy.clients.fdsn import Client
from obspy import UTCDateTime

T_START = 0     #   length in seconds of data to plot before origin time
T_END = 60*60*24*14    #   length in seconds of data to plot after origin time

#   now get some station meta data, esp. lat & long 
namStation = 'R4DD7'    #   RHS Bridgewater   
#namStation = 'R72C7'    #   Aberdeen Uni   
netStation = 'AM'
locStation = '00'
chaStation = 'EHZ'      #   vertical component of geophone 
client = Client('RASPISHAKE')

#   start of 'event' 
EQ_TIME = "2022-06-15T00:00:00"
dtEvent = UTCDateTime(EQ_TIME)
sEventName1 = 'Raw data'
sEventName2 = 'Filtered data'

#   get the data
#   plot seismograms  
t1 = dtEvent - T_START
t2 = dtEvent + T_END

# Download and filter data for this station
st1 = client.get_waveforms(netStation, namStation, locStation, chaStation,
                          starttime=t1, endtime=t2, attach_response=True)
st1.merge(method=0, fill_value='latest')

st1v = st1.copy()
st1v.detrend(type="demean")
st1v.remove_response()
st1v.trim(t1, t2)

st1f = st1v.copy()
st1f.filter("highpass", freq=1., corners=2)
st1f.trim(t1, t2)

st2f = st1v.copy()
st2f.filter("bandpass", freqmin=5., freqmax=20., corners=4)
st2f.trim(t1, t2)

#   build a dataframe of velocity amplitudes and their datetimes 
#   this helps to improve the x-axis formatting...         
index = pd.DatetimeIndex([(dtEvent + ns).datetime for ns in st1[0].times()])
vamplRaw = pd.DataFrame(st1[0].data, index=index)
vamplVel = pd.DataFrame(st1f[0].data*1000., index=index)
vamplVel2 = pd.DataFrame(st2f[0].data*1000., index=index)

window_size = 256
recording_rate = 100
frequencies, times, amplitudes = signal.spectrogram(st1[0].data, fs=recording_rate, 
                                                    window='hamming', nperseg=window_size, 
                                                    noverlap=window_size - 100, 
                                                    detrend=False, scaling="density")
decibels = 20 * np.log10(amplitudes)

# Now plot the waveform data
fig, axs = plt.subplots(4, 1, figsize=(10,12))

axs[0].plot(vamplRaw.index, vamplRaw[0], linewidth=1)
axs[0].grid(True)
axs[0].set_xlabel("Date")
axs[0].set_title("{:} - RaspberryShake {:}.{:}.{:}.{:} - raw".format(
    sEventName1, st1[0].stats.network, st1[0].stats.station, st1[0].stats.location,
    st1[0].stats.channel))
axs[0].set_ylabel("Counts")
axs[0].set_xlim(vamplRaw.index[0], vamplRaw.index[-1])
axs[0].set_ylim(5000, 25000)

axs[1].plot(vamplVel.index, vamplVel[0], linewidth=1)
axs[1].grid(True)
axs[1].set_xlabel("Date")
axs[1].set_title("{:} - RaspberryShake {:}.{:}.{:}.{:}".format(
    sEventName2, st1f[0].stats.network, st1f[0].stats.station, st1f[0].stats.location,
    st1f[0].stats.channel))
axs[1].set_ylabel("Velocity (mm/s)")
axs[1].set_xlim(vamplRaw.index[0], vamplRaw.index[-1])
axs[1].set_ylim(-0.025, 0.025)

pcm = axs[2].pcolormesh(times, frequencies, decibels, 
                        cmap="plasma", 
                        shading='auto',
                        vmin=-20, vmax=80)
axs[2].set_ylabel("Frequency (Hz)")
axs[2].set_xlabel("Time from start (s)")
axs[2].set_title("Spectrogram")
axs[2].set_ylim(1., 50.)
#plt.colorbar(pcm, ax=axs[1], orientation='vertical')

axs[3].plot(vamplVel2.index, vamplVel2[0], linewidth=1)
axs[3].grid(True)
axs[3].set_xlabel("Date")
axs[3].set_title("{:} - RaspberryShake {:}.{:}.{:}.{:}".format(
    sEventName2, st1f[0].stats.network, st1f[0].stats.station, st1f[0].stats.location,
    st1f[0].stats.channel))
axs[3].set_ylabel("Velocity (mm/s)")
axs[3].set_ylim(-0.025, 0.025)

plt.autoscale(enable=True, axis='x', tight=True)
plt.tight_layout() 
plt.savefig('plotSeismogramSpectrogram_paper.png', dpi=300)
