# -*- coding: utf-8 -*-
"""
Created on Mon Mar 25 10:15:10 2024

functions for plotting solutions

@author: ben sadler
"""

import numpy as np
import matplotlib.pyplot as plt
import csv
from obspy.clients.fdsn import Client
import pdb as pdb
import seaborn as sns

from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms
import corner

def plotSol_Options(d,station,bestCost,bestPos,costs_localMin=[],positions=[],costs=[],nlayers=3, pointCutoff=20, cloudCutoff=.5, writeCSV=True, cloud = False, drawEllipse=False, minMax_ofSubset=False, splitVariable=None, splitValue=0, splitVariable2=None, splitValue2=0, side='greater', plotMinMax=False, oneCorner=True):
    
    '''
    Iterates through search history and plots search positions/costs for analysis.
    Also writes out solutions in several formtas (just best solutions, standard deviation, and min/max) into a .csv file.
    
    
    d is the directory
    station is the station being plotted
    bestCost and bestPos are the minimum cost from the algorithm, and bestPos is the corresponding parameter combination
    costs_localMin are the costs associated with the local minima, and positions are the corresponding parameter combinations
    nlayers is the number of layers used for modeling
    pointCuttoff is the percent cutoff for plotting misfits on scatter plots. If the best misfit/cost is 0, and the worst is 1, then all tested solutions <.2 will be plotted with a 20% cutoff.
    cloudCutoff is the percent cutoff for plotting misfits as translucent clouds in scatter plots AND corner plots
    writeCSV is whether or not to write a CSV with solutions
    cloud is whether or not to plot transclucent "cloud" of low misfit points onto scatter plot
    drawEllipse is whether to plot and ellipse based on the cloudCutoff
    minMax_ofSubset controls whether the minimum and maximum values of subset will be written out. this is done to decide whether solution subsets are narrow enough
    
    these control whether low misfits will be split into groups. this is useful if the smaller msifit groups are multimodal.
    the variable is which parameter to use, and the value is the actual parameter value used to split the groups. can be used for up to 2 values
        splitVariable 
        splitValue
        splitVariable2
        splitValue2
        side
        
    plotMinMax=False is whether to to plot the min and max values on scatter plots
    oneCorner is whether to make 1 big corner plot or seperate corner plots for each layer. one corner plot is better if you have a big monitor, I think, but becomes hard to read in print
    '''
    
    client = Client('IRIS')
    
    print('Plotting Searches: ' + station)
    
    #get location of station via IRIS client
    if 'BS.' in station or 'BS-' in station:
        if 'RLOK' in station:
            stn = client.get_stations(network='OK', station='RLOK', format='xml', level='channel')
        elif 'U40' in station:
            stn = client.get_stations(network='TA', station='U40A', format='xml', level='channel')
        elif 'X40' in station:
            stn = client.get_stations(network='TA', station='X40A', format='xml', level='channel')
        elif 'Z38' in station:
            stn = client.get_stations(network='TA', station='Z38A', format='xml', level='channel')
        elif 'U38' in station:
            stn = client.get_stations(network='TA', station='U38A', format='xml', level='channel')
        elif 'X37' in station:
            stn = client.get_stations(network='TA', station='X37A', format='xml', level='channel')
        elif 'Z41' in station:
            stn = client.get_stations(network='TA', station='Z41A', format='xml', level='channel')
        elif '140' in station:
            stn = client.get_stations(network='TA', station='140A', format='xml', level='channel')
        elif '441' in station:
            stn = client.get_stations(network='TA', station='441A', format='xml', level='channel')
        elif 'HNVL' in station:
            stn = client.get_stations(network='TA', station='438A', format='xml', level='channel')
        elif 'MANK' in station:
            stn = client.get_stations(network='XI', station='MANK', format='xml', level='channel')
        elif 'JAFL' in station:
            stn = client.get_stations(network='XI', station='JAFL', format='xml', level='channel')
        elif 'MPUR' in station:
            stn = client.get_stations(network='XI', station='MPUR', format='xml', level='channel')
        elif 'BN15' in station:
            stn = client.get_stations(network='XR', station='BN08', format='xml', level='channel')
            print(stn)
            
    elif '_' in station:
        stn = client.get_stations(network=station.replace('-','.').split('.')[0], station=station.replace('-','.').split('.')[1].split('_')[0], format='xml', level='channel')
    else:    
        stn = client.get_stations(network=station.replace('-','.').split('.')[0], station=station.replace('-','.').split('.')[1], format='xml', level='channel')
    
    '''
    #load in solutions and search history
    sols = pickle.load(open(d + station + '.p','rb'))
    [pos_history, cost_history, bounds, options, [wR, wZ, wRC]] = pickle.load(open(d + station + '_optimizer.p','rb'))
    kwargs = pickle.load(open(d + station + '_kwargs.p','rb'))
    '''
    
    #best solution for each parameter
    if nlayers == 5:
        l1 = bestPos[0]
        l2 = bestPos[1]
        l3 = bestPos[2]
        l4 = bestPos[3]
        l5 = bestPos[4]
        k1 = bestPos[10]
        k2 = bestPos[11]
        k3 = bestPos[12]
        k4 = bestPos[13]
        k5 = bestPos[14]
        vp1 = bestPos[5]
        vp2 = bestPos[6]
        vp3 = bestPos[7]
        vp4 = bestPos[8]
        vp5 = bestPos[9]
        
    elif nlayers == 4:
        l1 = bestPos[0]
        l2 = bestPos[1]
        l3 = bestPos[2]
        l4 = bestPos[3]
        l5 = 0
        k1 = bestPos[8]
        k2 = bestPos[9]
        k3 = bestPos[10]
        k4 = bestPos[11]
        k5 = 0
        vp1 = bestPos[4]
        vp2 = bestPos[5]
        vp3 = bestPos[6]
        vp4 = bestPos[7]
        vp5 = 0
        
    elif nlayers == 3:
        l1 = bestPos[0]
        l2 = bestPos[1]
        l3 = bestPos[2]
        l4 = 0
        l5 = 0
        k1 = bestPos[6]
        k2 = bestPos[7]
        k3 = bestPos[8]
        k4 = 0
        k5 = 0
        vp1 = bestPos[3]
        vp2 = bestPos[4]
        vp3 = bestPos[5]
        vp4 = 0
        vp5 = 0
    elif nlayers == 2:
        l1 = bestPos[0]
        l2 = bestPos[1]
        l3 = 0
        l4 = 0
        l5 = 0
        k1 = bestPos[4]
        k2 = bestPos[5]
        k3 = 0
        k4 = 0
        k5 = 0
        vp1 = bestPos[2]
        vp2 = bestPos[3]
        vp3 = 0
        vp4 = 0
        vp5 = 0
    elif nlayers == 1:
        l1 = bestPos[0]
        l2 = 0
        l3 = 0
        l4 = 0
        l5 = 0
        k1 = bestPos[2]
        k2 = 0
        k3 = 0
        k4 = 0
        k5 = 0
        vp1 = bestPos[1]
        vp2 = 0
        vp3 = 0
        vp4 = 0
        vp5 = 0
    
    if writeCSV == True:
        m = l1 + l2 + l3 + l4 + l5
        bestSol = [station,stn[0][0].latitude,stn[0][0].longitude,str(round(bestCost,2)),str(round(l1,2)),str(round(vp1,2)),str(round(k1,2)),str(round(l2,2)),str(round(vp2,2)),str(round(k2,2)),str(round(l3,2)),str(round(vp3,2)),str(round(k3,2)),str(round(l4,2)),str(round(vp4,2)),str(round(k4,2)),str(round(l4,2)),str(round(vp4,2)),str(round(k4,2)),str(round(m,2))]
        
        with open(d +'bestSols.csv','a') as f:
            write = csv.writer(f)
            write.writerow(bestSol)       
    
    theta = []
    l1s = []
    l2s = []
    l3s = []
    l4s = []
    l5s = []
    vp1s = []
    vp2s = []
    vp3s = []
    vp4s = []
    vp5s = []
    k1s = []
    k2s = []
    k3s = []
    k4s = []
    k5s = []
    
    thetas = []
    
    # Using readlines()
    allSearches = d + station.replace('-','.') + '.txt'
    print(allSearches)
    file1 = open(allSearches, 'r')
    Lines = file1.readlines()
 
    linesSorted = [x for y, x in sorted(zip(costs, Lines))]
    print('Lines length:')
    print(len(Lines))
    costs = sorted(costs)
    print('Costs length:')
    print(len(costs))
    
    linesSorted.reverse()
    costs.reverse()
    
    count = 0
    for line in linesSorted:
        count += 1
        #print("Line{}: {}".format(count, line.strip()))
        if nlayers == 5:
            thetaStr = line.replace('[','').replace(']','').replace('\n','').split()
            theta = float(thetaStr[0]),float(thetaStr[1]),float(thetaStr[2]),float(thetaStr[3]),float(thetaStr[4]), \
                float(thetaStr[5]),float(thetaStr[6]),float(thetaStr[7]),float(thetaStr[8]),float(thetaStr[9]), \
                    float(thetaStr[10]),float(thetaStr[11]),float(thetaStr[12]),float(thetaStr[13]),float(thetaStr[14])
            
            thetas.append(theta)     
            l1s.append(theta[0])
            l2s.append(theta[1])
            l3s.append(theta[2])
            l4s.append(theta[3])
            l5s.append(theta[4])
            
            vp1s.append(theta[5])
            vp2s.append(theta[6])
            vp3s.append(theta[7])
            vp4s.append(theta[8])
            vp5s.append(theta[9])
            
            k1s.append(theta[10])
            k2s.append(theta[11])
            k3s.append(theta[12])
            k4s.append(theta[13])
            k5s.append(theta[14])
            
        elif nlayers == 4:
            thetaStr = line.replace('[','').replace(']','').replace('\n','').split()
            theta = float(thetaStr[0]),float(thetaStr[1]),float(thetaStr[2]),float(thetaStr[3]),\
                float(thetaStr[4]),float(thetaStr[5]),float(thetaStr[6]),float(thetaStr[7]),\
                    float(thetaStr[8]),float(thetaStr[9]),float(thetaStr[10]),float(thetaStr[11])
                
            thetas.append(theta)           
            l1s.append(theta[0])
            l2s.append(theta[1])
            l3s.append(theta[2])
            l4s.append(theta[3])
            l5s.append(0)
            
            vp1s.append(theta[4])
            vp2s.append(theta[5])
            vp3s.append(theta[6])
            vp4s.append(theta[7])
            vp5s.append(0)
            
            k1s.append(theta[8])
            k2s.append(theta[9])
            k3s.append(theta[10])
            k4s.append(theta[11])
            k5s.append(0)
            
        elif nlayers == 3:
            thetaStr = line.replace('[','').replace(']','').replace('\n','').split()
            theta = float(thetaStr[0]),float(thetaStr[1]),float(thetaStr[2]),\
                float(thetaStr[3]),float(thetaStr[4]),float(thetaStr[5]),\
                    float(thetaStr[6]),float(thetaStr[7]),float(thetaStr[8])
                
            thetas.append(theta)           
            l1s.append(theta[0])
            l2s.append(theta[1])
            l3s.append(theta[2])
            l4s.append(0)
            l5s.append(0)
            
            vp1s.append(theta[3])
            vp2s.append(theta[4])
            vp3s.append(theta[5])
            vp4s.append(0)
            vp5s.append(0)
            
            k1s.append(theta[6])
            k2s.append(theta[7])
            k3s.append(theta[8])
            k4s.append(0)
            k5s.append(0)
        
        elif nlayers == 2:
            thetaStr = line.replace('[','').replace(']','').replace('\n','').split()
            theta = float(thetaStr[0]),float(thetaStr[1]),\
                float(thetaStr[2]),float(thetaStr[3]),\
                    float(thetaStr[4]),float(thetaStr[5])
                
            thetas.append(theta)           
            l1s.append(theta[0])
            l2s.append(theta[1])
            l3s.append(0)
            l4s.append(0)
            l5s.append(0)
            
            vp1s.append(theta[2])
            vp2s.append(theta[3])
            vp3s.append(0)
            vp4s.append(0)
            vp5s.append(0)
            
            k1s.append(theta[4])
            k2s.append(theta[5])
            k3s.append(0)
            k4s.append(0)
            k5s.append(0)
            
        elif nlayers == 1:
            thetaStr = line.replace('[','').replace(']','').replace('\n','').split()
            theta = float(thetaStr[0]),float(thetaStr[1]),\
                float(thetaStr[2])
                
            thetas.append(theta)           
            l1s.append(theta[0])
            l2s.append(0)
            l3s.append(0)
            l4s.append(0)
            l5s.append(0)
            
            vp1s.append(theta[1])
            vp2s.append(0)
            vp3s.append(0)
            vp4s.append(0)
            vp5s.append(0)
            
            k1s.append(theta[2])
            k2s.append(0)
            k3s.append(0)
            k4s.append(0)
            k5s.append(0)
    '''
    l1_loc = []
    l2_loc = []
    l3_loc = []
    l4_loc = []
    l5_loc = []
    vp1_loc = []
    vp2_loc = []
    vp3_loc = []
    vp4_loc = []
    vp5_loc = []
    k1_loc = []
    k2_loc = []
    k3_loc = []
    k4_loc = []
    k5_loc = []
    
    for p in positions:
        if nlayers == 5:
            l1_loc.append(p[0])
            l2_loc.append(p[1])
            l3_loc.append(p[2])
            l4_loc.append(p[3])
            l5_loc.append(p[4])
            
            vp1_loc.append(p[5])
            vp2_loc.append(p[6])
            vp3_loc.append(p[7])
            vp4_loc.append(p[8])
            vp5_loc.append(p[9])
            
            k1_loc.append(p[10])
            k2_loc.append(p[11])
            k3_loc.append(p[12])
            k4_loc.append(p[13])
            k5_loc.append(p[14])
            
        elif nlayers == 4:
            l1_loc.append(p[0])
            l2_loc.append(p[1])
            l3_loc.append(p[2])
            l4_loc.append(p[3])
            
            vp1_loc.append(p[4])
            vp2_loc.append(p[5])
            vp3_loc.append(p[6])
            vp4_loc.append(p[7])
            
            k1_loc.append(p[8])
            k2_loc.append(p[9])
            k3_loc.append(p[10])
            k4_loc.append(p[11])
            
        elif nlayers == 3:
            l1_loc.append(p[0])
            l2_loc.append(p[1])
            l3_loc.append(p[2])
            
            vp1_loc.append(p[3])
            vp2_loc.append(p[4])
            vp3_loc.append(p[5])
            
            k1_loc.append(p[6])
            k2_loc.append(p[7])
            k3_loc.append(p[8])
            
        elif nlayers == 2:
            l1_loc.append(p[0])
            l2_loc.append(p[1])
            
            vp1_loc.append(p[2])
            vp2_loc.append(p[3])
            
            k1_loc.append(p[4])
            k2_loc.append(p[5])
            
        else:
            l1_loc.append(p[0])
            l2_loc.append(p[1])
            
            vp1_loc.append(p[2])
            vp2_loc.append(p[3])
            
            k1_loc.append(p[4])
            k2_loc.append(p[5])
    '''        
    # old definition for cutoff
    #cutoff= max(costs)-((max(costs)-min(costs))*((100-pointCutoff)/100))
    
    # new definiton for cutoff
    cutoff = min(costs) * (100/(100-pointCutoff))
    
    cutIdx = costs.index(costs[min(range(len(costs)), key = lambda i: abs(costs[i]-cutoff))])
    
    #old definition for cutoff for cutoff
    #cutoff_cloud= max(costs)-((max(costs)-min(costs))*((100-cloudCutoff)/100))
    
    # new definiton for cutoff
    cutoff_cloud = min(costs) * (100/(100-cloudCutoff))
    cutIdx_cloud = costs.index(costs[min(range(len(costs)), key = lambda i: abs(costs[i]-cutoff_cloud))])
    
    l1s_minMax = []
    k1s_minMax = []
    v1s_minMax = []
    
    l2s_minMax = []
    k2s_minMax = []
    v2s_minMax = []
    
    l3s_minMax = []
    k3s_minMax = []
    v3s_minMax = []
    
    l4s_minMax = []
    k4s_minMax = []
    v4s_minMax = []
    
    l5s_minMax = []
    k5s_minMax = []
    v5s_minMax = []
    
    if minMax_ofSubset:
        splitVariable_use = locals()[splitVariable]
        
        if splitVariable2 != None:
            splitVariable2_use = locals()[splitVariable2]
            
        for i in range(len(splitVariable_use[cutIdx_cloud:])):
            
            if splitVariable_use[cutIdx_cloud:][i] < splitValue:
                
                if splitVariable2 != None:
                    if splitVariable2_use[cutIdx_cloud:][i] < splitValue2:
                        check2 = True
                    else:
                        check2 = False
                    
                else:
                    check2 = True
                    
                if check2 == True:
                    l1s_minMax.append(l1s[cutIdx_cloud:][i])
                    k1s_minMax.append(k1s[cutIdx_cloud:][i])
                    v1s_minMax.append(vp1s[cutIdx_cloud:][i])
                    
                    if nlayers >= 2:
                        l2s_minMax.append(l2s[cutIdx_cloud:][i])
                        k2s_minMax.append(k2s[cutIdx_cloud:][i])
                        v2s_minMax.append(vp2s[cutIdx_cloud:][i])
                    else:
                        l2s_minMax.append(0)
                        k2s_minMax.append(0)
                        v2s_minMax.append(0)
                        
                    if nlayers >= 3:
                        l3s_minMax.append(l3s[cutIdx_cloud:][i])
                        k3s_minMax.append(k3s[cutIdx_cloud:][i])
                        v3s_minMax.append(vp3s[cutIdx_cloud:][i])
                    else:
                        l3s_minMax.append(0)
                        k3s_minMax.append(0)
                        v3s_minMax.append(0)
                        
                    if nlayers >= 4:
                        l4s_minMax.append(l4s[cutIdx_cloud:][i])
                        k4s_minMax.append(k4s[cutIdx_cloud:][i])
                        v4s_minMax.append(vp4s[cutIdx_cloud:][i])
                    else:
                        l4s_minMax.append(0)
                        k4s_minMax.append(0)
                        v4s_minMax.append(0)
                        
                    if nlayers >= 5:
                        l5s_minMax.append(l5s[cutIdx_cloud:][i])
                        k5s_minMax.append(k5s[cutIdx_cloud:][i])
                        v5s_minMax.append(vp5s[cutIdx_cloud:][i])
                    else:
                        l5s_minMax.append(0)
                        k5s_minMax.append(0)
                        v5s_minMax.append(0)
    
    #plot up search history with costs on 2D plots
    #pdb.set_trace()
    #xy = np.vstack([l1s[cutIdx:],k1s[cutIdx:]])
    #z = gaussian_kde(xy)(xy)
    
    #plt.scatter(l1s[cutIdx:], k1s[cutIdx:], s=100)
    
    #plt.scatter(l1s[cutIdx:],k1s[cutIdx:],s=1, c=costs[cutIdx:], cmap='binary_r', vmin=min(costs[cutIdx:]), vmax=max(costs[cutIdx:]))
    l1_std = np.std(l1s[cutIdx_cloud:])
    k1_std = np.std(k1s[cutIdx_cloud:])
    v1_std = np.std(vp1s[cutIdx_cloud:])
    
    #set up dummy values for std so the csv writer doesn't break
    l2_std = 0
    k2_std = 0
    v2_std = 0
    
    l3_std = 0
    k3_std = 0
    v3_std = 0
    
    l4_std = 0
    k4_std = 0
    v4_std = 0
    
    l5_std = 0
    k5_std = 0
    v5_std = 0
    
    #testing out corner plots
    #ndim, nsamples = 2, 10000
    #np.random.seed(42)
    #samples = np.random.randn(ndim * nsamples).reshape([nsamples, ndim])
    
    '''
    l1s_corner = np.array(l1s[cutIdx_cloud:], dtype='float32')
    k1s_corner = np.array(k1s[cutIdx_cloud:], dtype='float32')
    v1s_corner = np.array(vp1s[cutIdx_cloud:], dtype='float32')
                  
    data = np.vstack([l1s_corner, k1s_corner, v1s_corner]).transpose()
    print(data)
    figure = corner.corner(data,labels=[r"$l1$",r"$k1$",r"$vp1$"])
    '''
    
    #make corner plots
    if nlayers == 1:
        l1s_corner = np.array(l1s[cutIdx_cloud:], dtype='float32')
        k1s_corner = np.array(k1s[cutIdx_cloud:], dtype='float32')
        v1s_corner = np.array(vp1s[cutIdx_cloud:], dtype='float32')
                      
        data = np.vstack([l1s_corner, k1s_corner, v1s_corner]).transpose()
        lbls = [r"$l1$",r"$k1$",r"$vp1$"]
        #print([(min(l1s),max(l1s)), (1.6, 2.1),  (min(vp1s),max(vp1s))])
        figure = corner.corner(data, range=[(min(l1s),max(l1s)), (1.6, 2.1),  (min(vp1s),max(vp1s))], labels=lbls)
        # Extract the axes
        axes = np.array(figure.axes).reshape((3, 3))
        
        value1 = [l1,k1,vp1]
        #value1 = np.mean(data, axis=0)
        print(value1)
        
        # Loop over the diagonal
        for i in range(3):
            ax = axes[i, i]
            ax.axvline(value1[i], color="g")
            #ax.axvline(value2[i], color="r")
        
        # Loop over the histograms
        for yi in range(3):
            for xi in range(yi):
                ax = axes[yi, xi]
                ax.axvline(value1[xi], color="g")
                #ax.axvline(value2[xi], color="r")
                ax.axhline(value1[yi], color="g")
                #ax.axhline(value2[yi], color="r")
                ax.plot(value1[xi], value1[yi], "sg")
                #ax.plot(value2[xi], value2[yi], "sr")
        
        figure.savefig(d + station + '_corners', dpi=300)
        
    if nlayers == 2:
        l1s_corner = np.array(l1s[cutIdx_cloud:], dtype='float32')
        k1s_corner = np.array(k1s[cutIdx_cloud:], dtype='float32')
        v1s_corner = np.array(vp1s[cutIdx_cloud:], dtype='float32')
        
        l2s_corner = np.array(l2s[cutIdx_cloud:], dtype='float32')
        k2s_corner = np.array(k2s[cutIdx_cloud:], dtype='float32')
        v2s_corner = np.array(vp2s[cutIdx_cloud:], dtype='float32')
           
        data = np.vstack([l1s_corner, k1s_corner, v1s_corner, l2s_corner, k2s_corner, v2s_corner]).transpose()
        lbls = [r"$l1$",r"$k1$",r"$vp1$",r"$l2$",r"$k2$",r"$vp2$"]
        figure = corner.corner(data, range=[(min(l1s),max(l1s)), (1.6, 2.1),  (min(vp1s),max(vp1s)), (min(l2s),max(l2s)), (1.6, 2.1), (min(vp2s),max(vp2s))], labels=lbls)
        # Extract the axes
        axes = np.array(figure.axes).reshape((6, 6))
        
        value1 = [l1,k1,vp1,l2,k2,vp2]
        #value1 = np.mean(data, axis=0)
        print(value1)
        
        # Loop over the diagonal
        for i in range(6):
            ax = axes[i, i]
            ax.axvline(value1[i], color="g")
            #ax.axvline(value2[i], color="r")
        
        # Loop over the histograms
        for yi in range(6):
            for xi in range(yi):
                ax = axes[yi, xi]
                ax.axvline(value1[xi], color="g")
                #ax.axvline(value2[xi], color="r")
                ax.axhline(value1[yi], color="g")
                #ax.axhline(value2[yi], color="r")
                ax.plot(value1[xi], value1[yi], "sg")
                #ax.plot(value2[xi], value2[yi], "sr")
        
        figure.savefig(d + station + '_corners', dpi=300)
        
    elif nlayers == 3:
        l1s_corner = np.array(l1s[cutIdx_cloud:], dtype='float32')
        k1s_corner = np.array(k1s[cutIdx_cloud:], dtype='float32')
        v1s_corner = np.array(vp1s[cutIdx_cloud:], dtype='float32')
        
        l2s_corner = np.array(l2s[cutIdx_cloud:], dtype='float32')
        k2s_corner = np.array(k2s[cutIdx_cloud:], dtype='float32')
        v2s_corner = np.array(vp2s[cutIdx_cloud:], dtype='float32')
        
        l3s_corner = np.array(l3s[cutIdx_cloud:], dtype='float32')
        k3s_corner = np.array(k3s[cutIdx_cloud:], dtype='float32')
        v3s_corner = np.array(vp3s[cutIdx_cloud:], dtype='float32')
        
        if oneCorner==True:
            data = np.vstack([l1s_corner, k1s_corner, v1s_corner, l2s_corner, k2s_corner, v2s_corner, l3s_corner, k3s_corner, v3s_corner]).transpose()
            lbls = [r"$l1$",r"$k1$",r"$vp1$",r"$l2$",r"$k2$",r"$vp2$",r"$l3$",r"$k3$",r"$vp3$"]
            figure = corner.corner(data, range=[(min(l1s),max(l1s)), (1.6, 2.1),  (min(vp1s),max(vp1s)), (min(l2s),max(l2s)), (1.6, 2.1), (min(vp2s),max(vp2s)),  (min(l3s),max(l3s)), (1.6, 2.1), (min(vp3s),max(vp3s))], labels=lbls)
            # Extract the axes
            axes = np.array(figure.axes).reshape((9, 9))
            
            value1 = [l1,k1,vp1,l2,k2,vp2,l3,k3,vp3]
            #value1 = np.mean(data, axis=0)
            print(value1)
            
            # Loop over the diagonal
            for i in range(9):
                ax = axes[i, i]
                ax.axvline(value1[i], color="g")
                #ax.axvline(value2[i], color="r")
            
            # Loop over the histograms
            for yi in range(9):
                for xi in range(yi):
                    ax = axes[yi, xi]
                    ax.axvline(value1[xi], color="g")
                    #ax.axvline(value2[xi], color="r")
                    ax.axhline(value1[yi], color="g")
                    #ax.axhline(value2[yi], color="r")
                    ax.plot(value1[xi], value1[yi], "sg")
                    #ax.plot(value2[xi], value2[yi], "sr")
            
            figure.savefig(d + station + '_corners', dpi=300)
        else:
            data1 = np.vstack([l1s_corner, k1s_corner, v1s_corner]).transpose()
            lbls1 = [r"$l1$",r"$k1$",r"$vp1$"]
            figure1 = corner.corner(data1, range=[(min(l1s),max(l1s)), (1.6, 2.1),  (min(vp1s),max(vp1s))], labels=lbls1)
            # Extract the axes
            axes1 = np.array(figure1.axes).reshape((3, 3))
            
            value1 = [l1,k1,vp1]
            #value1 = np.mean(data, axis=0)
            print(value1)
            
            # Loop over the diagonal
            for i in range(3):
                ax = axes1[i, i]
                ax.axvline(value1[i], color="g")
                #ax.axvline(value2[i], color="r")
            
            # Loop over the histograms
            for yi in range(3):
                for xi in range(yi):
                    ax = axes1[yi, xi]
                    ax.axvline(value1[xi], color="g")
                    #ax.axvline(value2[xi], color="r")
                    ax.axhline(value1[yi], color="g")
                    #ax.axhline(value2[yi], color="r")
                    ax.plot(value1[xi], value1[yi], "sg")
                    #ax.plot(value2[xi], value2[yi], "sr")
            
            figure1.savefig(d + station + '_corners_L1', dpi=300)
            
            data2 = np.vstack([l2s_corner, k2s_corner, v2s_corner]).transpose()
            lbls2 = [r"$l2$",r"$k2$",r"$vp2$"]
            figure2 = corner.corner(data2, range=[(min(l2s),max(l2s)), (1.6, 2.1),  (min(vp2s),max(vp2s))], labels=lbls2)
            # Extract the axes
            axes2 = np.array(figure2.axes).reshape((3, 3))
            
            value2 = [l2,k2,vp2]
            #value1 = np.mean(data, axis=0)
            print(value2)
            
            # Loop over the diagonal
            for i in range(3):
                ax = axes2[i, i]
                ax.axvline(value2[i], color="g")
                #ax.axvline(value2[i], color="r")
            
            # Loop over the histograms
            for yi in range(3):
                for xi in range(yi):
                    ax = axes2[yi, xi]
                    ax.axvline(value2[xi], color="g")
                    #ax.axvline(value2[xi], color="r")
                    ax.axhline(value2[yi], color="g")
                    #ax.axhline(value2[yi], color="r")
                    ax.plot(value2[xi], value2[yi], "sg")
                    #ax.plot(value2[xi], value2[yi], "sr")
            
            figure2.savefig(d + station + '_corners_L2', dpi=300)
            
            data3 = np.vstack([l3s_corner, k3s_corner, v3s_corner]).transpose()
            lbls3 = [r"$l3$",r"$k3$",r"$vp3$"]
            figure3 = corner.corner(data3, range=[(min(l3s),max(l3s)), (1.6, 2.1),  (min(vp3s),max(vp3s))], labels=lbls3)
            # Extract the axes
            axes3 = np.array(figure3.axes).reshape((3, 3))
            
            value3 = [l3,k3,vp3]
            #value1 = np.mean(data, axis=0)
            print(value3)
            
            # Loop over the diagonal
            for i in range(3):
                ax = axes3[i, i]
                ax.axvline(value3[i], color="g")
                #ax.axvline(value2[i], color="r")
            
            # Loop over the histograms
            for yi in range(3):
                for xi in range(yi):
                    ax = axes3[yi, xi]
                    ax.axvline(value3[xi], color="g")
                    #ax.axvline(value2[xi], color="r")
                    ax.axhline(value3[yi], color="g")
                    #ax.axhline(value2[yi], color="r")
                    ax.plot(value3[xi], value3[yi], "sg")
                    #ax.plot(value2[xi], value2[yi], "sr")
            
            figure3.savefig(d + station + '_corners_L3', dpi=300)
            
            dataL = np.vstack([l1s_corner, l2s_corner, l3s_corner]).transpose()
            lblsL = [r"$l1$",r"$l2$",r"$l3$"]
            figureL = corner.corner(dataL, range=[(min(l1s),max(l1s)), (min(l2s),max(l2s)),  (min(l3s),max(l3s))], labels=lblsL)
            # Extract the axes
            axesL = np.array(figureL.axes).reshape((3, 3))
            
            valueL = [l1,l2,l3]
            #value1 = np.mean(data, axis=0)
            print(valueL)
            
            # Loop over the diagonal
            for i in range(3):
                ax = axesL[i, i]
                ax.axvline(valueL[i], color="g")
                #ax.axvline(value2[i], color="r")
            
            # Loop over the histograms
            for yi in range(3):
                for xi in range(yi):
                    ax = axesL[yi, xi]
                    ax.axvline(valueL[xi], color="g")
                    #ax.axvline(value2[xi], color="r")
                    ax.axhline(valueL[yi], color="g")
                    #ax.axhline(value2[yi], color="r")
                    ax.plot(valueL[xi], valueL[yi], "sg")
                    #ax.plot(value2[xi], value2[yi], "sr")
            
            figureL.savefig(d + station + '_corners_Ls', dpi=300)
        
    elif nlayers == 4:
        print('Lengths of data: ')
        print(len(l1s))
        print(len(k1s))
        print(len(vp1s))
        
        print(len(l2s))
        print(len(k2s))
        print(len(vp2s))
        
        print(len(l3s))
        print(len(k3s))
        print(len(vp3s))
        
        print(len(l4s))
        print(len(k4s))
        print(len(vp4s))
        
        l1s_corner = np.array(l1s[cutIdx_cloud:], dtype='float32')
        k1s_corner = np.array(k1s[cutIdx_cloud:], dtype='float32')
        v1s_corner = np.array(vp1s[cutIdx_cloud:], dtype='float32')
        
        l2s_corner = np.array(l2s[cutIdx_cloud:], dtype='float32')
        k2s_corner = np.array(k2s[cutIdx_cloud:], dtype='float32')
        v2s_corner = np.array(vp2s[cutIdx_cloud:], dtype='float32')
        
        l3s_corner = np.array(l3s[cutIdx_cloud:], dtype='float32')
        k3s_corner = np.array(k3s[cutIdx_cloud:], dtype='float32')
        v3s_corner = np.array(vp3s[cutIdx_cloud:], dtype='float32')
        
        l4s_corner = np.array(l4s[cutIdx_cloud:], dtype='float32')
        k4s_corner = np.array(k4s[cutIdx_cloud:], dtype='float32')
        v4s_corner = np.array(vp4s[cutIdx_cloud:], dtype='float32')
                      
        data = np.vstack([l1s_corner, k1s_corner, v1s_corner, l2s_corner, k2s_corner, v2s_corner, l3s_corner, k3s_corner, v3s_corner, l4s_corner, k4s_corner, v4s_corner]).transpose()
        #print(data)
        lbls = [r"$l1$",r"$k1$",r"$vp1$",r"$l2$",r"$k2$",r"$vp2$",r"$l3$",r"$k3$",r"$vp3$",r"$l4$",r"$k4$",r"$vp4$"]
        print('Lengths of corner plots: ')
        print(len(l1s_corner))
        print(len(k1s_corner))
        print(len(v1s_corner))
        
        print(len(l2s_corner))
        print(len(k2s_corner))
        print(len(v2s_corner))
        
        print(len(l3s_corner))
        print(len(k3s_corner))
        print(len(v3s_corner))
        
        print(len(l4s_corner))
        print(len(k4s_corner))
        print(len(v4s_corner))
        
        print('ranges for plotting: ')
        print([(min(l1s),max(l1s)), (1.6, 2.1),  (min(vp1s),max(vp1s)), (min(l2s),max(l2s)), (1.6, 2.1), (min(vp2s),max(vp2s)),  (min(l3s),max(l3s)), (1.6, 2.1), (min(vp3s),max(vp3s)),  (min(l4s),max(l4s)), (1.6, 2.1), (min(vp4s),max(vp4s))])
        
        figure = corner.corner(data, range=[(min(l1s),max(l1s)), (1.6, 2.1),  (min(vp1s),max(vp1s)), (min(l2s),max(l2s)), (1.6, 2.1), (min(vp2s),max(vp2s)),  (min(l3s),max(l3s)), (1.6, 2.1), (min(vp3s),max(vp3s)),  (min(l4s),max(l4s)), (1.6, 2.1), (min(vp4s),max(vp4s))], labels=lbls)
        # Extract the axes
        axes = np.array(figure.axes).reshape((12, 12))
        
        value1 = [l1,k1,vp1,l2,k2,vp2,l3,k3,vp3,l4,k4,vp4]
        #value1 = np.mean(data, axis=0)
        print(value1)
        
        # Loop over the diagonal
        for i in range(12):
            ax = axes[i, i]
            ax.axvline(value1[i], color="g")
            #ax.axvline(value2[i], color="r")
        
        # Loop over the histograms
        for yi in range(12):
            for xi in range(yi):
                ax = axes[yi, xi]
                ax.axvline(value1[xi], color="g")
                #ax.axvline(value2[xi], color="r")
                ax.axhline(value1[yi], color="g")
                #ax.axhline(value2[yi], color="r")
                ax.plot(value1[xi], value1[yi], "sg")
                #ax.plot(value2[xi], value2[yi], "sr")
        
        figure.savefig(d + station + '_corners', dpi=300)
        
    plotMisfits(l1s,k1s,l1,k1,costs,cutIdx,cutIdx_cloud,'_l1_k1_BH.png',cloud,drawEllipse,d,station,'Layer 1 Thickness','Layer 1 Vp/Vs','Layer 1: Thickness vs Vp/Vs',plotMinMax,l1s_minMax,k1s_minMax)
    plotMisfits(l1s,vp1s,l1,vp1,costs,cutIdx,cutIdx_cloud,'_l1_v1_BH.png',cloud,drawEllipse,d,station,'Layer 1 Thickness','Layer 1 Vp/Vs','Layer 1: Thickness vs Vp',plotMinMax,l1s_minMax,v1s_minMax)
    
    if nlayers >= 2:
        l2_std = np.std(l2s[cutIdx_cloud:])
        k2_std = np.std(k2s[cutIdx_cloud:])
        v2_std = np.std(vp2s[cutIdx_cloud:])
    
        plotMisfits(l2s,k2s,l2,k2,costs,cutIdx,cutIdx_cloud,'_l2_h2_BH.png',cloud,drawEllipse,d,station,'Layer 2 Thickness','Layer 2 Vp/Vs','Layer 2: Thickness vs Vp/Vs',plotMinMax,l2s_minMax,k2s_minMax)
        plotMisfits(l2s,vp2s,l2,vp2,costs,cutIdx,cutIdx_cloud,'_l2_v2_BH.png',cloud,drawEllipse,d,station,'Layer 2 Thickness','Layer 2 Vp','Layer 2: Thickness vs Vp',plotMinMax,l2s_minMax,v2s_minMax)
        plotMisfits(l1s,l2s,l1,l2,costs,cutIdx,cutIdx_cloud,'_layers12_BH.png',cloud,drawEllipse,d,station,'Layer 1 Thickness','Layer 2 Thickness','Layer 1 Thickness vs Layer 2 Thickness',plotMinMax,l1s_minMax,l2s_minMax)
        
    if nlayers >= 3:
        l3_std = np.std(l3s[cutIdx_cloud:])
        k3_std = np.std(k3s[cutIdx_cloud:])
        v3_std = np.std(vp3s[cutIdx_cloud:])
        
        plotMisfits(l3s,k3s,l3,k3,costs,cutIdx,cutIdx_cloud,'_l3_h3_BH.png',cloud,drawEllipse,d,station,'Layer 3 Thickness','Layer 3 Vp/Vs','Layer 3: Thickness vs Vp/Vs',plotMinMax,l3s_minMax,k3s_minMax)
        plotMisfits(l3s,vp3s,l3,vp3,costs,cutIdx,cutIdx_cloud,'_l3_v3_BH.png',cloud,drawEllipse,d,station,'Layer 3 Thickness','Layer 3 Vp','Layer 3: Thickness vs Vp',plotMinMax,l3s_minMax,v3s_minMax)
        plotMisfits(l2s,l3s,l2,l3,costs,cutIdx,cutIdx_cloud,'_layers23_BH.png',cloud,drawEllipse,d,station,'Layer 2 Thickness','Layer 3 Thickness','Layer 2 Thickness vs Layer 3 Thickness',plotMinMax,l2s_minMax,l3s_minMax)
    if nlayers >= 4:
        l4_std = np.std(l4s[cutIdx_cloud:])
        k4_std = np.std(k4s[cutIdx_cloud:])
        v4_std = np.std(vp4s[cutIdx_cloud:])
        
        plotMisfits(l4s,k4s,l4,k4,costs,cutIdx,cutIdx_cloud,'_l4_h4_BH.png',cloud,drawEllipse,d,station,'Layer 4 Thickness','Layer 4 Vp/Vs','Layer 4: Thickness vs Vp/Vs',plotMinMax,l4s_minMax,k4s_minMax)
        plotMisfits(l4s,vp4s,l4,vp4,costs,cutIdx,cutIdx_cloud,'_l4_v4_BH.png',cloud,drawEllipse,d,station,'Layer 4 Thickness','Layer 4 Vp','Layer 4: Thickness vs Vp',plotMinMax,l4s_minMax,v4s_minMax)
        plotMisfits(l3s,l4s,l3,l4,costs,cutIdx,cutIdx_cloud,'_layers34_BH.png',cloud,drawEllipse,d,station,'Layer 3 Thickness','Layer 4 Thickness','Layer 3 Thickness vs Layer 4 Thickness',plotMinMax,l3s_minMax,l4s_minMax)
    if nlayers == 5:
        l5_std = np.std(l5s[cutIdx_cloud:])
        k5_std = np.std(k5s[cutIdx_cloud:])
        v5_std = np.std(vp5s[cutIdx_cloud:])
        
        plotMisfits(l5s,k5s,l5,k5,costs,cutIdx,cutIdx_cloud,'_l5_h5_BH.png',cloud,drawEllipse,d,station,'Layer 5 Thickness','Layer 5 Vp/Vs','Layer 5: Thickness vs Vp/Vs',plotMinMax,l5s_minMax,k5s_minMax)
        plotMisfits(l5s,vp5s,l5,vp5,costs,cutIdx,cutIdx_cloud,'_l5_v5_BH.png',cloud,drawEllipse,d,station,'Layer 5 Thickness','Layer 5 Vp','Layer 5: Thickness vs Vp',plotMinMax,l5s_minMax,v5s_minMax)
        plotMisfits(l4s,l5s,l4,l5,costs,cutIdx,cutIdx_cloud,'_layers45_BH.png',cloud,drawEllipse,d,station,'Layer 4 Thickness','Layer 5 Thickness','Layer 4 Thickness vs Layer 5 Thickness',plotMinMax,l4s_minMax,l5s_minMax)
    
    #pdb.set_trace()
    if writeCSV == True:
        m = l1 + l2 + l3 + l4 + l5
        bestSol_std = [station,stn[0][0].latitude,stn[0][0].longitude,str(round(bestCost,2)),str(round(l1,2)),str(round(l1_std,2)), \
                   str(round(vp1,2)),str(round(v1_std,2)),str(round(k1,2)),str(round(k1_std,2)),str(round(l2,2)),str(round(l2_std,2)), \
                       str(round(vp2,2)),str(round(v2_std,2)),str(round(k2,2)),str(round(k2_std,2)),str(round(l3,2)),str(round(l3_std,2)), \
                           str(round(vp3,2)),str(round(v3_std,2)),str(round(k3,2)),str(round(k3_std,2)),str(round(l4,2)),str(round(l4_std,2)), \
                               str(round(vp4,2)),str(round(v4_std,2)),str(round(k4,2)),str(round(k4_std,2)),str(round(l5,2)),str(round(l5_std,2)), \
                                   str(round(vp5,2)),str(round(v5_std,2)),str(round(k5,2)),str(round(v5_std,2)),str(round(m,2))]
        
        with open(d +'bestSols_std.csv','a') as f:
            write = csv.writer(f)
            write.writerow(bestSol_std)     
        
        if minMax_ofSubset:
            bestSol_minMax = [station,stn[0][0].latitude,stn[0][0].longitude,str(round(bestCost,2)),str(round(l1,2)),str(round(min(l1s_minMax),2)),str(round(max(l1s_minMax),2)), \
                       str(round(vp1,2)),str(round(min(v1s_minMax),2)),str(round(max(v1s_minMax),2)),str(round(k1,2)),str(round(min(k1s_minMax),2)),str(round(max(k1s_minMax),2)), \
                           str(round(l2,2)),str(round(min(l2s_minMax),2)),str(round(max(l2s_minMax),2)), \
                           str(round(vp2,2)),str(round(min(v2s_minMax),2)),str(round(max(v2s_minMax),2)),str(round(k2,2)),str(round(min(k2s_minMax),2)),str(round(max(k2s_minMax),2)), \
                               str(round(l3,2)), str(round(min(l3s_minMax),2)),str(round(max(l3s_minMax),2)), \
                               str(round(vp3,2)),str(round(min(v3s_minMax),2)),str(round(max(v3s_minMax),2)),str(round(k3,2)),str(round(min(k3s_minMax),2)),str(round(max(k3s_minMax),2)), \
                                   str(round(l4,2)),str(round(min(l4s_minMax),2)),str(round(max(l4s_minMax),2)), \
                                   str(round(vp4,2)),str(round(min(v4s_minMax),2)),str(round(max(v4s_minMax),2)),str(round(k4,2)),str(round(min(k4s_minMax),2)),str(round(max(k4s_minMax),2)), \
                                       str(round(l5,2)),str(round(min(l5s_minMax),2)),str(round(max(l5s_minMax),2)), \
                                       str(round(vp5,2)),str(round(min(v5s_minMax),2)),str(round(max(v5s_minMax),2)),str(round(k5,2)),str(round(min(k5s_minMax),2)),str(round(max(k5s_minMax),2)),str(round(m,2))]
        else:
            bestSol_minMax = [station,stn[0][0].latitude,stn[0][0].longitude,str(round(bestCost,2)),str(round(l1,2)),str(round(min(l1s[cutIdx_cloud:]),2)),str(round(max(l1s[cutIdx_cloud:]),2)), \
                       str(round(vp1,2)),str(round(min(vp1s[cutIdx_cloud:]),2)),str(round(max(vp1s[cutIdx_cloud:]),2)),str(round(k1,2)),str(round(min(k1s[cutIdx_cloud:]),2)),str(round(max(k1s[cutIdx_cloud:]),2)), \
                           str(round(l2,2)),str(round(min(l2s[cutIdx_cloud:]),2)),str(round(max(l2s[cutIdx_cloud:]),2)), \
                           str(round(vp2,2)),str(round(min(vp2s[cutIdx_cloud:]),2)),str(round(max(vp2s[cutIdx_cloud:]),2)),str(round(k2,2)),str(round(min(k2s[cutIdx_cloud:]),2)),str(round(max(k2s[cutIdx_cloud:]),2)), \
                               str(round(l3,2)), str(round(min(l3s[cutIdx_cloud:]),2)),str(round(max(l3s[cutIdx_cloud:]),2)), \
                               str(round(vp3,2)),str(round(min(vp3s[cutIdx_cloud:]),2)),str(round(max(vp3s[cutIdx_cloud:]),2)),str(round(k3,2)),str(round(min(k3s[cutIdx_cloud:]),2)),str(round(max(k3s[cutIdx_cloud:]),2)), \
                                   str(round(l4,2)),str(round(min(l4s[cutIdx_cloud:]),2)),str(round(max(l4s[cutIdx_cloud:]),2)), \
                                   str(round(vp4,2)),str(round(min(vp4s[cutIdx_cloud:]),2)),str(round(max(vp4s[cutIdx_cloud:]),2)),str(round(k4,2)),str(round(min(k4s[cutIdx_cloud:]),2)),str(round(max(k4s[cutIdx_cloud:]),2)), \
                                       str(round(l5,2)),str(round(min(l5s[cutIdx_cloud:]),2)),str(round(max(l5s[cutIdx_cloud:]),2)), \
                                       str(round(vp5,2)),str(round(min(vp5s[cutIdx_cloud:]),2)),str(round(max(vp5s[cutIdx_cloud:]),2)),str(round(k5,2)),str(round(min(k5s[cutIdx_cloud:]),2)),str(round(max(k5s[cutIdx_cloud:]),2)),str(round(m,2))]
        
        with open(d +'bestSols_minMax.csv','a') as f:
            write = csv.writer(f)
            write.writerow(bestSol_minMax)   

def plotMisfits(x,y,xMin,yMin,costs,cutIdx,cutIdx_cloud,name,cloud,drawEllipse,d,station,xlabel,ylabel,title, plotSubSplit = False, xSplitMinMax=None, ySplitMinMax=None):
    #this functions will plot misfits
    fig, ax_kwargs = plt.subplots()
    plt.scatter(x[cutIdx:],y[cutIdx:],s=1, c=costs[cutIdx:], cmap='binary_r', vmin=min(costs[cutIdx:]), vmax=max(costs[cutIdx:]))
    cbar = plt.colorbar()
    cbar.set_label('Misfit')
    plt.scatter(x[cutIdx_cloud:],y[cutIdx_cloud:],s=4, color='salmon', alpha=.3)
    if cloud:
        #sns.kdeplot(x=x[cutIdx_cloud:], y=y[cutIdx_cloud:], cmap='Greens', fill=True, levels=7, thresh=0.001, cut=2, alpha=.8)
        sns.kdeplot(x=x[cutIdx_cloud:], y=y[cutIdx_cloud:],fill=True, levels=5, thresh=0.5, cut=0, alpha=.8)
    if plotSubSplit:
        plt.scatter(xSplitMinMax,ySplitMinMax,s=4, color='green', alpha=.3)
        
        if drawEllipse:
            confidence_ellipse(xSplitMinMax, ySplitMinMax, ax_kwargs, n_std=2,label=r'$1\sigma$', edgecolor='orange')  
    
    else:
        if drawEllipse:
            confidence_ellipse(x[cutIdx_cloud:], y[cutIdx_cloud:], ax_kwargs, n_std=2,label=r'$1\sigma$', edgecolor='cornflowerblue')  
        
    plt.scatter(xMin,yMin,s=25, color='salmon',marker="X",edgecolor='black')
    plt.xlim([min(x), max(x)])
    plt.ylim([min(y), max(y)])
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    
    plt.savefig(d + station + name, dpi=300)
    #plt.show()
    plt.close()

def confidence_ellipse(x, y, ax, n_std=2.0, facecolor='none', **kwargs):
    """
    Create a plot of the covariance confidence ellipse of *x* and *y*.
    https://matplotlib.org/stable/gallery/statistics/confidence_ellipse.html

    Parameters
    ----------
    x, y : array-like, shape (n, )
        Input data.

    ax : matplotlib.axes.Axes
        The axes object to draw the ellipse into.

    n_std : float
        The number of standard deviations to determine the ellipse's radiuses.

    **kwargs
        Forwarded to `~matplotlib.patches.Ellipse`

    Returns
    -------
    matplotlib.patches.Ellipse
    """
    if len(x) != len(y):
        raise ValueError("x and y must be the same size")

    cov = np.cov(x, y)
    pearson = cov[0, 1]/np.sqrt(cov[0, 0] * cov[1, 1])
    # Using a special case to obtain the eigenvalues of this
    # two-dimensional dataset.
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
                      facecolor=facecolor, **kwargs)

    # Calculating the standard deviation of x from
    # the squareroot of the variance and multiplying
    # with the given number of standard deviations.
    scale_x = np.sqrt(cov[0, 0]) * n_std
    mean_x = np.mean(x)

    # calculating the standard deviation of y ...
    scale_y = np.sqrt(cov[1, 1]) * n_std
    mean_y = np.mean(y)

    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(scale_x, scale_y) \
        .translate(mean_x, mean_y)

    ellipse.set_transform(transf + ax.transData)
    return ax.add_patch(ellipse)