#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May  4 09:08:28 2022

@author: ben
"""

import numpy as np
import obspy
from obspy.signal import cross_correlation as cc
from telewavesim import utils as ut
from telewavesim import wiggle as wg
from pyswarms.single.local_best import LocalBestPSO

import pickle
import matplotlib.pyplot as plt

import os
from pyswarms.utils.plotters import plot_cost_history
import time

import csv
from obspy.clients.fdsn import Client
from multiprocessing import Pool

from obspy.core import Stream, Trace
import glob
from obspy.taup import TauPyModel

from functions_modifiedWiggle import rf_wiggles_RaS_label_stacked_wide
import pdb as pdb

def whitenDelph(tr, npts_smooth):
    #using the spectral whitening approach that Delph employs from his matlab script, in /v1.0.1/EQ_AUTOCORR_prewhite_wtempnorm.m
    dt = tr.stats.delta
    L = len(tr.data)
    Famp_z = abs(np.fft.fft(tr.data))    # Construct amp specturm
    Fphase_z = np.angle(np.fft.fft(tr.data)) # Construct phase spectrum

    # Smooth Spectra
    Fz_smooth = movingmean(Famp_z,npts_smooth)
    ind1 = np.where(Fz_smooth<0.1)
    Fz_smooth[ind1] = 0.1
    Fzamp_smooth = Famp_z/Fz_smooth
            
    # ALANS FILTER
    Z_smooth = np.fft.irfft(Fzamp_smooth * np.exp(np.zeros(len(Fphase_z))+1j * Fphase_z),L)

    tr.data = Z_smooth

    return tr

def movingmean(series,window):
    #ported into python from Jon Delph's function of the same name
    nt = len(series) 
    mval = np.zeros(len(series))
    for ii in range(len(series)):
        if (ii+1) <= (window-1)/2: #handles first window/2 samples
            ind = np.arange(0 , ((ii+1)+(window-1)/2)+1)
        elif (ii+1) >= nt - (window-1)/2: #handles last window/2 samples
            ind = np.arange((ii)-(window-1)/2,nt+1)
        else:
            ind = np.arange((ii)-(window-1)/2,((ii+1)+(window-1)/2)+1)
        mval[ii] = np.mean(series[int(ind[0]):int(ind[-1])])
    
    return mval

def time_taper(tr,taper_mins):
    #ported into python from Jon Delph's function of the same name
    nsamp = (taper_mins*60/tr.stats.delta)
    tap = np.hanning(nsamp * 2) # Get full taper to be split
    tap_beg = tap[0:int(nsamp+1)]
    tap_end = tap[int(len(tap)-nsamp - 1):len(tap)]
    
    hann_taper = np.zeros(len(tr.data))+1
    hann_taper[0:int(nsamp+1)] = tap_beg
    hann_taper[len(hann_taper)-int(nsamp)-1:] = tap_end
    
    tr.data = tr.data*hann_taper
    
    return tr
    
def binRF(stn,inDir,dest,polFactor=1,npts=11,zFilt_low=.08,zFilt_high=1.25,cull_Cmpnts=None,cull_rP=[0]):
    '''
    stn is the stations list, in format: ['s1','s2']
    inDir is the input directory
    dest is the output directory
    
    polFactor is the polarity multiplier. Occasionaly, stations have the wrong polarity
    for the RF's due to bad metadata. If that's the case, use a polFacotr of -1.
    Otherwise, use 1.
    '''
    slows = [.04, .045, .05, .055, .06, .065, .07, .075, .08]
    model = TauPyModel(model='ak135') #load taup with ak135
    
    if os.path.isdir(dest) == False:
        os.system('mkdir ' + dest)
    
    destination = dest  + 'bins/'
    print(destination)
    if os.path.isdir(destination) == False:
        os.system('mkdir ' + destination)
    
    #loop through stations
    for s in stn:
        
        print(s)
        s40r = Stream()
        s45r = Stream()
        s50r = Stream()
        s55r = Stream()
        s60r = Stream()
        s65r = Stream()
        s70r = Stream()
        s75r = Stream()
        s80r = Stream()
        
        s40t = Stream()
        s45t = Stream()
        s50t = Stream()
        s55t = Stream()
        s60t = Stream()
        s65t = Stream()
        s70t = Stream()
        s75t = Stream()
        s80t = Stream()
        
        s40z = Stream()
        s45z = Stream()
        s50z = Stream()
        s55z = Stream()
        s60z = Stream()
        s65z = Stream()
        s70z = Stream()
        s75z = Stream()
        s80z = Stream()
        
        s40rC = Stream()
        s45rC = Stream()
        s50rC = Stream()
        s55rC = Stream()
        s60rC = Stream()
        s65rC = Stream()
        s70rC = Stream()
        s75rC = Stream()
        s80rC = Stream()
        
        events = sorted(glob.glob(inDir + s + '/*')) #events directory
        #loop through all event directories
        for e in events:
            print(e)
            files = sorted(glob.glob(e + '/*'))
            #loop through files in each event directory and get receiver functions for the desired station
            for f in files:
                if s in f and '.eqr' in f:
                    

                    corner=2

                    rrf = obspy.read(f)
                    trf = obspy.read(f.replace('.eqr','.eqt'))
                    
                    z = obspy.read(f.replace('.eqr','.z'))
                    z[0].trim(starttime=z[0].stats.starttime + 145,endtime=z[0].stats.starttime + 180,pad=True,fill_value=0)
                    
                    z.filter('bandpass',freqmin=zFilt_low, freqmax=zFilt_high, corners=corner, zerophase=True)
                    z[0] = time_taper(z[0],1/12)
                    z[0] = whitenDelph(z[0], npts)

                    if os.path.isfile(f.replace('.eqr','.r')):
                        r = obspy.read(f.replace('.eqr','.r'))
                        
                    else:
                        n = obspy.read(f.replace('.eqr','.n'))
                        e = obspy.read(f.replace('.eqr','.e'))
                        ne = n + e
                        rt = ne.rotate(method='NE->RT',back_azimuth=rrf[0].stats.sac.baz)
                        r = Stream()
                        r = r.append(rt[0])
                    
                    r.filter('bandpass',freqmin=zFilt_low, freqmax=zFilt_high, corners=corner, zerophase=True)
                    r[0].trim(starttime=r[0].stats.starttime + 145,endtime=r[0].stats.starttime + 180,pad=True,fill_value=0)
                    r[0] = time_taper(r[0],1/12)
                    r[0] = whitenDelph(r[0], npts)
                    
                    #compute autocorrelations
                    xcorr_time = 5/6;
                    corrZ = cc.correlate(z[0],z[0],int((xcorr_time*60)/z[0].stats.delta),normalize='naive',method='auto')
                    corrZ = corrZ[500:]
                    corrR = cc.correlate(r[0],r[0],int((xcorr_time*60)/r[0].stats.delta),normalize='naive',method='auto')
                    corrR = corrR[500:]
    
                    #now taper
                    nsamp = (2/z[0].stats.delta); #start at 2 seconds
                    tap = np.hanning(nsamp * 2)
                    tap_beg = tap[1:int(nsamp+2)]
                    tap_end = tap[int(len(tap)-nsamp - 1):len(tap)]
                    hann_taper = np.zeros(len(corrZ))+1
                    hann_taper[0:int(nsamp+1)] = tap_beg
                    hann_taper[len(hann_taper)-int(nsamp)-1:] = tap_end
                    taper_corrZ = np.array(corrZ) * np.array(hann_taper)
                    taper_corrR = np.array(corrR) * np.array(hann_taper)
                    
                    #bandpass tapered autocorrelations
                    tcorrTrZ = Trace(data = taper_corrZ)
                    tcorrTrZ.stats.delta = z[0].stats.delta
                    tcorrTrZ.filter('bandpass',freqmin=zFilt_low, freqmax=zFilt_high, corners=corner, zerophase=True)
                    taper_corrZ = tcorrTrZ.data
                    
                    tcorrTrR = Trace(data = taper_corrR)
                    tcorrTrR.stats.delta = r[0].stats.delta
                    tcorrTrR.filter('bandpass',freqmin=zFilt_low, freqmax=zFilt_high, corners=corner, zerophase=True)
                    taper_corrR = tcorrTrR.data
                    
                    #pad to same length as synthetics
                    taper_corrZ = taper_corrZ[:(int(len(taper_corrZ)* (3/5)) + 1)]
                    taper_corr_padZ = np.append(np.zeros(int(30/z[0].stats.delta)),taper_corrZ)
                    
                    taper_corrR = taper_corrR[:(int(len(taper_corrR)* (3/5)) + 1)]
                    taper_corr_padR = np.append(np.zeros(int(30/z[0].stats.delta)),taper_corrR)
                    
                    zAuto = z[0].copy()
                    zAuto.data = taper_corr_padZ
                    zAuto.stats.channel = 'autoZ'
                    rAuto = r[0].copy()
                    rAuto.data = taper_corr_padR
                    rAuto.stats.channel = 'autoR'

                    #pdb.set_trace()
                    rp = rrf[0].stats.sac.user2
                    #print(rp)
                    #bin rf's based on ray parameter
                    if rp < .0425:
                        #print('.04 bin')
                        s40r.append(rrf[0])
                        s40t.append(trf[0])
                        s40z.append(zAuto)
                        s40rC.append(rAuto)
                    elif rp < .0475:
                        #print('.045 bin')
                        s45r.append(rrf[0])
                        s45t.append(trf[0])
                        s45z.append(zAuto)
                        s45rC.append(rAuto)
                    elif rp < .0525:
                        #print('.05 bin')
                        s50r.append(rrf[0])
                        s50t.append(trf[0])
                        s50z.append(zAuto)
                        s50rC.append(rAuto)
                    elif rp < .0575:
                        #print('.055 bin')
                        s55r.append(rrf[0])
                        s55t.append(trf[0])
                        s55z.append(zAuto)
                        s55rC.append(rAuto)
                    elif rp < .0625:
                        #print('.06 bin')
                        s60r.append(rrf[0])
                        s60t.append(trf[0])
                        s60z.append(zAuto)
                        s60rC.append(rAuto)
                    elif rp < .0675:
                        #print('.065 bin')
                        s65r.append(rrf[0])
                        s65t.append(trf[0])
                        s65z.append(zAuto)
                        s65rC.append(rAuto)
                    elif rp < .0725:
                        #print('.07 bin')
                        s70r.append(rrf[0])
                        s70t.append(trf[0])
                        s70z.append(zAuto)
                        s70rC.append(rAuto)
                    elif rp < .0775:
                        #print('.075 bin')
                        s75r.append(rrf[0])
                        s75t.append(trf[0])
                        s75z.append(zAuto)
                        s75rC.append(rAuto)
                    else:
                        #print('.08 bin')
                        s80r.append(rrf[0])
                        s80t.append(trf[0])
                        s80z.append(zAuto)
                        s80rC.append(rAuto)
        
        #stack rp bins                
        s40rstack = s40r.stack()
        s45rstack = s45r.stack()
        s50rstack = s50r.stack()
        s55rstack = s55r.stack()
        s60rstack = s60r.stack()
        s65rstack = s65r.stack()
        s70rstack = s70r.stack()
        s75rstack = s75r.stack()
        s80rstack = s80r.stack()
        
        s40tstack = s40t.stack()
        s45tstack = s45t.stack()
        s50tstack = s50t.stack()
        s55tstack = s55t.stack()
        s60tstack = s60t.stack()
        s65tstack = s65t.stack()
        s70tstack = s70t.stack()
        s75tstack = s75t.stack()
        s80tstack = s80t.stack()
        
        s40zstack = s40z.stack()
        s45zstack = s45z.stack()
        s50zstack = s50z.stack()
        s55zstack = s55z.stack()
        s60zstack = s60z.stack()
        s65zstack = s65z.stack()
        s70zstack = s70z.stack()
        s75zstack = s75z.stack()
        s80zstack = s80z.stack()
        
        s40rCstack = s40rC.stack()
        s45rCstack = s45rC.stack()
        s50rCstack = s50rC.stack()
        s55rCstack = s55rC.stack()
        s60rCstack = s60rC.stack()
        s65rCstack = s65rC.stack()
        s70rCstack = s70rC.stack()
        s75rCstack = s75rC.stack()
        s80rCstack = s80rC.stack()
        
        #find any empty stacks that are empty and add in blank traces to keep everything happy
        if any([len(s40rstack)==0, len(s45rstack)==0, len(s50rstack)==0, len(s55rstack)==0, len(s60rstack)==0, len(s65rstack)==0, len(s70rstack)==0, len(s75rstack)==0, len(s80rstack)==0]):
            l = [len(s40rstack), len(s45rstack), len(s50rstack), len(s55rstack), len(s60rstack), len(s65rstack), len(s70rstack), len(s75rstack), len(s80rstack)].index(0)
            #m = [len(s40rstack)==0, len(s45rstack)==0, len(s50rstack)==0, len(s55rstack)==0, len(s60rstack)==0, len(s65rstack)==0, len(s70rstack)==0, len(s75rstack)==0, len(s80rstack)==0].index(True)
            nm = [len(s40rstack)==0, len(s45rstack)==0, len(s50rstack)==0, len(s55rstack)==0, len(s60rstack)==0, len(s65rstack)==0, len(s70rstack)==0, len(s75rstack)==0, len(s80rstack)==0].index(False)
            bs = ['40','45','50','55','60','65','70','75','80']
            #print('s' + str(l+4) + 'stack')
            
            locals()['s' + bs[l] + 'rstack'].append(locals()['s' + bs[nm] + 'rstack'][0].copy())
            locals()['s' + bs[l] + 'rstack'][0].data = locals()['s' + bs[l] + 'rstack'][0].data * 0
            locals()['s' + bs[l] + 'tstack'].append(locals()['s' + bs[nm] + 'tstack'][0].copy())
            locals()['s' + bs[l] + 'tstack'][0].data = locals()['s' + bs[l] + 'tstack'][0].data * 0
            locals()['s' + bs[l] + 'zstack'].append(locals()['s' + bs[nm] + 'zstack'][0].copy())
            locals()['s' + bs[l] + 'zstack'][0].data = locals()['s' + bs[l] + 'zstack'][0].data * 0
            locals()['s' + bs[l] + 'rCstack'].append(locals()['s' + bs[nm] + 'rCstack'][0].copy())
            locals()['s' + bs[l] + 'rCstack'][0].data = locals()['s' + bs[l] + 'rCstack'][0].data * 0
            
        
        if cull_Cmpnts != None:
            for cmpnt in cull_Cmpnts:
                print(cmpnt)
                for rP in cull_rP:
                    print(rP)
                    #pdb.set_trace()
                    locals()['s' + str(int((rP * 1000))) + cmpnt + 'stack'][0].data = locals()['s' + str(int((rP * 1000))) + cmpnt + 'stack'][0].data * 0
        #pdb.set_trace()
        #move all stacked rp bins into Stream, then normalize and save into mseed
        sAll = Stream(traces=[s40rstack[0], s45rstack[0], s50rstack[0], s55rstack[0], s60rstack[0], s65rstack[0], s70rstack[0], s75rstack[0], s80rstack[0], s40tstack[0], s45tstack[0], s50tstack[0], s55tstack[0], s60tstack[0], s65tstack[0], s70tstack[0], s75tstack[0], s80tstack[0], s40zstack[0], s45zstack[0], s50zstack[0], s55zstack[0], s60zstack[0], s65zstack[0], s70zstack[0], s75zstack[0], s80zstack[0], s40rCstack[0], s45rCstack[0], s50rCstack[0], s55rCstack[0], s60rCstack[0], s65rCstack[0], s70rCstack[0], s75rCstack[0], s80rCstack[0]])
        
       
        for tr in sAll:
            normFactor = max(abs(tr.data))
            if normFactor != 0:
                tr.data = tr.data / normFactor

        for tr in sAll:
            tr.data = tr.data * polFactor
            start = tr.stats.starttime
            
            if tr.stats.channel == 'autoZ' or tr.stats.channel == 'autoR':
                tr.trim(starttime=tr.stats.starttime - 0,endtime=tr.stats.starttime + 60,pad=True,fill_value=0)
                tr.stats.starttime = start
                tr.stats.wvtype = 'P'
            else:
                tr.trim(starttime=tr.stats.starttime - 0,endtime=tr.stats.starttime + 60,pad=True,fill_value=0)
                tr.stats.starttime = start
                tr.stats.wvtype = 'P' 
            
            if s.split('-')[0] == 'BS':
                tr.stats.network = 'BS'
                tr.stats.station = s.split('-')[1]
                    
        sAll.write(destination + s + '.mseed')
