#!/usr/bin/env python
import math

import numpy as np

import scipy
import scipy.special
import scipy.interpolate
import scipy.integrate as sciint
import matplotlib.pyplot as plt
from numpy import cfloat, double, dtype, linalg as la, matrix
import scipy.linalg
import pandas as pd
import seaborn as sns
import torch

class Analysis:
    def Get_Mean_Var_Skew_Kurtosis(x,y):
        yp = np.copy(y)
        yp[np.where(yp <0.001)] = 0
        vol = sciint.simpson(yp, x, x[1] - x[0])
        y_norm = yp/vol

        mean = sciint.simpson(x*y_norm,x, dx= x[1] - x[0])
        var = sciint.simpson((x-mean)**2*y_norm, x, dx=x[1]-x[0])
        skew = sciint.simpson(((x-mean)**3 * y_norm) ,x, dx= x[1] - x[0])/ (var**(3/2))
        kurtosis = sciint.simpson(y_norm * (x-mean)**4  / var**2, x, x[1] - x[0]) - 3
        return mean, var, skew, kurtosis

    def Get_N_moments(x,y):
        
        yp = np.copy(y)
        yp[np.where(yp <0.001)] = 0
        vol = sciint.simpson(yp, x, x[1] - x[0])
        y_norm = yp/vol

        mean = sciint.simpson(x*y_norm,x, dx= x[1] - x[0])
        var = sciint.simpson((x-mean)**2*y_norm, x, dx=x[1]-x[0])
        skew = sciint.simpson(((x-mean)**3 * y_norm) ,x, dx= x[1] - x[0])/ (var**(3/2))
        kurtosis = sciint.simpson(y_norm * (x-mean)**4  / var**2, x, x[1] - x[0])
        five = sciint.simpson(y_norm * (x-mean)**5  / var**2.5, x, x[1] - x[0])
        six = sciint.simpson(y_norm * (x-mean)**6  / var**3, x, x[1] - x[0])
        seven = sciint.simpson(y_norm * (x-mean)**7  / var**3.5, x, x[1] - x[0])
        return[mean,var,skew,kurtosis,five,six,seven]
        
    def Cumulant_FC_Analysis(FC_Spectra, Cumulant_specta):
        goodness = 0
        neg_area = 0
        for i in range(0, FC_Spectra.shape[1]):
            goodness += np.abs(FC_Spectra[1,i] - Cumulant_specta[1,i])
            if Cumulant_specta[1,i] < 0:
                neg_area += np.abs(Cumulant_specta[1,i])
        norm_val = sciint.simpson(y = FC_Spectra[1,:], x = FC_Spectra[0,:], dx = FC_Spectra[0,1] - FC_Spectra[0,0])
        return goodness, neg_area, norm_val
    def Lineshape_Analysis(lineshape):
        re_pos=0
        im_pos=0
        re_neg=0
        im_neg = 0
        for i in range(0, 200):
            re_val = lineshape[1,i].real
            if re_val >0:
                re_pos += re_val
            if re_val<0:
                re_neg += np.abs(re_val)
            im_val = lineshape[1,i].imag
            if im_val >0:
                im_pos += im_val
            if im_val <0:
                im_neg += np.abs(im_val)
        return re_pos,re_neg, im_pos, im_neg


class Del_U_Operators:
    def Calculate_Omega_squared(J_mat, gs_freqs, es_freqs):
        dim = len(gs_freqs)
        Omega_mat= np.zeros((dim,dim))
        for n in range(0, dim):
            for m in range(0,dim):
                for j in range(0, dim):
                    Omega_mat[n,m] += 0.5 * J_mat[n,j] * es_freqs[j]**2 * J_mat[m,j] 
                if n == m:
                    Omega_mat[n,m] -= 0.5*gs_freqs[n]**2
        return Omega_mat
    def Calculate_Xi(Omega_squared_mat, gs_freqs, KbT):
        beta = 1/KbT
        Xi = 0
        dim = len(gs_freqs)
        i = 0
        while i < dim:
            Xi += (Omega_squared_mat[i,i]/gs_freqs[i]) * (1/math.tanh(beta*gs_freqs[i]*0.5))
            i+=1
        return Xi/2
    def Calculate_Lamda_naught(K_vecs, J_mat, es_freqs):
        dim = len(es_freqs)
        lamda_naught = 0
        for j in range(0,dim):
            for n in range(0,dim):
                for m in range(0,dim):
                    lamda_naught+= K_vecs[m]*J_mat[m,j]*(es_freqs[j])**2 * J_mat[n,j] * K_vecs[n]
        return lamda_naught/2
    
    def Two_mode_trace_and_off_diag_sum_omega_sqr(gs_freqs, es_freqs, theta):
        tr = 0.5*(es_freqs[0]**2 - gs_freqs[0]**2 + es_freqs[1]**2 - gs_freqs[1]**2)
        off = 0.5*(es_freqs[0]**2 - es_freqs[1]**2) * np.sin(2*theta)
        return tr, off

class Two_mode_state:
    def __init__(self, gs_freqs, es_freqs, k_vecs, Theta):
        self.Ground_frequencies = gs_freqs
        self.Excited_frequencies = es_freqs
        self.Shift_vectors = k_vecs
        self.Rotation = Theta
    def map_to_new_parameters_conserving_Omega_sq(state_1, new_ground_1, new_excited_1, new_theta, new_k):
        gamma = np.sqrt(new_excited_1**2 -(state_1.Excited_frequencies[0]**2- state_1.Excited_frequencies[1]**2)*np.sin(2*state_1.Rotation)/ np.sin(2*new_theta))
        j_mat = np.array([[np.cos(state_1.Rotation),-np.sin(state_1.Rotation)],[np.sin(state_1.Rotation),np.cos(state_1.Rotation)]])
        tr_a = np.trace(Del_U_Operators.Calculate_Omega_squared(j_mat, state_1.Ground_frequencies,state_1.Excited_frequencies))
        w_g_2 = np.sqrt(gamma**2 + new_excited_1**2 - 2*tr_a - new_ground_1**2)
        return Two_mode_state([new_ground_1,w_g_2],[new_excited_1,gamma],new_k, new_theta)
        
class Debye_Solvent:
    def Compute_spectral_density(Reorg_E, Omega_cutoff, Omega_max, num_points):
        spec_density = np.zeros((2,num_points))
        d_omega = Omega_max/num_points
        for i in range(0, num_points):
            omega =   (i * d_omega)
            spec_density[0,i] = omega
            spec_density[1,i] = ( 2 * Reorg_E * omega)/(Omega_cutoff*(1 + (omega/Omega_cutoff)**2))
        return spec_density
    def Compute_g2_solvent(Reorg_E, Omega_cutoff, Omega_max, num_points_in_spectral_density, time_axis, KbT):
        g2_t = np.zeros((2, time_axis.shape[0]), dtype= complex)
        solvent_spect_density = Debye_Solvent.Compute_spectral_density(Reorg_E, Omega_cutoff, Omega_max, num_points_in_spectral_density)
        dw = solvent_spect_density[0,1] - solvent_spect_density[0,0]
        for i in range(0,time_axis.shape[0]):
            g2_t[0,i] = time_axis[i]
            g2_intergrant = np.zeros((2, num_points_in_spectral_density),dtype=complex)
            for j in range(1, num_points_in_spectral_density):
                omega = solvent_spect_density[0,j]
                g2_intergrant[0,j] = omega
                ##g2_intergrant[1,j] = (solvent_spect_density[1,j]/ omega**2) * ((1/np.tanh(omega/2)) * (1-np.cos(omega * time_axis[i])) - complex(0,1)*(np.sin(omega* time_axis[i]) - omega * time_axis[i]))
                g2_intergrant[1,j] = (solvent_spect_density[1,j]/omega**2) *( (1/np.tanh(omega/(2*KbT)) *(1 - np.cos(omega * time_axis[i])) - complex(0,1)*(np.sin(omega * time_axis[i]) - omega * time_axis[i])))
            g2_t[1,i] = (1/math.pi) * scipy.integrate.simpson(y = g2_intergrant[1,:], x = g2_intergrant[0,:], dx= dw)
        return g2_t
class Basis_Sets:
    def __init__(self, First_gs_basis, Second_gs_basis):
        self.First_gs_basis = First_gs_basis
        self.Second_gs_basis = Second_gs_basis
class Diagonalize:
    def __init__(self, Invertable_matrix, Eigen_value_vector):
        self.Invertable_matrix = Invertable_matrix
        self.Eigen_value_vector = Eigen_value_vector
def Inner_Product(Psi_I,Psi_J,X_dat):
    ###just your basic <Psi_I|Psi_J> calculation
    intergrant = np.zeros_like(Psi_I)
    intergrant[:] = Psi_I[:] * Psi_J[:]
    return sciint.simpson(y = intergrant[:], x = X_dat[:], dx = X_dat[1] - X_dat[0])

def Compute_Overlap_Matrix(gs_basis,es_basis):
    ###Forms matricies of <psi_i|psi_j>, is a handy diagonstic for orthonormality!
    dim = gs_basis.shape[0] - 1 ###first row is x data
    overlap_mat = np.zeros((dim,dim))
    for i in range(0, dim):
        for j in range(0,dim):
            overlap_mat[i,j] = Inner_Product(gs_basis[i+1,:], es_basis[j+1,:],gs_basis[0,:])
    return overlap_mat



def Form_gs_basis_sets(n_max,first_gs_wf_freq,second_gs_wf_freq,num_points):
    ##Forms ground state basis sets along a common unweighted coordinate system q
    x_domain = 3*np.sqrt(2*(n_max+0.5))/np.sqrt(first_gs_wf_freq)
    x_step = x_domain/(num_points-1)
    first_gs_basis = np.zeros((n_max+2,num_points))
    second_gs_basis = np.zeros_like(first_gs_basis)
    for n in range(0, n_max + 1):
        herm = scipy.special.hermite(n)
        
        for i in range(0, num_points):
            x = -0.5 * x_domain + i*x_step
            first_gs_basis[0,i] = second_gs_basis[0,i] = x 
            
            x_es = x * np.sqrt(second_gs_wf_freq)
            x_gs = x * np.sqrt(first_gs_wf_freq)
            first_gs_basis[n+1,i] =  np.power(first_gs_wf_freq/math.pi , 0.25) * np.exp(-x_gs**2/2) * herm(x_gs)
            second_gs_basis[n+1,i] = np.power(second_gs_wf_freq/math.pi , 0.25) * np.exp(-x_es**2/2) * herm(x_es) 
        if n_max<=14:
            norm = 1 / (np.sqrt((2**n) * math.factorial(n)))
            first_gs_basis[n+1,:] *= norm
            second_gs_basis[n+1,:] *= norm

        else:
            norm_gs = np.sqrt(Inner_Product(first_gs_basis[n+1,:], first_gs_basis[n+1,:], first_gs_basis[0,:]))
            norm_es = np.sqrt(Inner_Product(second_gs_basis[n+1,:], second_gs_basis[n+1,:], second_gs_basis[0,:]))
            second_gs_basis[n+1,:] /= norm_es
            first_gs_basis[n+1,:] /= norm_gs        
    return Basis_Sets(first_gs_basis,second_gs_basis)

def H_psi_j (psi_j, x_dat, freq_i):
    ### A boring ground state hamiltonian method, mostly used to test basis sets now.
    H_psi_j = np.zeros_like(psi_j)
    spline = scipy.interpolate.UnivariateSpline(x_dat[:],psi_j,s=0,k=4)
    second_derv = spline.derivative(2)
    for i in range(0,psi_j.shape[0]):
        V = -0.5  * second_derv(x_dat[i]) 
        T = 0.5 *x_dat[i]**2 * psi_j[i] * freq_i**2
        H_psi_j[i] = V + T
    return(H_psi_j)
def Form_dual_Hg_matrix(dim, gs_freq_1, gs_freq_2):
    vals = []
    for i in range(0,dim+1):
        for j in range(0,dim+1):
            vals.append( (gs_freq_1 * (0.5 + i) + (gs_freq_2 * (0.5 + j))))
    return torch.diag(torch.tensor(vals)).type(torch.cfloat)

def Delta_on_mode(psi_j,x_dat, es_freq_one, es_freq_two,gs_freq_one, gs_freq_two, mode:int, J_mat, k_vects):
    ## for two mode systems this calculates the element of Delta(1,2) That only acts on a single mode.
    if mode == 0:
        Squared_V_prefactor = 0.5*((es_freq_one*J_mat[0,0])**2 + (es_freq_two*J_mat[1,0])**2 - gs_freq_one**2)
        Linear_V_prefactor = J_mat[0,0] * k_vects[0] * es_freq_one**2 + J_mat[1,0]*k_vects[1]*es_freq_two**2
    if mode ==1:
        Squared_V_prefactor = 0.5*((es_freq_one*J_mat[0,1])**2 + (es_freq_two*J_mat[1,1])**2 - gs_freq_two**2)
        Linear_V_prefactor = J_mat[0,1] * k_vects[0] * es_freq_one**2 + J_mat[1,1] * k_vects[1] * es_freq_two**2
    H_psi_j = np.zeros_like(psi_j)
    for i in range(0,psi_j.shape[0]):
        T_1 = x_dat[i]**2 * Squared_V_prefactor* psi_j[i] 
        T_2 = x_dat[i] * Linear_V_prefactor * psi_j[i]
        H_psi_j[i] = T_1 - T_2
    return H_psi_j
def One_mode_delta_matrix(basis_Set, es_freq_one,es_freq_two,gs_freq_one,gs_freq_two, mode:int, J_mat,k_vects):
    #Calculates the change of potental between gs and es
    Mat = np.zeros((basis_Set.shape[0]-1, basis_Set.shape[0]-1))
    for i in range(0, basis_Set.shape[0]-1):
        for j in range(0,basis_Set.shape[0]-1):
            H_psi_j = Delta_on_mode(basis_Set[j+1,:], basis_Set[0,:], es_freq_one,es_freq_two,gs_freq_one,gs_freq_two, mode, J_mat,k_vects)
            Mat[i,j] = Mat[j,i] = Inner_Product(basis_Set[i+1,:],H_psi_j,basis_Set[0,:])
    return Mat

def q_overlap_mat(basis_set):
    #<i|q|j> needed for inseperable term
    Mat = np.zeros((basis_set.shape[0]-1, basis_set.shape[0]-1))
    for i in range(0, basis_set.shape[0]-1):
        for j in range(i,basis_set.shape[0]-1):
            x_psi_j = np.zeros_like(basis_set[0,:])
            for k in range(0, x_psi_j.shape[0]):
                x_psi_j[k] = basis_set[0,k] * basis_set[j+1,k] 
            Mat[i,j] = Mat[j,i] = Inner_Product(basis_set[i+1,:], x_psi_j, basis_set[0,:])
    return Mat
def Two_mode_delta_matrix(delta_on_one,delta_on_two,q_mat_one,q_mat_two,q_prefactor, k_term):
    dim = delta_on_one.shape[0]
    delta_mat = torch.zeros(dim**2,dim**2).type(torch.cdouble)
    for i in range(0, dim**2):
        for j in range(i, dim**2):
            x_coords = divmod(j,dim)
            y_coords = divmod(i,dim)
            delta_mat[i,j] += q_prefactor * q_mat_one[x_coords[0],y_coords[0]] * q_mat_two[x_coords[1],y_coords[1]] #<1|q|1><2|q|2> acts on all elements
            if x_coords[1] == y_coords[1]:
                delta_mat[i,j] += delta_on_one[x_coords[0],y_coords[0]] #<1|q^2 - q|1>delta(2)
            if x_coords[0] == y_coords[0]:
                delta_mat[i,j] += delta_on_two[x_coords[1], y_coords[1]]#<2|q^2 - q|2>delta(1)
            if i == j:
                delta_mat[i,j] += k_term  ##k delta(k_1,m_1) delta(i_2,j_2)
            delta_mat[j,i] = delta_mat[i,j]
    return delta_mat
def Form_Mega_Matrix(Hg, Delta, num_blocks):
    ###makes the big boy
    #|-iHg Delta 0 0|
    #|0 -iHg Delta 0|
    #|0 0 -IHg Delta|
    #|0  0   0  -iHg|
    # for a given number of blocks
    dim = Hg.shape[0]
    mega_dim = num_blocks * dim
    #Mega_mat = np.zeros((mega_dim,mega_dim), dtype=complex)
    Mega_mat = torch.zeros(mega_dim,mega_dim,dtype=torch.cfloat)
    #fill diagonal
    minus_i_Hg = Hg * -1j
    for i in range(0,num_blocks):
        Mega_mat[i * dim:(i+1)*dim, i*dim:(i+1)*dim] = minus_i_Hg
        if i < num_blocks - 1:
            Mega_mat[i * dim : (i+1) * dim, (i + 1) * dim: (i+2) * dim] = Delta
    return Mega_mat

@torch.no_grad()
def torch_compute_moments_faster(num_blocks: int,t_dat, Hg_diag, mega_matrix,rho):
    moments = torch.zeros(num_blocks-1, t_dat.shape[0], dtype=torch.cfloat, device='cuda')
    t_mega = mega_matrix
    t_Hg = Hg_diag
    
    t_mega = t_mega.cuda()
    t_Hg = t_Hg.cuda()
    rho = rho.cuda()
    dim = Hg_diag.size(0)
    Hgdiag = t_Hg.diagonal()
    i = 0
    for t_now in t_dat:
        #print(f"{i/t_dat.shape[0]}")
        moments[0,i] = t_now
        temp = Hgdiag.clone()
        temp.mul_(1j*t_now)
        temp.exp_()
        e_i_Hg_t = torch.diag(temp).type(torch.cdouble)
        A_t = torch.multiply(t_mega,t_now)
        e_A_t = torch.matrix_exp(A_t).type(torch.cdouble)
        y_bounds = torch.flipud(torch.arange(0,(num_blocks-1)*dim, dim))
        j = 1
        while j < y_bounds.size(0):
            val = e_i_Hg_t @ e_A_t[y_bounds[j]:y_bounds[j-1], dim*(num_blocks-1):dim*num_blocks]@rho
            moments[j,i] = torch.trace(val)
            j+=1
        i+=1
    return moments

torch_compute_moments_nograd_faster = torch.jit.script(torch_compute_moments_faster)
def FT_lineshape(e_gt, freq_grid):
    ###HOW TO USE:......not like I forgot or anything!
    ###WORK in Hartree units of frequency, compute avg = <Delta*Rho> 
    ### avg += E_adiabat(Hartree!)
    ### Now, define your freq grid and use this to FT a FC response carrying a -1j*t*E_adiabat exponet OR
    ### -1j*t*avg_freq exponet for cumulant spectra
    ### before plotting you can then convert the frequency grid to units of eV 

    ft_dat = np.zeros((2, freq_grid.shape[0]))
    for w in range(0, freq_grid.shape[0]):
        omega = freq_grid[w]

        prefac=40.0*math.pi**2.0*0.0072973525693*omega/(3.0*math.log(10.0))
        intergrant = np.zeros((e_gt.shape[1]), dtype= complex)
        intergrant[:] = e_gt[1,:] * np.exp(complex(0,1) * (omega) * e_gt[0,:]) 
        ft_dat[1,w] =  prefac*scipy.integrate.simpson(y=intergrant, x = e_gt[0,:], dx= (e_gt[0,1] - e_gt[0,0]).real).real
        ft_dat[0,w] = omega *  27.211396132
    return ft_dat

def Moments_to_cumulants(moments):
    corrected_moments = np.copy(moments)
    for i in range(1,moments.shape[0]-1):
        corrected_moments[i,:] *=(-1j)**(i+1)
    cumulants = np.zeros_like(moments, dtype=complex)
    cumulants[0,:] = corrected_moments[0,:]
    cumulants[1,:] = corrected_moments[1,:]
    cumulants[2,:] = corrected_moments[2,:]
    n = 4
    while n <= corrected_moments.shape[0]:
        cumulants[n-1,:] = corrected_moments[n-1,:]
        m =2 
        while m <= n-2:
            cumulants[n-1,:] += -(m/n) * np.multiply(cumulants[m-1,:] , corrected_moments[(n-m)-1,:])
            m+=1
        n+=1
    return cumulants



def Run_Matrix_FC_Comparison(n_max,gs_1, gs_2,es_1,es_2,Temp,g_2_solvent, E_adibat, J_mat, k_vecs):
    KbT = Temp * (8.6173303*10.0**(-5.0)/27.211396132)
    tim_k = np.transpose(J_mat)@k_vecs
    basis_sets = Form_gs_basis_sets(n_max, gs_1,gs_2, 1000)
    Hg = Form_dual_Hg_matrix(n_max,gs_1,gs_2)
    delta_one = One_mode_delta_matrix(basis_sets.First_gs_basis, es_1,es_2,gs_1,gs_2,0,J_mat,k_vecs)
    delta_two = One_mode_delta_matrix(basis_sets.Second_gs_basis,es_1,es_2,gs_1,gs_2,1,J_mat,k_vecs)
    q_1 = q_overlap_mat(basis_sets.First_gs_basis)
    q_2 = q_overlap_mat(basis_sets.Second_gs_basis)
    shared_q_factor = (es_1**2*J_mat[0,0]*J_mat[0,1] + es_2**2 * J_mat[1,0] * J_mat[1,1])
    k_factor = 0.5*(es_1*k_vecs[0])**2 + 0.5*(es_2*k_vecs[1])**2
    Full_delta = Two_mode_delta_matrix(delta_one,delta_two,q_1,q_2, shared_q_factor, k_factor)
    rho = torch.diag(Hg)
    rho = torch.exp(torch.multiply(rho, - 1/KbT))
    rho = torch.diag(rho / torch.sum(rho)).type(torch.cdouble)
    avg_freq = torch.trace(Full_delta@rho).item().real
    Full_delta = Full_delta - torch.multiply(torch.eye(Hg.shape[0]), avg_freq)
    Mega = Form_Mega_Matrix(Hg, Full_delta,4)
    fc_line_frame = pd.read_csv("FC_lineshape_function.dat", sep=" ", header=None)
    fc_line_Array = np.array(fc_line_frame)
    t_dat = fc_line_Array[:,0]###THIS IS A HUGE WEAK SPOT
    moments = torch_compute_moments_faster(4,t_dat,Hg,Mega,rho)
    cpu_moments = moments.cpu().numpy()
    cumulants = Moments_to_cumulants(cpu_moments)
    Spec_pkg_handler.Get_FC_LineShape([gs_1,gs_2],[es_1,es_2], tim_k, np.transpose(J_mat), 600)
    FC_data = np.array(pd.read_csv("FC_lineshape_function.dat", sep = " ", header = None))
    FC_response = np.zeros_like(g_2_solvent, dtype=complex)
    FC_response[0,:] = t_dat
    FC_response[1,:] = np.exp(-(FC_data[:,1] + 1j*FC_data[:,2] + g_2_solvent[1,:]))
    FC_FT = FT_lineshape(FC_response,E_adibat, 15*gs_1,400)
    Matrix_response = np.zeros_like(FC_response, dtype= complex)
    Matrix_response[0,:] = t_dat[:]
    Matrix_response[1,:] = np.exp(cumulants[1,:] + cumulants[2,:] - g_2_solvent[1,:])
    Matrix_FT = FT_lineshape(Matrix_response, E_adibat + avg_freq, 12 * gs_1, 400)
    return FC_FT,Matrix_FT

    
def Heat_map_quantities (data_set,x_axis_name, y_axis_name, z_axis_name, x_vol,y_vol, x_lim, y_lim,z_lim,min_vol,scaling_factor):
    names = ["3RD_GOODNESS","3RD_NONPHYS","2ND_GOODNESS","2ND_NONPHYS","MEAN","VAR","SKEW","KURT","TRACE","ANTI"]
    x_ax = names.index(x_axis_name)
    y_ax = names.index(y_axis_name)
    z_ax = names.index(z_axis_name)
    dx = x_vol
    dy = y_vol
    min_density = min_vol
    x=0
    y=0
    xs = []
    ys = []
    zs = []
    while x < x_lim:
        while y < y_lim:

            x_lb = np.where(data_set[:,x_ax] >= x)
            x_ub = np.where(data_set[:,x_ax] < x + dx)
            x_pics = np.intersect1d(x_lb,x_ub)
            y_lb = np.where(data_set[:,y_ax] >= y)
            y_ub = np.where(data_set[:,y_ax] < y+ dy)
            y_pics = np.intersect1d(y_lb,y_ub)
            pics = np.intersect1d(y_pics,x_pics)
            z_sum = 0
            for pic in pics:
                z_sum += dat[pic,z_ax]
            trace_point = x + 0.5*dx
            skew_point = y + 0.5*dy
            if pics.shape[0] > min_density:
                av_val = z_sum/pics.shape[0]
                zs.append(av_val)
                ys.append(round(skew_point* scaling_factor,2))
                xs.append(round(trace_point*scaling_factor,2))
            plt.xlim(right = 2.5*10**-5)
            plt.ylim(top = 2.5*10**-5)
            y = y +dy
        y=0
        x+= dx
    d = {x_axis_name: xs, y_axis_name: ys, z_axis_name : zs}
    df = pd.DataFrame(d, index=None)
    df = df.pivot(index=y_axis_name, columns=x_axis_name, values=z_axis_name)
    plot = sns.heatmap(df, vmax=z_lim)
    plot.invert_yaxis()
    plt.title(f"{x_axis_name} vs {y_axis_name} : {z_axis_name}" )
    plt.show()
