#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''

Functions to plot single station plane wave or receiver function seismograms. 
They were modified from plotting functions in the telewavesim package (see below).

From the original script:
    
# Copyright 2019 Pascal Audet

# This file is part of Telewavesim.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

'''

import matplotlib.pyplot as plt
import numpy as np


def rf_wiggles_RaS(str1, str2, tr1, tr2, sta, btyp='baz', tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_rf_wiggle_baz',
                   wvtype='P'):
    """
    redefining function BS
    Plots receiver function seismograms sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        tr1 (obspy.trace):
            Trace 1 (normally obtained from the ``utils.stack_all`` function)
        tr2 (obspy.trace): Trace 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        raise ValueError('type has to be "baz" or "slow" or "dist"')

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(-nn/2, nn/2)/sr

    # Initialize figure
    fig = plt.figure()
    plt.clf()

    # Get more control on subplots
    ax1 = fig.add_axes([0.1, 0.825, 0.3, 0.05])
    ax2 = fig.add_axes([0.1, 0.1, 0.3, 0.7])
    ax3 = fig.add_axes([0.45, 0.825, 0.3, 0.05])
    ax4 = fig.add_axes([0.45, 0.1, 0.3, 0.7])

    # Plot stack of all traces from str1 on top left
    ax1.fill_between(time, 0., tr1.data, where=tr1.data+1e-6 <= 0.,
                     facecolor='blue', linewidth=0)
    ax1.fill_between(time, 0., tr1.data, where=tr1.data+1e-6 >= 0.,
                     facecolor='red', linewidth=0)
    ax1.set_ylim(-np.max(np.abs(tr1.data)), np.max(np.abs(tr1.data)))
    ax1.set_yticks(())
    ax1.set_xticks(())
    ax1.set_title('Real RRF\'s')
    ax1.set_xlim(tmin, tmax)

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 180
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)
        ax2.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax2.set_ylim(0.038, 0.082)
            ax2.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax2.set_ylim(0.07, 0.125)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
        ax2.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)
        ax2.set_ylabel('Distance (deg)')

    ax2.set_xlabel('Time (sec)')
    ax2.grid(ls=':')

    # Plot stack of all SH traces on top right
    ax3.fill_between(time, 0., tr2.data, where=tr2.data+1e-6 <= 0.,
                     facecolor='blue', linewidth=0)
    ax3.fill_between(time, 0., tr2.data, where=tr2.data+1e-6 >= 0.,
                     facecolor='red', linewidth=0)
    ax3.set_xlim(tmin, tmax)
    ax3.set_ylim(-np.max(np.abs(tr1.data)), np.max(np.abs(tr1.data)))
    ax3.set_yticks(())
    ax3.set_xticks(())
    ax3.set_title('Synthetic RRF\'s')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax4.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax4.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax4.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax4.set_ylim(-5, 370)

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax4.set_ylim(0.038, 0.082)
            ax4.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax4.set_ylim(0.074, 0.125)
        elif wvtype == 'SKS':
            ax4.set_ylim(0.03, 0.06)
    elif btyp == 'dist':
        if wvtype == 'P':
            ax4.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax4.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax4.set_ylim(83., 117.)

    ax4.set_xlabel('Time (sec)')
    ax4.set_yticklabels([])
    ax4.grid(ls=':')

    if save:
        plt.savefig(ftitle+'.eps', dpi=300, bbox_inches='tight', format='eps')
    else:
        plt.show()
        
    plt.close()
    return

def rf_wiggles_RaS_label(str1, str2, tr1, tr2, sta, l1, l2, btyp='baz', tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_rf_wiggle_baz',
                   wvtype='P'):
    """
    redefining function BS
    Plots receiver function seismograms sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        tr1 (obspy.trace):
            Trace 1 (normally obtained from the ``utils.stack_all`` function)
        tr2 (obspy.trace): Trace 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        raise ValueError('type has to be "baz" or "slow" or "dist"')

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(-nn/2, nn/2)/sr

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    plt.figtext(.18,.95 + .1,sta, fontsize=14, fontweight='bold')
    plt.figtext(.2,.05,l1, fontweight='bold')
    plt.figtext(.2,0,l2, fontweight='bold')

    # Get more control on subplots
    ax1 = fig.add_axes([0.1, 0.825 + .1, 0.3, 0.05])
    ax2 = fig.add_axes([0.1, 0.1 + .1, 0.3, 0.7])
    ax3 = fig.add_axes([0.45, 0.825 + .1, 0.3, 0.05])
    ax4 = fig.add_axes([0.45, 0.1 + .1, 0.3, 0.7])

    # Plot stack of all traces from str1 on top left
    ax1.fill_between(time, 0., tr1.data, where=tr1.data+1e-6 <= 0.,
                     facecolor='blue', linewidth=0)
    ax1.fill_between(time, 0., tr1.data, where=tr1.data+1e-6 >= 0.,
                     facecolor='red', linewidth=0)
    ax1.set_ylim(-np.max(np.abs(tr1.data)), np.max(np.abs(tr1.data)))
    ax1.set_yticks(())
    ax1.set_xticks(())
    ax1.set_title('Real RRF\'s')
    ax1.set_xlim(tmin, tmax)

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 180
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)
        ax2.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax2.set_ylim(0.038, 0.082)
            ax2.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax2.set_ylim(0.07, 0.125)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
        ax2.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)
        ax2.set_ylabel('Distance (deg)')

    ax2.set_xlabel('Time (sec)')
    ax2.grid(ls=':')

    # Plot stack of all SH traces on top right
    ax3.fill_between(time, 0., tr2.data, where=tr2.data+1e-6 <= 0.,
                     facecolor='blue', linewidth=0)
    ax3.fill_between(time, 0., tr2.data, where=tr2.data+1e-6 >= 0.,
                     facecolor='red', linewidth=0)
    ax3.set_xlim(tmin, tmax)
    ax3.set_ylim(-np.max(np.abs(tr1.data)), np.max(np.abs(tr1.data)))
    ax3.set_yticks(())
    ax3.set_xticks(())
    ax3.set_title('Synthetic RRF\'s')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax4.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax4.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax4.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax4.set_ylim(-5, 370)

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax4.set_ylim(0.038, 0.082)
            ax4.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax4.set_ylim(0.074, 0.125)
        elif wvtype == 'SKS':
            ax4.set_ylim(0.03, 0.06)
    elif btyp == 'dist':
        if wvtype == 'P':
            ax4.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax4.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax4.set_ylim(83., 117.)

    ax4.set_xlabel('Time (sec)')
    ax4.set_yticklabels([])
    ax4.grid(ls=':')

    if save:
        plt.savefig(ftitle+'.png', dpi=600, bbox_inches='tight', format='png')
    else:
        plt.show()
        
    plt.close()
    return

def rf_wiggles_RaS_label_stacked(str1, str2, tr1, tr2, sta, l1, l2, btyp='baz', tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_rf_wiggle_baz',
                   wvtype='P'):
    """
    redefining function BS
    Plots real and synthetic receiver functions, stacked, sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        tr1 (obspy.trace):
            Trace 1 (normally obtained from the ``utils.stack_all`` function)
        tr2 (obspy.trace): Trace 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        raise ValueError('type has to be "baz" or "slow" or "dist"')

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(-nn/2, nn/2)/sr

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    plt.figtext(0,.95,sta, fontsize=14, fontweight='bold')
    plt.figtext(.04,.05,l1, fontweight='bold')
    plt.figtext(.04,0,l2, fontweight='bold')

    # Get more control on subplots
    ax2 = fig.add_axes([0.1, 0.1 + .1, 0.3, 0.7])

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 180
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)
        ax2.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax2.set_ylim(0.038, 0.082)
            ax2.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax2.set_ylim(0.07, 0.125)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
        ax2.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)
        ax2.set_ylabel('Distance (deg)')

    ax2.set_xlabel('Time (sec)')
    ax2.grid(ls=':')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        #plot these as lines up on real RF
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)


    if save:
        plt.savefig(ftitle+'.png', dpi=600, bbox_inches='tight', format='png')
    else:
        plt.show()

    plt.close()
    return

def rf_wiggles_RaS_label_stacked_wide(str1, str2, tr1, tr2, sta, l1, l2, btyp='baz', tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_rf_wiggle_baz',
                   wvtype='P'):
    """
    redefining function BS
    Plots real and synthetic receiver functions, stacked, sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        tr1 (obspy.trace):
            Trace 1 (normally obtained from the ``utils.stack_all`` function)
        tr2 (obspy.trace): Trace 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        raise ValueError('type has to be "baz" or "slow" or "dist"')

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(-nn/2, nn/2)/sr

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    plt.figtext(0,.95,sta, fontsize=14, fontweight='bold')
    plt.figtext(.04,.05,l1, fontweight='bold')
    plt.figtext(.04,0,l2, fontweight='bold')

    # Get more control on subplots
    ax2 = fig.add_axes([-0.3, 0.2, 1.1, 0.7])

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 180
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)
        ax2.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax2.set_ylim(0.038, 0.082)
            ax2.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax2.set_ylim(0.07, 0.125)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
        ax2.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)
        ax2.set_ylabel('Distance (deg)')

    ax2.set_xlabel('Time (sec)')
    ax2.grid(ls=':')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        #plot these as lines up on real RF
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)


    if save:
        plt.savefig(ftitle+'.png', dpi=600, bbox_inches='tight', format='png')
    else:
        plt.show()

    plt.close()
    return

def rf_wiggles_RaS_label_stacked_wide_3layer(str1, str2, tr1, tr2, sta, l1, l2, l3, btyp='baz', tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_rf_wiggle_baz',
                   wvtype='P'):
    """
    redefining function BS
    Plots real and synthetic receiver functions, stacked, sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        tr1 (obspy.trace):
            Trace 1 (normally obtained from the ``utils.stack_all`` function)
        tr2 (obspy.trace): Trace 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        raise ValueError('type has to be "baz" or "slow" or "dist"')

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(-nn/2, nn/2)/sr

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    plt.figtext(0,.95,sta, fontsize=14, fontweight='bold')
    plt.figtext(.04, .05,l1, fontweight='bold')
    plt.figtext(.04,0,l2, fontweight='bold')
    plt.figtext(.04,-.05,l3, fontweight='bold')

    # Get more control on subplots
    ax2 = fig.add_axes([-0.3, 0.2, 1.1, 0.7])

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 180
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)
        ax2.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax2.set_ylim(0.038, 0.082)
            ax2.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax2.set_ylim(0.07, 0.125)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
        ax2.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)
        ax2.set_ylabel('Distance (deg)')

    ax2.set_xlabel('Time (sec)')
    ax2.grid(ls=':')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        #plot these as lines up on real RF
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)


    if save:
        plt.savefig(ftitle+'.png', dpi=600, bbox_inches='tight', format='png')
    else:
        plt.show()

    plt.close()
    return

def rf_wiggles_RaS_label_stacked_wide_4layer(str1, str2, tr1, tr2, sta, l1, l2, l3, l4, btyp='baz', tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_rf_wiggle_baz',
                   wvtype='P'):
    """
    redefining function BS
    Plots real and synthetic receiver functions, stacked, sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        tr1 (obspy.trace):
            Trace 1 (normally obtained from the ``utils.stack_all`` function)
        tr2 (obspy.trace): Trace 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        raise ValueError('type has to be "baz" or "slow" or "dist"')

    print()
    print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(-nn/2, nn/2)/sr

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    plt.figtext(0,.95,sta, fontsize=14, fontweight='bold')
    plt.figtext(.04, .05,l1, fontweight='bold')
    plt.figtext(.04,0,l2, fontweight='bold')
    plt.figtext(.04,-.05,l3, fontweight='bold')
    plt.figtext(.04,-.10,l4, fontweight='bold')

    # Get more control on subplots
    ax2 = fig.add_axes([-0.3, 0.2, 1.1, 0.7])

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 180
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)
        ax2.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax2.set_ylim(0.038, 0.082)
            ax2.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax2.set_ylim(0.07, 0.125)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
        ax2.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)
        ax2.set_ylabel('Distance (deg)')

    ax2.set_xlabel('Time (sec)')
    ax2.grid(ls=':')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        #plot these as lines up on real RF
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)


    if save:
        plt.savefig(ftitle+'.png', dpi=600, bbox_inches='tight', format='png')
    else:
        plt.show()

    plt.close()
    return

def rf_wiggles_RaS_label_stacked_wide_5layer(str1, str2, tr1, tr2, sta, l1, l2, l3, l4, l5, btyp='baz', tmin=-10., tmax=30,
                   scale=None, save=False, ftitle='Figure_rf_wiggle_baz',
                   wvtype='P'):
    """
    redefining function BS
    Plots real and synthetic receiver functions, stacked, sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        tr1 (obspy.trace):
            Trace 1 (normally obtained from the ``utils.stack_all`` function)
        tr2 (obspy.trace): Trace 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        raise ValueError('type has to be "baz" or "slow" or "dist"')

    #print()
    #print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(-nn/2, nn/2)/sr

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    plt.figtext(0,.95,sta, fontsize=14, fontweight='bold')
    plt.figtext(.04, .05,l1, fontweight='bold')
    plt.figtext(.04,0,l2, fontweight='bold')
    plt.figtext(.04,-.05,l3, fontweight='bold')
    plt.figtext(.04,-.10,l4, fontweight='bold')
    plt.figtext(.04,-.15,l5, fontweight='bold')

    # Get more control on subplots
    ax2 = fig.add_axes([-0.3, 0.2, 1.1, 0.7])

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 180
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)
        ax2.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax2.set_ylim(0.038, 0.082)
            ax2.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax2.set_ylim(0.07, 0.125)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
        ax2.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)
        ax2.set_ylabel('Distance (deg)')

    ax2.set_xlabel('Time (sec)')
    ax2.grid(ls=':')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        #plot these as lines up on real RF
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)


    if save:
        plt.savefig(ftitle+'.png', dpi=600, bbox_inches='tight', format='png')
    else:
        plt.show()

    plt.close()
    return

def rf_wiggles_RaS_label_stacked_wide_5layer_withPoints(str1, str2, tr1, tr2, sta, l1, l2, l3, l4, l5, th1, th2, th3, th4, th5, vp1, vp2, vp3, vp4, vp5, vR1, vR2, vR3, vR4, vR5, btyp='baz', tmin=-5., tmax=30,
                   scale=None, save=False, ftitle='Figure_rf_wiggle_baz',
                   wvtype='P', isZ = False):
    """
    redefining function BS
    Plots real and synthetic receiver functions, stacked, sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        tr1 (obspy.trace):
            Trace 1 (normally obtained from the ``utils.stack_all`` function)
        tr2 (obspy.trace): Trace 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        raise ValueError('type has to be "baz" or "slow" or "dist"')

    #print()
    #print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(-nn/2, nn/2)/sr

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    plt.figtext(0,.95,sta, fontsize=14, fontweight='bold')
    plt.figtext(.04, .05,l1, fontweight='bold')
    plt.figtext(.04,0,l2, fontweight='bold')
    plt.figtext(.04,-.05,l3, fontweight='bold')
    plt.figtext(.04,-.10,l4, fontweight='bold')
    plt.figtext(.04,-.15,l5, fontweight='bold')

    # Get more control on subplots
    ax2 = fig.add_axes([-0.3, 0.2, 1.1, 0.7])

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 180
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)
        ax2.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax2.set_ylim(0.038, 0.082)
            ax2.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax2.set_ylim(0.07, 0.125)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
        ax2.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)
        ax2.set_ylabel('Distance (deg)')

    ax2.set_xlabel('Time (sec)')
    ax2.grid(ls=':')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        #plot these as lines up on real RF
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)
        #plot these points of Ps arrivals
        vs1 = vp1/vR1
        vs2 = vp2/vR2
        vs3 = vp3/vR3
        vs4 = vp4/vR4
        vs5 = vp5/vR5
        if isZ == False:
            t1 = th1*( (np.sqrt(1/(vs1*vs1))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp1*vp1))-(tr.stats.slow*tr.stats.slow)) )
            ax2.plot(t1, y, 'wo',markersize=3,markeredgecolor='black')
            
            t2 =t1 + (th2*( (np.sqrt(1/(vs2*vs2))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp2*vp2))-(tr.stats.slow*tr.stats.slow)) ))
            ax2.plot(t2, y, 'wo',markersize=3,markeredgecolor='black')
    
            t3 =t2 + (th3*( (np.sqrt(1/(vs3*vs3))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp3*vp3))-(tr.stats.slow*tr.stats.slow)) ))
            ax2.plot(t3, y, 'wo',markersize=3,markeredgecolor='black')
            
            t4 =t3 + (th4*( (np.sqrt(1/(vs4*vs4))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp4*vp4))-(tr.stats.slow*tr.stats.slow)) ))
            ax2.plot(t4, y, 'yo',markersize=3,markeredgecolor='black')
            
            t5 =t4 + (th5*( (np.sqrt(1/(vs5*vs5))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp5*vp5))-(tr.stats.slow*tr.stats.slow)) ))
            ax2.plot(t5, y, 'go',markersize=3,markeredgecolor='black')
        else:
            t1 = 2*th1*( (np.sqrt(1/(vp1*vp1))-(tr.stats.slow*tr.stats.slow)) )
            ax2.plot(t1, y, 'wo',markersize=3,markeredgecolor='black')
            
            t2 = t1 + 2*th2*( (np.sqrt(1/(vp2*vp2))-(tr.stats.slow*tr.stats.slow)) )
            ax2.plot(t2, y, 'wo',markersize=3,markeredgecolor='black')
            
            t3 = t2 + 2*th3*( (np.sqrt(1/(vp3*vp3))-(tr.stats.slow*tr.stats.slow)) )
            ax2.plot(t3, y, 'wo',markersize=3,markeredgecolor='black')
            
            t4 = t3 + 2*th4*( (np.sqrt(1/(vp4*vp4))-(tr.stats.slow*tr.stats.slow)) )
            ax2.plot(t4, y, 'yo',markersize=3,markeredgecolor='black')
            
            t5 = t4 + 2*th5*( (np.sqrt(1/(vp5*vp5))-(tr.stats.slow*tr.stats.slow)) )
            ax2.plot(t5, y, 'go',markersize=3,markeredgecolor='black')

    if save:
        plt.savefig(ftitle+'.png', dpi=600, bbox_inches='tight', format='png')
    else:
        plt.show()

    plt.close()
    return

def rf_wiggles_multiLayer(str1, str2, tr1, tr2, sta, l, th, vp, vR, btyp='baz', tmin=-5., tmax=30,
                   scale=None, save=False, ftitle='Figure_rf_wiggle_baz',
                   wvtype='P', isZ = False):
    """
    redefining function BS
    Plots real and synthetic receiver functions, stacked, sorted by back-azimuth or slowness.

    Args:
        str1 (obspy.stream): Stream 1
        str2 (obspy.stream): Stream 2
        tr1 (obspy.trace):
            Trace 1 (normally obtained from the ``utils.stack_all`` function)
        tr2 (obspy.trace): Trace 2
        sta (str): Station name
        btyp (str, optional): Type of sorting for panel
        tmin (float, optional): Lower bound of time axis (s)
        tmax (float, optional): Upper bound of time axis (s)
        scale (float, optional): Scaling factor
        save (bool, optional): Whether or not to save the figure
        ftitle (str, optional): Title of figure to be saved
        wvtype (str, optional): wave type

    Returns:
        None
    """

    if not (btyp == 'baz' or btyp == 'slow' or btyp == 'dist'):
        raise ValueError('type has to be "baz" or "slow" or "dist"')
    
    if len(l) >= 2:
        layers = 2
        
        l1 = l[0]
        l2 = l[1]
        
        th1 = th[0]
        th2 = th[1]
        
        vp1 = vp[0]
        vp2 = vp[1]
        
        vR1 = vR[0]
        vR2 = vR[1]
        
    if len(l) >= 3:
        layers = 3

        l3 = l[2]
        th3 = th[2]
        vp3 = vp[2]
        vR3 = vR[2]
        
    if len(l) >= 4:
        layers = 4

        l4 = l[3]
        th4 = th[3]
        vp4 = vp[3]
        vR4 = vR[3]
        
    if len(l) >= 5:
        layers = 5

        l5 = l[4]
        th5 = th[4]
        vp5 = vp[4]
        vR5 = vR[4]
        
        
    if len(l) > 5 or len(l) < 2:
        print('trying with unsupported number of layers, aint gonna work :(')
        
    #print()
    #print('Plotting Wiggles by '+btyp)

    # Time axis
    nn = str1[0].stats.npts
    sr = str1[0].stats.sampling_rate
    time = np.arange(-nn/2, nn/2)/sr

    # Initialize figure
    fig = plt.figure()
    plt.clf()
    plt.figtext(0,.95,sta, fontsize=14, fontweight='bold')
    plt.figtext(.04, .05,l1, fontweight='bold')
    plt.figtext(.04,0,l2, fontweight='bold')
    if layers >= 3:
        plt.figtext(.04,-.05,l3, fontweight='bold')
    if layers >= 4:
        plt.figtext(.04,-.10,l4, fontweight='bold')
    if layers == 5:
        plt.figtext(.04,-.15,l5, fontweight='bold')

    # Get more control on subplots
    ax2 = fig.add_axes([-0.3, 0.2, 1.1, 0.7])

    # Plot sorted traces from str1 on bottom left panel
    for tr in str1:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 180
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        # Fill positive in red, negative in blue
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 <= 0.,
                         facecolor='blue', linewidth=0)
        ax2.fill_between(time, y, y+tr.data*maxval, where=tr.data+1e-6 >= 0.,
                         facecolor='red', linewidth=0)

    ax2.set_xlim(tmin, tmax)

    if btyp == 'baz':
        ax2.set_ylim(-5, 370)
        ax2.set_ylabel('Back-azimuth (deg)')

    elif btyp == 'slow':
        if wvtype == 'P':
            #ax2.set_ylim(0.038, 0.082)
            ax2.set_ylim(0.036, 0.088)
        elif wvtype == 'S':
            ax2.set_ylim(0.07, 0.125)
        elif wvtype == 'SKS':
            ax2.set_ylim(0.03, 0.06)
        ax2.set_ylabel('Slowness (s/km)')
    elif btyp == 'dist':
        if wvtype == 'P':
            ax2.set_ylim(28., 92.)
        elif wvtype == 'S':
            ax2.set_ylim(53., 107.)
        elif wvtype == 'SKS':
            ax2.set_ylim(83., 117.)
        ax2.set_ylabel('Distance (deg)')

    ax2.set_xlabel('Time (sec)')
    ax2.grid(ls=':')

    # Plot binned SH traces in back-azimuth on bottom right
    for tr in str2:

        if scale:
            maxval = scale
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
            elif btyp == 'slow':
                y = tr.stats.slow
            elif btyp == 'dist':
                y = tr.stats.slow
        else:
            # Define y axis
            if btyp == 'baz':
                y = tr.stats.baz
                maxval = 150
            elif btyp == 'slow':
                y = tr.stats.slow
                maxval = 0.02
            elif btyp == 'dist':
                y = tr.stats.slow
                maxval = 20

        #plot these as lines up on real RF
        ax2.plot(time, y+tr.data*maxval, color='black', linestyle='dashed', linewidth=1)
        #plot these points of Ps arrivals
        vs1 = vp1/vR1
        vs2 = vp2/vR2
        if layers >= 3:
            vs3 = vp3/vR3
        if layers >=4:
            vs4 = vp4/vR4
        if layers == 5:
            vs5 = vp5/vR5
        if isZ == False:
            t1 = th1*( (np.sqrt(1/(vs1*vs1))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp1*vp1))-(tr.stats.slow*tr.stats.slow)) )
            ax2.plot(t1, y, 'wo',markersize=3,markeredgecolor='black')
            
            t2 =t1 + (th2*( (np.sqrt(1/(vs2*vs2))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp2*vp2))-(tr.stats.slow*tr.stats.slow)) ))
            ax2.plot(t2, y, 'wo',markersize=3,markeredgecolor='black')
            
            if layers >= 3:
                t3 =t2 + (th3*( (np.sqrt(1/(vs3*vs3))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp3*vp3))-(tr.stats.slow*tr.stats.slow)) ))
                ax2.plot(t3, y, 'yo',markersize=3,markeredgecolor='black')
            
            if layers >= 4:
                t4 =t3 + (th4*( (np.sqrt(1/(vs4*vs4))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp4*vp4))-(tr.stats.slow*tr.stats.slow)) ))
                ax2.plot(t4, y, 'go',markersize=3,markeredgecolor='black')
            
            if layers == 5:
                
                t5 =t4 + (th5*( (np.sqrt(1/(vs5*vs5))-(tr.stats.slow*tr.stats.slow)) - (np.sqrt(1/(vp5*vp5))-(tr.stats.slow*tr.stats.slow)) ))
                ax2.plot(t5, y, 'go',markersize=3,markeredgecolor='black')
        else:
            t1 = 2*th1*( (np.sqrt(1/(vp1*vp1))-(tr.stats.slow*tr.stats.slow)) )
            ax2.plot(t1, y, 'wo',markersize=3,markeredgecolor='black')
            
            t2 = t1 + 2*th2*( (np.sqrt(1/(vp2*vp2))-(tr.stats.slow*tr.stats.slow)) )
            ax2.plot(t2, y, 'wo',markersize=3,markeredgecolor='black')

            if layers >= 3:
                t3 = t2 + 2*th3*( (np.sqrt(1/(vp3*vp3))-(tr.stats.slow*tr.stats.slow)) )
                ax2.plot(t3, y, 'yo',markersize=3,markeredgecolor='black')
            
            if layers >= 4:
                t4 = t3 + 2*th4*( (np.sqrt(1/(vp4*vp4))-(tr.stats.slow*tr.stats.slow)) )
                ax2.plot(t4, y, 'go',markersize=3,markeredgecolor='black')
                
            if layers == 5:
                t5 = t4 + 2*th5*( (np.sqrt(1/(vp5*vp5))-(tr.stats.slow*tr.stats.slow)) )
                ax2.plot(t5, y, 'go',markersize=3,markeredgecolor='black')

    if save:
        plt.savefig(ftitle+'.png', dpi=600, bbox_inches='tight', format='png')
    else:
        plt.show()

    plt.close()
    return