#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Apr  1 16:08:18 2023

@author: ben

functions for teleseismic S processing. plotting scripts modified from telewavesim
"""
import matplotlib.pyplot as plt
import numpy as np
import pdb as pdb
#import os
import glob
from obspy import read, Stream
#from obspy.taup import TauPyModel
from obspy.realtime import signal as sig
import os.path

def import_S(stn,inDir):
    #model_ak = TauPyModel(model='ak135') #load taup with ak135
    events = sorted(glob.glob(inDir + '/*/')) #events directory
    if 'plots' in events[-1]:
        del events[-1]
    st_R = Stream()
    st_Z = Stream()
    
    stnB = 'null'
    
    if ',' in stn:
        stnB = stn.split(',')[0]
        #print(stnB)
        stn = stn.split(',')[1]
        #print(stn)
        
    for evnt in events:
        if os.path.isfile(evnt + "list_r_S.txt"):
            list_z_S = open(evnt + "list_r_S.txt","r")
            Lines = list_z_S.readlines()
            
            count = 0
            # Strips the newline character
            for line in Lines:
                count += 1
                
                #if stn[0:4] in line.strip():
                #print(line.strip())
                    
                if stn in line.strip() or stnB in line.strip():
                    '''
                    if stn in line.strip():
                        print(stn)
                    if stnB in line.strip():
                        print(stnB)
                    '''
                    l = line.strip().split(' ')
                    ev_st_R = read(evnt + l[0])
                    ev_st_Z = read(evnt + l[0].replace("HR.SAC", "HZ.SAC"))
                    #can calculate slowness or get it from .txt file. The one calculated here is typically ~0.0001, or ~.08%, higher
                    #slow = model_ak.get_ray_paths(ev_st_Z[0].stats.sac.evdp, ev_st_Z[0].stats.sac.gcarc, ['S'])[0].ray_param_sec_degree / 111.1 #s/km
                    slow = float(l[7])
                    ev_st_R[0].stats.slow = slow
                    ev_st_Z[0].stats.slow = slow
                    
                    #print('difference between calcultate and loaded ray parameter is : ' + str((model_ak.get_ray_paths(ev_st_Z[0].stats.sac.evdp, ev_st_Z[0].stats.sac.gcarc, ['S'])[0].ray_param_sec_degree / 111.1) - float(l[7])) + ' s/km')
                    
                    if int(l[4]) == -1:
                        #print(int(l[4]))
                        
                        ev_st_R[0]
                        sig.scale(ev_st_Z[0],factor=-1.0)
                        sig.scale(ev_st_R[0],factor=-1.0)
                        
                    ev_st_Z[0].trim(starttime=ev_st_Z[0].stats.starttime+(50+float(l[2])),endtime=ev_st_Z[0].stats.endtime+float(l[2]),pad=True,fill_value=0)
                    ev_st_R[0].trim(starttime=ev_st_R[0].stats.starttime+(50+float(l[2])),endtime=ev_st_R[0].stats.endtime+float(l[2]),pad=True,fill_value=0)
                    
                    ny = 1/(2.*ev_st_Z[0].stats.delta)                       # Nyquist frequency
                    Fsin = np.fft.rfft(ev_st_Z[0].data)              # FFT to frequency domain
                    f = np.linspace(0, ny, len(Fsin))    # frequency axis for plotting
                    fft_norm = abs(Fsin)/max(abs(Fsin))
                    
                    #minimum power used to make frequency cutoffs for bandpass
                    minPower = .44
                    
                    ev_st_Z[0].stats.sac.kuser0 = np.where(fft_norm >= minPower)[0][-1] / 100
                    ev_st_R[0].stats.sac.kuser0 = np.where(fft_norm >= minPower)[0][-1] / 100
                    
                    ev_st_Z[0].stats.sac.kuser1 = np.where(fft_norm >= minPower)[0][0] / 100
                    ev_st_R[0].stats.sac.kuser1 = np.where(fft_norm >= minPower)[0][0] / 100
                    
                    '''
                    plt.plot(f, fft_norm, 'b')
                    plt.xlim(0,2)
                    plt.title('Frequency Domain')
                    plt.xlabel('Frequency [Hz]')
                    plt.ylabel('Amplitude')
                    
                    plt.show()
                    
                    print('Corners for power cutoff: ' + str(minPower))
                    print('lcorner: ' + str( np.where(fft_norm >= minPower)[0][0] / 100 ))
                    print('hcorner: ' + str( np.where(fft_norm >= minPower)[0][-1] / 100 ))
                    '''
                    st_Z.append(ev_st_Z[0])
                    st_R.append(ev_st_R[0])
                        
                '''
                elif len(stn) > 4:
                    print(stn)
                    if stn[5:8] in line.strip():
                        print('in second if statement')
                        print(stn[5:8])
                        l = line.strip().split(' ')
                        print(l[0])
                        ev_st_R = read(evnt + l[0])
                        ev_st_Z = read(evnt + l[0].replace("HR.SAC", "HZ.SAC"))
                        #can calculate slowness or get it from .txt file. The one calculated here is typically ~0.0001, or ~.08%, higher
                        #slow = model_ak.get_ray_paths(ev_st_Z[0].stats.sac.evdp, ev_st_Z[0].stats.sac.gcarc, ['S'])[0].ray_param_sec_degree / 111.1 #s/km
                        slow = float(l[7])
                        ev_st_R[0].stats.slow = slow
                        ev_st_Z[0].stats.slow = slow
                        
                        #print('difference between calcultate and loaded ray parameter is : ' + str((model_ak.get_ray_paths(ev_st_Z[0].stats.sac.evdp, ev_st_Z[0].stats.sac.gcarc, ['S'])[0].ray_param_sec_degree / 111.1) - float(l[7])) + ' s/km')
                        
                        if int(l[4]) == -1:
                            #print(int(l[4]))
                            
                            ev_st_R[0]
                            sig.scale(ev_st_Z[0],factor=-1.0)
                            sig.scale(ev_st_R[0],factor=-1.0)
                    
                    
                    
                    
                        ev_st_Z[0].trim(starttime=ev_st_Z[0].stats.starttime+(50+float(l[2])),endtime=ev_st_Z[0].stats.endtime+float(l[2]),pad=True,fill_value=0)
                        ev_st_R[0].trim(starttime=ev_st_R[0].stats.starttime+(50+float(l[2])),endtime=ev_st_R[0].stats.endtime+float(l[2]),pad=True,fill_value=0)
                    
                        ny = 1/(2.*ev_st_Z[0].stats.delta)                       # Nyquist frequency
                        Fsin = np.fft.rfft(ev_st_Z[0].data)              # FFT to frequency domain
                        f = np.linspace(0, ny, len(Fsin))    # frequency axis for plotting
                        fft_norm = abs(Fsin)/max(abs(Fsin))
                        
                        #print(np.where(fft_norm >= .2)[0][-1]) #check fft for last occurence of given amplitude
                        #print(f[np.where(fft_norm >= .2)[0][-1]]) 
                        
                        ev_st_Z[0].stats.sac.kuser0 = np.where(fft_norm >= .2)[0][-1] / 100
                        ev_st_R[0].stats.sac.kuser0 = np.where(fft_norm >= .2)[0][-1] / 100
                        
                        plt.plot(f, fft_norm, 'b')
                        plt.xlim(0,2)
                        plt.title('Frequency Domain')
                        plt.xlabel('Frequency [Hz]')
                        plt.ylabel('Amplitude')
                        
                        plt.show()
                        
                        print(ev_st_Z[0].stats.sac.kuser0)
                    
                        st_Z.append(ev_st_Z[0])
                        print(st_Z[-1].stats.sac.kuser0)
                        st_R.append(ev_st_R[0])
                '''
    
    st_Z.normalize()
    st_R.normalize()
    return st_Z,st_R

def bin_S(st, width):
    #function for binning S waves by ray parameter
    slows = []

    st_binned = Stream()

    for tr in st:
        slow = tr.stats.slow
        #print('ray parameter is: ' + str(slow))
        if len(slows) > 0:
            nn = slows[min(range(len(slows)), key = lambda i: abs(slows[i]-slow))]
            #print('nearest neighbor is: ' + str(nn))
            if abs(slow-nn) < width:
                #pdb.set_trace()
                #print('ray parameter is: ' + str(slow))
                #print('nearest neighbor is: ' + str(nn) +', at position: ' + str(slows.index(nn)))
                #print(st_binned[slows.index(nn)])
                st_binned[slows.index(nn)].data = st_binned[slows.index(nn)].data + tr.data
                st_binned[slows.index(nn)].normalize()
                st_binned[slows.index(nn)].stats.slow = (slow + nn) / 2
                slows[slows.index(nn)] = (slow + nn) / 2
                
            else:
                st_binned.append(tr)
                slows.append(tr.stats.slow)
        else:
            st_binned.append(tr)
            slows.append(tr.stats.slow)
        
        
    slows_binned = []
    for tr in st_binned:
        slows_binned.append(tr.stats.slow)
        
    #print(slows)
    #print(slows_binned)
    return st_binned,slows_binned

def pw_wiggles_S(str1, str2, sta, btyp='slow', t1=None, tmin=0., tmax=30,
                   scale=None, save=False, ftitle='Figure_pw_wiggles_baz',
                   wvtype='P', fmt='png'):
    """
    Plots plane wave seismograms sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        t1 (float):
            Predicted arrival time that will be drawn as vertical line (s)
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        print('type has to be "baz" or "slow" or "dist"')
        return

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(nn)/sr + tmin

    # Initialize figure
    fig = plt.figure()
    plt.clf()

    # Get more control on subplots
    ax1 = fig.add_axes([0.1, 0.1, 0.3, 0.83])
    ax2 = fig.add_axes([0.45, 0.1, 0.3, 0.83])

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:
        # tr.data = np.fft.fftshift(tr.data)
        
        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 100
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        #pdb.set_trace()
        ax1.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax1.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)
        ax1.set_title('Vertical')
        
    if t1 is not None:
        ax1.axvline(t1, c='gold', ls='--',
                    lw=plt.rcParams['lines.linewidth']/2)

    ax1.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax1.set_ylim(-5, 370)
        ax1.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            ax1.set_ylim(0.038, 0.082)
        elif wvtype == 'S':
            ax1.set_ylim(0.1115, 0.1275)
        elif wvtype == 'SKS':
            ax1.set_ylim(0.03, 0.06)
        ax1.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax1.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax1.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax1.set_ylim(83., 117.)
        ax1.set_ylabel('Distance (deg)')

    ax1.set_xlabel('Time (sec)')
    ax1.grid(ls=':')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        # tr.data = np.fft.fftshift(tr.data)

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)
        ax2.set_title('Radial')

    if t1 is not None:
        ax2.axvline(t1, c='gold', ls='--',
                    lw=plt.rcParams['lines.linewidth']/2)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)

    elif btyp == 'slow':
        if wvtype == 'P':
            ax2.set_ylim(0.038, 0.082)
        elif wvtype == 'S':
            ax2.set_ylim(0.1115, 0.1275)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)

    ax2.set_xlabel('Time (sec)')
    ax2.set_yticklabels([])
    ax2.grid(ls=':')

    if save:
        plt.savefig(ftitle+'.'+fmt, dpi=300, bbox_inches='tight', format=fmt)
    else:
        plt.show()

    return

def pw_wiggles_S_1Layer(str1, str2, str1_synth, str2_synth, sta, l1='Layer 1 info', btyp='slow', t1=None, tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_pw_wiggles_baz',
                   wvtype='S', fmt='png'):
    """
    Plots plane wave seismograms sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        t1 (float):
            Predicted arrival time that will be drawn as vertical line (s)
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        print('type has to be "baz" or "slow" or "dist"')
        return

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(nn)/sr + tmin

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    
    plt.figtext(0.28,1,sta, fontsize=14, fontweight='bold')
    plt.figtext(.2,-.05,l1, fontweight='bold')

    # Get more control on subplots
    ax1 = fig.add_axes([-0.1, 0.1, 0.5, 0.83])
    ax2 = fig.add_axes([0.45, 0.1, 0.5, 0.83])

    slows = []
    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:
        # tr.data = np.fft.fftshift(tr.data)
        slows.append(tr.stats.slow)
        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 100
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        #pdb.set_trace()
        ax1.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax1.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)
        ax1.set_title('Vertical')
        
    if t1 is not None:
        ax1.axvline(t1, c='gold', ls='--',
                    lw=plt.rcParams['lines.linewidth']/2)

    ax1.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax1.set_ylim(-5, 370)
        ax1.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            ax1.set_ylim(0.038, 0.082)
        elif wvtype == 'S':
            ax1.set_ylim(min(slows)- .0005, max(slows)+ .0005)
        elif wvtype == 'SKS':
            ax1.set_ylim(0.03, 0.06)
        ax1.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax1.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax1.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax1.set_ylim(83., 117.)
        ax1.set_ylabel('Distance (deg)')

    ax1.set_xlabel('Time (sec)')
    ax1.grid(ls=':')
    
    for tr in str1_synth:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20
                
        #plot these as lines up on real
        ax1.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        # tr.data = np.fft.fftshift(tr.data)

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)
        ax2.set_title('Radial')

    if t1 is not None:
        ax2.axvline(t1, c='gold', ls='--',
                    lw=plt.rcParams['lines.linewidth']/2)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)

    elif btyp == 'slow':
        if wvtype == 'P':
            ax2.set_ylim(0.038, 0.082)
        elif wvtype == 'S':
            ax2.set_ylim(min(slows) - .0005, max(slows) + .0005)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)

    ax2.set_xlabel('Time (sec)')
    ax2.set_yticklabels([])
    ax2.grid(ls=':')
    
    for tr in str2_synth:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20
                
        #plot these as lines up on real
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)

    if save:
        plt.savefig(ftitle+'.'+fmt, dpi=600, bbox_inches='tight', format=fmt)
    else:
        plt.show()

    return

def pw_wiggles_S_2Layer(str1, str2, str1_synth, str2_synth, sta, l1='Layer 1 info', l2='Layer 2 info', btyp='slow', t1=None, tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_pw_wiggles_baz',
                   wvtype='S', fmt='png'):
    """
    Plots plane wave seismograms sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        t1 (float):
            Predicted arrival time that will be drawn as vertical line (s)
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        print('type has to be "baz" or "slow" or "dist"')
        return

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(nn)/sr + tmin

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    
    plt.figtext(0.28,1,sta, fontsize=14, fontweight='bold')
    plt.figtext(.2,-.05,l1, fontweight='bold')
    plt.figtext(.2,-.1,l2, fontweight='bold')

    # Get more control on subplots
    ax1 = fig.add_axes([-0.1, 0.1, 0.5, 0.83])
    ax2 = fig.add_axes([0.45, 0.1, 0.5, 0.83])

    slows = []
    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:
        # tr.data = np.fft.fftshift(tr.data)
        slows.append(tr.stats.slow)
        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 100
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        #pdb.set_trace()
        ax1.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax1.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)
        ax1.set_title('Vertical')
        
    if t1 is not None:
        ax1.axvline(t1, c='gold', ls='--',
                    lw=plt.rcParams['lines.linewidth']/2)

    ax1.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax1.set_ylim(-5, 370)
        ax1.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            ax1.set_ylim(0.038, 0.082)
        elif wvtype == 'S':
            ax1.set_ylim(min(slows)- .0005, max(slows)+ .0005)
        elif wvtype == 'SKS':
            ax1.set_ylim(0.03, 0.06)
        ax1.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax1.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax1.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax1.set_ylim(83., 117.)
        ax1.set_ylabel('Distance (deg)')

    ax1.set_xlabel('Time (sec)')
    ax1.grid(ls=':')
    
    for tr in str1_synth:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20
                
        #plot these as lines up on real
        ax1.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        # tr.data = np.fft.fftshift(tr.data)

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)
        ax2.set_title('Radial')

    if t1 is not None:
        ax2.axvline(t1, c='gold', ls='--',
                    lw=plt.rcParams['lines.linewidth']/2)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)

    elif btyp == 'slow':
        if wvtype == 'P':
            ax2.set_ylim(0.038, 0.082)
        elif wvtype == 'S':
            ax2.set_ylim(min(slows) - .0005, max(slows) + .0005)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)

    ax2.set_xlabel('Time (sec)')
    ax2.set_yticklabels([])
    ax2.grid(ls=':')
    
    for tr in str2_synth:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20
                
        #plot these as lines up on real
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)

    if save:
        plt.savefig(ftitle+'.'+fmt, dpi=600, bbox_inches='tight', format=fmt)
    else:
        plt.show()

    return

def pw_wiggles_S_3layer(str1, str2, str1_synth, str2_synth, sta, l1='Layer 1 info', l2='Layer 2 info', l3='Layer 3 info', btyp='slow', t1=None, tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_pw_wiggles_baz',
                   wvtype='S', fmt='png'):
    """
    Plots plane wave seismograms sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        t1 (float):
            Predicted arrival time that will be drawn as vertical line (s)
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        print('type has to be "baz" or "slow" or "dist"')
        return

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(nn)/sr + tmin
    
    #pdb.set_trace()
    # Initialize figure
    fig = plt.figure()
    plt.clf()
    
    plt.figtext(0.16,1,sta, fontsize=14, fontweight='bold')
    plt.figtext(.2,-.05,l1, fontweight='bold')
    plt.figtext(.2,-.1,l2, fontweight='bold')
    plt.figtext(.2,-.15,l3, fontweight='bold')

    # Get more control on subplots
    ax1 = fig.add_axes([-0.1, 0.1, 0.5, 0.83])
    ax2 = fig.add_axes([0.45, 0.1, 0.5, 0.83])

    slows = []
    # Plot sorted traces from str1 on bottom left panel
    
    for tr in str1:
        # tr.data = np.fft.fftshift(tr.data)
        slows.append(tr.stats.slow)
        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 100
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        #pdb.set_trace()
        ax1.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax1.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)
        ax1.set_title('Vertical')
        
    if t1 is not None:
        ax1.axvline(t1, c='gold', ls='--',
                    lw=plt.rcParams['lines.linewidth']/2)

    ax1.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax1.set_ylim(-5, 370)
        ax1.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            ax1.set_ylim(0.038, 0.082)
        elif wvtype == 'S':
            ax1.set_ylim(min(slows)- .0005, max(slows)+ .0005)
        elif wvtype == 'SKS':
            ax1.set_ylim(0.03, 0.06)
        ax1.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax1.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax1.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax1.set_ylim(83., 117.)
        ax1.set_ylabel('Distance (deg)')

    ax1.set_xlabel('Time (sec)')
    ax1.grid(ls=':')
    
    for tr in str1_synth:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20
                
        #plot these as lines up on real
        ax1.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        # tr.data = np.fft.fftshift(tr.data)

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)
        ax2.set_title('Radial')

    if t1 is not None:
        ax2.axvline(t1, c='gold', ls='--',
                    lw=plt.rcParams['lines.linewidth']/2)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)

    elif btyp == 'slow':
        if wvtype == 'P':
            ax2.set_ylim(0.038, 0.082)
        elif wvtype == 'S':
            ax2.set_ylim(min(slows) - .0005, max(slows) + .0005)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)

    ax2.set_xlabel('Time (sec)')
    ax2.set_yticklabels([])
    ax2.grid(ls=':')
    
    for tr in str2_synth:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20
                
        #plot these as lines up on real
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)

    if save:
        plt.savefig(ftitle+'.'+fmt, dpi=600, bbox_inches='tight', format=fmt)
    else:
        plt.show()

    return