#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 31 18:19:05 2023

@author: ben

script to import and plot S arrivals, to be compared to models derived from P receiver functions and autocorrelograms

"""
from functions_sWave import *
from telewavesim import utils
from obspy.taup import TauPyModel
import pickle
import os

stations = [['US.NATX',3]] #station and number of layers

indir = os.getcwd() + '/S_DATA/'
outDir = os.getcwd() + '/sWave_Compare/'

if not os.path.exists(outDir):
    os.makedirs(outDir)

for station in stations:
    nlayers = station[1]
    station = station[0]
    
    print(station)
    print(str(nlayers) + ' layers')
    if station == '441' or station == 'U40' or station == 'X40' or station == 'U38' or station == 'X37' or station == 'Z38' or station == 'Z41':
        mod = pickle.load(open(os.getcwd() + '/models/' + 'BS-' + station + '_model.p','rb'))
    elif station == 'HNVL,438A':
        mod = pickle.load(open(os.getcwd() + '/models/' + 'BS-HNVL_model.p','rb'))
    else:
        mod = pickle.load(open(os.getcwd() + '/models/' + station.replace('.','-') + '_model.p','rb'))
        
    model = mod
    
    Z, R = import_S(station,indir) #import S arrivals
    
    #resample to match synthetics
    Z.resample(20)
    R.resample(20)
    
    #ray parameter bin width too use
    bin_width = .0003
    Z_binned,slows_Z = bin_S(Z, bin_width)
    R_binned,slows_R = bin_S(R, bin_width)
    
    #match with synthetics
    stZ = Stream()
    stN = Stream()
    for i in range(len(slows_Z)):
        slw = slows_Z[i]
        npts = 2000
        dt = Z[0].stats.delta
        
        st = utils.run_plane(model, slw, npts, dt, baz=0, wvtype='SV')
        st.filter('bandpass',freqmin=Z_binned[i].stats.sac.kuser1, freqmax=Z_binned[i].stats.sac.kuser0, corners=5, zerophase=True)
        sig.scale(st[2],factor=-1.0)
        
        st.trim(st[0].stats.starttime - (10 - st[1].data.argmin()*dt), st[0].stats.endtime - (10 - st[1].data.argmin()*dt), pad=True, fill_value=0)
        st.normalize()
        st.sort()
        
        stZ.append(st[2])
        stN.append(st[1])
        
    if nlayers == 3:
        lb1 = 'Layer 1 - ' + str(round(model.thickn[0],2)) + ' km, ' + str(round(model.vp[0],2)) + ' km/s Vp, ' + str(round(model.vp[0]/model.vs[0],2)) + ' Vp/Vs'
        lb2 = 'Layer 2 - ' + str(round(model.thickn[1],2)) + ' km, ' + str(round(model.vp[1],2)) + ' km/s Vp, ' + str(round(model.vp[1]/model.vs[1],2)) + ' Vp/Vs'
        lb3 = 'Layer 3 - ' + str(round(model.thickn[2],2)) + ' km, ' + str(round(model.vp[2],2)) + ' km/s Vp, ' + str(round(model.vp[2]/model.vs[2],2)) + ' Vp/Vs'
        
        
        pw_wiggles_S_3layer(Z_binned, R_binned, stZ, stN, station + ' from RF model', lb1, lb2, lb3, tmin=-10, tmax=20, scale=.0005, save=True, ftitle=outDir + station, fmt='png')
        
        #uncomment this block to model S arrivals as well as plot them
        '''
        cost,pos = modelStation_S_3layer(Z,model,5,48)
        
        vR = [pos[3], pos[4], pos[5]]
        vp = [pos[6], pos[7], pos[8], 7.8]
        vs = [vp[0]/vR[0], vp[1]/vR[1], vp[2]/vR[2], 4.48]
        thic = [pos[0], pos[1], pos[2], 10]
        model_best = utils.Model(thic, [2550, 2750, 3100, 3300], vp, vs)
        
        stZ_best = Stream()
        stN_best = Stream()
        
        for i in range(len(slows_Z)):
            slw = slows_Z[i]
            npts = 2000
            dt = Z[0].stats.delta
            
            st = utils.run_plane(model_best, slw, npts, dt, baz=0, wvtype='SV')
            st.filter('bandpass',freqmin=Z_binned[i].stats.sac.kuser1, freqmax=Z_binned[i].stats.sac.kuser0, corners=5, zerophase=True)
            sig.scale(st[2],factor=-1.0)
            st.trim(st[0].stats.starttime - (10 - st[1].data.argmin()*dt), st[0].stats.endtime - (10 - st[1].data.argmin()*dt), pad=True, fill_value=0)
            st.normalize()
            st.sort()
            stZ_best.append(st[2])
            stN_best.append(st[1])
        
        lb1 = 'Layer 1 - ' + str(round(thic[0],2)) + ' km, ' + str(round(vp[0],2)) + ' km/s Vp, ' + str(round(vR[0],2)) + ' Vp/Vs'
        lb2 = 'Layer 2 - ' + str(round(thic[1],2)) + ' km, ' + str(round(vp[1],2)) + ' km/s Vp, ' + str(round(vR[1],2)) + ' Vp/Vs'
        lb3 = 'Layer 2 - ' + str(round(thic[2],2)) + ' km, ' + str(round(vp[2],2)) + ' km/s Vp, ' + str(round(vR[2],2)) + ' Vp/Vs'
        
        pw_wiggles_S_3layer(Z_binned, R_binned, stZ_best, stN_best, station + ' modeled from S waves', lb1, lb2, lb3, tmin=-10, tmax=30, scale=.0005, save=True, ftitle=outDir + station + 'modeled', fmt='png')
        '''
    elif nlayers == 2:
        lb1 = 'Layer 1 - ' + str(round(model.thickn[0],2)) + ' km, ' + str(round(model.vp[0],2)) + ' km/s Vp, ' + str(round(model.vp[0]/model.vs[0],2)) + ' Vp/Vs'
        lb2 = 'Layer 2 - ' + str(round(model.thickn[1],2)) + ' km, ' + str(round(model.vp[1],2)) + ' km/s Vp, ' + str(round(model.vp[1]/model.vs[1],2)) + ' Vp/Vs'
        
        pw_wiggles_S_2Layer(Z_binned, R_binned, stZ, stN, station, lb1, lb2, tmin=-10, tmax=30, scale=.0005, save=True, ftitle=outDir + station, fmt='png')
        
        #uncomment this block to model S arrivals as well as plot them
        '''
        cost,pos = modelStation_S(Z,model,2,48)
        
        vR = [pos[2], pos[3]]
        vp = [pos[4], pos[5], 7.8]
        vs = [vp[0]/vR[0], vp[1]/vR[1], 4.48]
        thic = [pos[0], pos[1], 10]
        model_best = utils.Model(thic, [2550, 2750, 3300], vp, vs)
        
        stZ_best = Stream()
        stN_best = Stream()
        
        for slw in slows_Z:
            npts = 2000
            dt = Z[0].stats.delta
            
            st = utils.run_plane(model_best, slw, npts, dt, baz=0, wvtype='SV')
            st.filter('bandpass',freqmin=freqmin, freqmax=freqmax, corners=5, zerophase=True)
            sig.scale(st[2],factor=-1.0)
            st.trim(st[0].stats.starttime - (10 - st[1].data.argmin()*dt), st[0].stats.endtime - (10 - st[1].data.argmin()*dt), pad=True, fill_value=0)
            st.normalize()
            st.sort()
            stZ_best.append(st[2])
            stN_best.append(st[1])
        
        lb1 = 'Layer 1 - ' + str(round(thic[0],2)) + ' km, ' + str(round(vp[0],2)) + ' km/s Vp, ' + str(round(vR[0],2)) + ' Vp/Vs'
        lb2 = 'Layer 2 - ' + str(round(thic[1],2)) + ' km, ' + str(round(vp[1],2)) + ' km/s Vp, ' + str(round(vR[1],2)) + ' Vp/Vs'
        pw_wiggles_S_2Layer(Z_binned, R_binned, stZ_best, stN_best, station + ' modeled', lb1, lb2, tmin=-10, tmax=30, scale=.0005, save=True, ftitle=outDir + station + 'modeled', fmt='png')
          '''
    elif nlayers == 1:
        lb1 = 'Layer 1 - ' + str(round(model.thickn[0],2)) + ' km, ' + str(round(model.vp[0],2)) + ' km/s Vp, ' + str(round(model.vp[0]/model.vs[0],2)) + ' Vp/Vs'

        pw_wiggles_S_1Layer(Z_binned, R_binned, stZ, stN, station, lb1, tmin=-10, tmax=30, scale=.0005, save=True, ftitle=outDir + station, fmt='png')
        
        #uncomment this block to model S arrivals as well as plot them
        '''
        cost,pos = modelStation_S(Z,model,2,48)
        
        vR = [pos[1]]
        vp = [pos[2],7.8]
        vs = [vp[0]/vR[0], 4.48]
        thic = [pos[0], 10]
        model_best = utils.Model(thic, [2750, 3300], vp, vs)
        
        stZ_best = Stream()
        stN_best = Stream()
        
        for slw in slows_Z:
            npts = 2000
            dt = Z[0].stats.delta
            
            st = utils.run_plane(model_best, slw, npts, dt, baz=0, wvtype='SV')
            st.filter('bandpass',freqmin=freqmin, freqmax=freqmax, corners=5, zerophase=True)
            sig.scale(st[2],factor=-1.0)
            st.trim(st[0].stats.starttime - (10 - st[1].data.argmin()*dt), st[0].stats.endtime - (10 - st[1].data.argmin()*dt), pad=True, fill_value=0)
            st.normalize()
            st.sort()
            stZ_best.append(st[2])
            stN_best.append(st[1])
        
        lb1 = 'Layer 1 - ' + str(round(thic[0],2)) + ' km, ' + str(round(vp[0],2)) + ' km/s Vp, ' + str(round(vR[0],2)) + ' Vp/Vs'

        pw_wiggles_S_1Layer(Z_binned, R_binned, stZ_best, stN_best, station + ' modeled', lb1, tmin=-10, tmax=30, scale=.0005, save=True, ftitle=outDir + station + 'modeled', fmt='png')
          '''