#!/Users/kangwang/opt/anaconda3/bin/python3
import numpy as np
from numpy.core.arrayprint import dtype_short_repr
from numpy.fft import fft
from numpy.lib import fromregex
from numpy.lib.index_tricks import RClass
from numpy.lib.nanfunctions import nanpercentile
from  scipy.sparse import linalg as sp2
import scipy.sparse
import time
# import numba as nb
import timeit
np.set_printoptions(precision = 4)
np.set_printoptions(suppress=True)
# 
def shift(a,l,N):
    if a <= l:
        return a
    elif a > l:
        return a-N
def fftshift(a,N):
    if a < 0:
            return a+N
    elif a >= 0:
        return a
def ASR_old(M):
    print('doing ASR')
    # M = M.tolil()
    (a,b) = np.shape(M)
    print('dimension of matrix',a,b)
    # m = M.toarray()
    # print(np.shape(m))
    
    M1=M.copy()
    assert (a % 3 == 0) and ( a==b )
    natom = int(a/3)
    print('start ASR')
    for i in range(natom):
        # print(i)
        s = np.zeros((3,3),dtype='float')
        for j in range(natom):
            s = s+M[i*3:(i+1)*3,j*3:(j+1)*3]
        for k in range(natom):
            M1[i*3:(i+1)*3,k*3:(k+1)*3] =   M[i*3:(i+1)*3,k*3:(k+1)*3] - s/natom
    return M1
def ASR_part(M):
    #print('doing ASR')
    (a,b) = np.shape(M)
    #print('dimension of matrix',a,b)
    M1=M.copy()
    assert (a % 3 == 0) and ( a==b )
    natom = int(a/3)
    #print('start ASR')
    for i in range(a):
        M1[i] = (M1[i].reshape(-1,3) - np.sum(M1[i].reshape(-1,3),axis=0)/int(a/3)).reshape(1,-1)
    return M1
def ASR(M):
         n=int(8)
        #  print("doing ASR")
         M1 = M.copy()
         for i in range(n):
             print(f'start {i} ASR',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
             M1 = ASR_part(M1)
             M1 = (M1+np.transpose(M1))/2
         return M1
def ASR_h(M):
         n=8
        #  print("doing ASR")
         M1 = M.copy()
         for i in range(n):
             M1 = ASR_part(M1)
             M1 = (M1+np.transpose(M1))/2
         return M1
def swapxy(n_sc,fsc,natom=6):
    index = 0
    vx=np.zeros((n_sc[0,0],n_sc[1,1],n_sc[2,2]),dtype='int')
    vy=np.zeros((n_sc[0,0],n_sc[1,1],n_sc[2,2]),dtype='int')
    for k in range(n_sc[2,2]):
        for j in range(n_sc[1,1]):
            for i in range(n_sc[0,0]):
                vx[i,j,k] = index
                vy[j,i,k] = index
                index += 1
    f = fsc.copy()
    for k in range(n_sc[2,2]):
        for j in range(n_sc[1,1]):
            for i in range(n_sc[0,0]):
                x = vx[i,j,k];y=vy[i,j,k]
                f[y*natom*3:(y+1)*natom*3,:] = fsc[x*natom*3:(x+1)*natom*3,:]
    return f
def irreducible(n_sc,cutoff):
    lsc = [int(np.floor((cutoff[i]/n_sc[i,i]-1)/2)) for i in range(3)]
    for i in range(3):
        if lsc[i] < 1:
            lsc[i] = 1
    llist = []
    for i in range(-lsc[0],1):
        for j in range(-lsc[1],lsc[1]+1):
            for k in range(-lsc[2],lsc[1]+1):
                frac = np.array([i,j,k])
                llist.append(frac)
    return llist
def hermitian(f):
    return np.conj(np.transpose(f))


def fconstant_sc(n_sc,D_pc,cutoff,ifnanoparticles=False):
    # D_pc = np.fft.fftn(generate_flist(lattice),axes=(0,1,2))
    print('start fconstant_sc constructing',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    f = np.real(np.fft.ifftn(D_pc,axes=(0,1,2)))
    [n1,l1,m1] = cutoff
    l0 = [int((n1-1)/2),int((l1-1)/2),int((m1-1)/2)]
    lsc = [int(np.floor((cutoff[i]/n_sc[i,i]-1)/2)) for i in range(len(l0))]
    for i in range(3):
        if lsc[i] < 1:
            lsc[i] = 1
    n1 = 2*lsc[0]+1; l1 = 2*lsc[1]+1; m1 = 2*lsc[2]+1
    llist = []
    for i in range(-lsc[0],lsc[0]+1):
        for j in range(-lsc[1],lsc[1]+1):
            for k in range(-lsc[2],lsc[2]+1):
                frac = np.array([i,j,k])
                llist.append(frac)
    # f_pc=dict()
    # for k in range(-l0[2],l0[2]+1):
    #     for j in range(-l0[1],l0[1]+1):
    #         for i in range(-l0[0],l0[0]+1):
    #             frac = np.array([i,j,k])
    #             f_pc[f'{frac}'] = f[fftshift(i,2*l0[0]+1),fftshift(j,2*l0[1]+1),fftshift(k,2*l0[2]+1)]    
    natom = 6
    ns = n_sc[0,0]*n_sc[1,1]*n_sc[2,2]
    dim = 3*ns*natom
    index=0;v=np.zeros((ns,3),dtype='int')
    for k in range(n_sc[2,2]):
        for j in range(n_sc[1,1]):
            for i in range(n_sc[0,0]):
                v[index] = np.array([i,j,k])
                index += 1
    
    reduced_l = irreducible(n_sc,cutoff)
    if ifnanoparticles == True:
        reduced_l = np.array([[0,0,0]])
        # fsc=scipy.sparse.lil_matrix((dim,dim),dtype='float')
    else:
        fsc=np.zeros((n1,l1,m1,dim,dim),dtype='float')
    print(2)
    for i in reduced_l:
            F = np.zeros((dim,dim),dtype='float')
            R=np.dot(i,n_sc)
            for j in range(ns):
                if j%100==0:
                    print(j)
                for k in range(ns):
                    vector = v[k] - v[j] + R
                    if abs(vector[0])<=l0[0] and abs(vector[1])<=l0[1] and abs(vector[2])<=l0[2]:
                        # F[3*j*natom:3*natom*(j+1),3*k*natom:3*natom*(k+1)] += f_pc[f'{vector}']
                        F[3*j*natom:3*natom*(j+1),3*k*natom:3*natom*(k+1)] += f[fftshift(vector[0],2*l0[0]+1),fftshift(vector[1],2*l0[1]+1),fftshift(vector[2],2*l0[2]+1)]
                    else:
                        pass
            if ifnanoparticles == True:
                print('finsh fsc',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
                return F
            elif ifnanoparticles == False:
                fsc[fftshift(i[0],n1),fftshift(i[1],l1),fftshift(i[2],m1)] = F
                fsc[fftshift(-i[0],n1),fftshift(-i[1],l1),fftshift(-i[2],m1)] = hermitian(F)
    return fsc   

def Dyna(fsc,q=np.zeros(3,)):
    if np.linalg.norm(q) < 1e-5:
        return np.sum(fsc,axis=(0,1,2))
    else:
        (x,y,z,a,b) = np.shape(fsc)
        D = np.zeros((a,b),dtype='complex_')
        for i in range(x):
            for j in range(y):
                for k in range(z):
                    
                    a0=shift(i,int((x-1)/2),x);b0=shift(j,int((y-1)/2),y);c0=shift(k,int((z-1)/2),z)
                    D = D + fsc[i,j,k]*np.exp(-1j*2*np.pi*np.dot(q,np.array([a0,b0,c0])))      
        return D



class dyna: 
    # construct dynamical matrix of primitive cell at all q points based on TiO2_unit_phonon.*.castep
    def __init__(self,filename,n_sc):
        data = open(filename).readlines()
        qp = []; dy = []; i=0
        
        [self.x,self.y,self.z] = filename.split('.')[:3]
        for ii,j in enumerate(data):
            if 'Real Lattice' in data[ii]:
                lattice = np.loadtxt(data[ii+1:ii+4])[:,0:3]
                self.lattice = np.array(lattice,dtype='float')
                
            elif 'Fractional coordinates of atoms' in data[ii]:
                frac = [data[o+3].split()[3:6] for o in range(ii,6+ii)]
                frac=np.array(frac,dtype='float')
                for i in range(len(frac)):
                    for j in range(len(frac[0])):
                        if frac[i,j] < -0.01:
                            frac[i,j] = frac[i,j] + 1 
                self.coord = np.array(frac,dtype='float')
            elif 'External pressure/stress (GPa)' in j:
                self.pre = float(data[ii+1].split()[0])
            elif 'q-pts' in j:
                i=ii
                break
        while i < len(data):
            if 'q-pt=' in data[i]:
                if ('Acoustic sum rule correction' in data[i+2]) or ('q->0 along' in data[i+1]):
                    pass
                else:
                    # print(data[i],i) 
                    qq = data[i].split()[-5:-2]; qq[-1] = qq[-1][:-1]
                    qp.append(qq)
                    for j in range(i,-1,-1):
                        if 'real part' in data[j]:
                            real = j
                            break
                        elif 'imaginary part' in data[j]:
                            im = j
                    # print(i)
                    Re = np.zeros((54,6),dtype='complex'); Im = np.zeros((54,6),dtype='complex')
                    for k in range(54):
                        Re[k] = data[real+k+1].split()[-6:]; Im[k] = data[im+k+1].split()[-6:]
                    comp = Re+Im*1j
                    dy.append(comp.reshape((18,18)))
            i+=1        
        self.dy=np.array(dy,dtype='complex'); self.qp=np.array(qp,dtype='float')
        self.x=int(self.x);self.y=int(self.y);self.z=int(self.z)
        atom=[]; type=[]
        for k in range(n_sc[2,2]):
            for j in range(n_sc[1,1]):
                for i in range(n_sc[0,0]):
                    for l in self.coord:
                        v = np.dot(self.lattice,(l+np.array([i,j,k])))
                        atom.append(v)
                    type = type + ['8','8','8','8','22','22']
        self.atomlist = np.array(atom)
        self.type = np.array(type)
        self.n_sc = n_sc
        
    
    def ellipse(self,a,b,c):
        sc = np.dot(self.n_sc,self.lattice)
        v0 = 0.5*(sc[0]+sc[1]+sc[2])
        p=np.array([1/a,1/b,1/c])
        dis = np.array([np.linalg.norm((i-v0)*p) for i in self.atomlist])
        l = np.where(dis<=1)[0]
        return l

    def cylinder(self,radius,leg,direction=np.array([0,0,1])):
        sc = np.dot(self.n_sc,self.lattice)
        v0 = 0.5*(sc[0]+sc[1]+sc[2])
        dis = np.array([np.linalg.norm(np.cross(direction,v0-i))/np.linalg.norm(direction) for i in self.atomlist])
        l = np.where(dis<=radius)[0]
        ll = []
        for i in l:
            if np.abs(np.dot(direction,v0-self.atomlist[i]))/np.linalg.norm(direction) <  leg/2:
                ll.append(i)
        return ll
    def cube(self,a,b,c):
        sc = np.dot(self.n_sc,self.lattice)
        v0 = 0.5*(sc[0]+sc[1]+sc[2])
        l=[]
        for i,j in enumerate(self.atomlist):
            v = j-v0
            if (np.abs(v[0])<=a/2) and (np.abs(v[1])<=b/2) and np.abs(v[2])<=c/2:
                l.append(i)
        return l

    def checkq(self,q1):
        l = np.array([self.x,self.y,self.z])
        r1 = np.cos(q1*np.array([2*np.pi/self.x,2*np.pi/self.y,2*np.pi/self.z]))
        r2 = np.array([r1[1],r1[0],r1[2]])
        qq = [[int(np.round(self.qp[i,j]*l[j])) for j in range(3)] for i in range(len(self.qp))]
        qp = np.cos(qq*np.array([2*np.pi/self.x,2*np.pi/self.y,2*np.pi/self.z]))
        for i,j in enumerate(qp):
            # print(r1,j,np.linalg.norm(r1-j)) 
            if (np.linalg.norm(r1-j) < 1e-5):
                return i,0
            elif (np.linalg.norm(r2-j) < 1e-5):
                return i,1 
        return 'error'
            
    def findq(self,q1,qq):
        l = np.array([self.x,self.y,self.z])
        q2 = np.array([shift(q1[i],int((l[i]-1)/2),l[i]) for i in range(3)])
        for i,j in enumerate(qq):
            if (np.linalg.norm(q2-j) < 1e-5):
                return i
        print('not finding equivalent q points')   
        return 'error'


    def pc_dynamical(self):
        data = self.dy; a = self.x; b = self.y; c = self.z
        l=np.array([a,b,c])
        print('dynamica abc',l)
        qq = [[int(np.round(self.qp[i,j]*l[j])) for j in range(3)] for i in range(len(self.qp))]
        D = np.zeros((a,b,c,18,18),dtype='complex_')
        for k in range(c):
            for j in range(b):
                for i in range(a):
                    index = self.findq(np.array([i,j,k]),qq)
                    # print(i,j,k,index,"ijk")
                    D[i,j,k] = data[index]
        print('finish dynamical abc')
        return D

    def Dynamical_to_be_done(self):
        data = self.dy; a = self.x; b = self.y; c = self.z
        l=np.array([a,b,c])
        qq = [[int(np.round(self.qp[i,j]*l[j])) for j in range(3)] for i in range(len(self.qp))]
        D = np.zeros((a,b,c,18,18),dtype='complex_')
        for k in range(c):
            for j in range(b):
                for i in range(a):
                    (index,t) = self.checkq(np.array([i,j,k]))
                    D[i,j,k] = data[index]
                    
        return D
        

    def pc_to_sc_eig(self):
        # calculate the nano eig based on bloch theorem
        n_sc=self.n_sc
        l = n_sc[0,0]*n_sc[1,1]*n_sc[2,2]
        Eig = np.zeros((l*18,l*18),dtype=complex)
        Rlist=[]
        for k in range(n_sc[2,2]):
            for j in range(n_sc[1,1]):
                for i in range(n_sc[0,0]):
                        v = np.array([i,j,k])
                        Rlist.append(v)
        l1 = int((n_sc[0,0]-1)/2)
        l2 = int((n_sc[1,1]-1)/2)
        l3 = int((n_sc[2,2]-1)/2)
        qlist=[]
        for i in range(-l1,l1+1):
            for j in range(-l2,l2+1):
                for k in range(-l3,l3+1):
                    v = np.array([i/n_sc[0,0],j/n_sc[1,1],k/n_sc[2,2]])
                    qlist.append(v)
        fre=np.zeros(len(qlist)*18,)
        fre=np.array(fre,dtype=complex)
        for i in range(len(qlist)):
            q = qlist[i]
            u,w = np.linalg.eig(self.dy[i])
            fre[i*18:(i+1)*18]=u
            # print(w)
            print(i)
            # print(self.dy[i])  
            # break  
            # for j in range(18):
            #     for k in range(l):
            #         factor = np.exp(1j*np.dot(q,Rlist[k])*2*np.pi)
            #         Eig[18*k:18*(k+1),i*18+j] = w[:,j]*factor
            #     Eig[:,i*18+j] = Eig[:,i*18+j]/np.linalg.norm(Eig[:,i*18+j])
                    
        return np.array(fre),Eig
    def pc_to_sc_eig1(self):
        # calculate the nano eig based on bloch theorem   
        n_sc=self.n_sc
        l = n_sc[0,0]*n_sc[1,1]*n_sc[2,2]
        
        l1 = int((n_sc[0,0]-1)/2)
        l2 = int((n_sc[1,1]-1)/2)
        l3 = int((n_sc[2,2]-1)/2)
        qlist=[]
        for i in range(-l1,l1+1):
            for j in range(-l2,l2+1):
                for k in range(-l3,l3+1):
                    v = np.array([i/n_sc[0,0],j/n_sc[1,1],k/n_sc[2,2]])
                    qlist.append(v)
        fre=np.zeros(len(qlist)*18,)
        fre=np.array(fre,dtype=complex)
        Eig = np.zeros((l*18,l*18),dtype='csingle')
        for i in range(l):
            if i%100 ==0:
                print(i)
            q = qlist[i]
            u,w = np.linalg.eig(self.dy[i])
            fre[i*18:(i+1)*18]=u
            a=self.n_sc[0,0];b=self.n_sc[1,1];c=self.n_sc[2,2]
            Rlist=np.zeros((a,b,c,1),dtype='csingle')
            for kk in range(a):
                for jj in range(b):
                    for ii in range(c):
                        Rlist[kk,jj,ii] = np.exp(1j*np.dot(q,np.array([ii,jj,kk]))*2*np.pi)
            # print(i,j)
            for j in range(18):
                eig=np.ones((a,b,c,18))
                eig = eig*w[:,j]
                # print(np.shape(eig))
                # print(np.shape(Rlist)) 
                eig = eig*Rlist
                
                Eig[:,i*18+j] = eig.reshape(-1,)
                # print(np.linalg.norm(Eig[:,i*18+j]))  
                Eig[:,i*18+j] = Eig[:,i*18+j]/np.linalg.norm(Eig[:,i*18+j])
            # del eig, Rlist, w 
                    
        return np.array(fre),Eig
    def pc_to_sc_eig2(self,nano):
        # calculate the nano eig based on bloch theorem   
        n_sc=self.n_sc
        l = n_sc[0,0]*n_sc[1,1]*n_sc[2,2]
        
        l1 = int((n_sc[0,0]-1)/2)
        l2 = int((n_sc[1,1]-1)/2)
        l3 = int((n_sc[2,2]-1)/2)
        qlist=[]
        for i in range(-l1,l1+1):
            for j in range(-l2,l2+1):
                for k in range(-l3,l3+1):
                    v = np.array([i/n_sc[0,0],j/n_sc[1,1],k/n_sc[2,2]])
                    qlist.append(v)
        fre=np.zeros(len(qlist)*18,)
        fre=np.array(fre,dtype=complex)
        coeff=np.zeros(l*18,)
        coeff=np.array(coeff,dtype=complex)
        for i in range(l):
            q = qlist[i]
            u,w = np.linalg.eig(self.dy[i])
            fre[i*18:(i+1)*18]=u
            a=self.n_sc[0,0];b=self.n_sc[1,1];c=self.n_sc[2,2]
            Rlist=np.zeros((a,b,c,1),dtype='csingle')
            for kk in range(a):
                for jj in range(b):
                    for ii in range(c):
                        Rlist[kk,jj,ii] = np.exp(1j*np.dot(q,np.array([ii,jj,kk]))*2*np.pi)
            if i % 100 == 0:
                print(i)
            for j in range(18):
                eig=np.ones((a,b,c,18))
                eig = eig*w[:,j]
                # print(np.shape(eig)) 
                # print(np.shape(Rlist))
                eig = eig*Rlist
                eig = eig.reshape(-1,)    
                eig = eig/np.linalg.norm(eig)
                coeff[i*18+j]=np.dot(np.conj(eig),nano)
            # del eig, Rlist, w
                    
        return coeff,fre
    def sc_bulk_dy(self):
        try:
                D =np.load(f'D{self.z}.{int(self.pre)}.npy')
        except:
                D = self.pc_dynamical()
                np.save(f'D{self.z}.{int(self.pre)}.npy',D)
        cutoff=[self.x,self.y,self.z]
        fsc=fconstant_sc(self.n_sc,D,cutoff,ifnanoparticles=False)
        return Dyna(fsc)

    def sc_bulk_eig(self):
        bulk = self.sc_bulk_dy()
        D = ASR(bulk)
        print('start eig',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        u,Eig = np.linalg.eig(D)
        print('start saving',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        np.save('bulk_fre.npy',u)
        np.save('bulk_eig.npy',Eig)
        return u,Eig
        # fre2 = np.sqrt(u)
        # index = np.argsort(fre2)
        # Eig2 = Eig[:,index]
        # return fre2[index],Eig2 
    def nano_to_bulk(self,nano_eig,l):
        ll =[]
        for i in l:
            ll.append(i*3)
            ll.append(i*3+1)
            ll.append(i*3+2)
        bulk = np.zeros((len(self.atomlist)*3,),dtype=complex)
        
        bulk[ll] = nano_eig
        return bulk
    def bulk_to_nano(self,bulk_eig_matrix,l):
        ll =[]
        for i in l:
            ll.append(i*3)
            ll.append(i*3+1)
            ll.append(i*3+2)
        nano = bulk_eig_matrix[ll]
        nano = nano[:,ll]
        for i in range(len(nano)):
            nano[:,i] = nano[:,i]/np.linalg.norm(nano[:,i])
        return nano

    def bulk_percent(self,bulk_Eig,l,nano_eig,mode):
        nano = self.nano_to_bulk(nano_eig,mode,l)
        return np.dot(np.linalg.inv(bulk_Eig),nano)

    def sc_nano_dy(self,shape,read_f=False,clamping=False):
        try:
            f = np.load(f'fsc1{self.z}.{int(self.pre)}.npy')
        except:
            try:
                
                D =np.load(f'D{self.z}.{int(self.pre)}.npy')
            except:
                
                D = self.pc_dynamical()
                np.save(f'D{self.z}.{int(self.pre)}.npy',D)
            print('start fsc',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
            f = fconstant_sc(self.n_sc,D,[self.x,self.y,self.z],ifnanoparticles=True)
            # be careful to save fsc matrix since it is 30 GB for a 15 15 15 supercell and may be too expensive to load such a big matrix.
            #  np.save(f'fsc{self.z}.{int(self.pre)}.npy',f)
            print('finish fsc',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

        assert shape[0] in ['ellipse','cylinder','cube']
        if shape[0] == 'ellipse':
            [a,b,c]=shape[1:]
            l = self.ellipse(a,b,c)
            l1 = self.ellipse(a+1.4,b+1.4,c+1.4)
            print("total atom:",len(l))
            # if clamping==True:
            #     print('surface atom:',len(l1)-len(l))
            
        elif shape[0] == 'cylinder':
            [radius,leg,direction]=shape[1:]
            l = self.cylinder(radius,leg,direction)
        elif shape[0] == 'cube':
            [a,b,c]=shape[1:]
            l = self.cube(a,b,c)
        l_all = [i for i in range(len(self.atomlist))]
        atom_not_in_sc = [i for i in l_all if i not in l]
        Atom_surface = [i for i in l1 if i not in l]
        # print('surface atom:',np.sort(Atom_surface))
        print('l_inside:',np.sort(l))
        A = np.sort(atom_not_in_sc)
        lsurface=[]
        for i in Atom_surface:
            lsurface.append(i*3)
            lsurface.append(i*3+1)
            lsurface.append(i*3+2)
        if clamping:
            # print(lsurface)
            # print(f[lsurface,:][:,lsurface])
            f[lsurface,:]=f[lsurface,:]/1000
            f[:,lsurface]=f[:,lsurface]/1000
            l = l1
        ll =[]
        for i in l:
            ll.append(i*3)
            ll.append(i*3+1)
            ll.append(i*3+2)
        
        print('start slicing',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        # f = np.delete(f,lremove,0)
        f = f[ll]
        print('finish row',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        # f = np.delete(f,lremove,1)
        f = np.transpose(f)
        print('finish transpose',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        f = f[ll]
        # print(f)
        print('finish column',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        return l,f


    def outputstructure(self,ll):
        # sc=self.n_sc
        f=open("TiO2.xsf",'w')
        f.write(f'#Ti O\nCRYSTAL\nPRIMVEC\n')
        sc = np.dot(self.n_sc,self.lattice)
        for i in range(3):
                f.write(f'{sc[i,0]} {sc[i,1]} {sc[i,2]}\n')
        f.write('CONVVEC\n')
        for i in range(3):
                f.write(f'{sc[i,0]} {sc[i,1]} {sc[i,2]}\n')
        f.write(f'PRIMCOORD\n {int(len(ll))} 1\n')
        for i in range(len(self.atomlist)):
            if i not in ll:
                pass
            else:
                # if self.type[i] == '8':
                #     pass
                # else: 
                    
                    f.write(f'{self.type[i]} {self.atomlist[i,0]} {self.atomlist[i,1]} {self.atomlist[i,2]}\n')
        f.close()
    
    def eigenvector(self,ll,D,radius):
        kk=int(5/(radius/32)**2)
        # kk = int(len(D)*0.8)
        #kk = 5
        D1 = np.real(ASR(D))
        #D1 = (np.real(D)+np.transpose(np.real(D)))/2
        print('ll',len(ll))
        print('D1',np.shape(D1))
        print('start eig',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        u1,w1 = sp2.eigsh(D1,k=kk,sigma=-1e4,which="LA")
        #u1,w1 = sp2.eigs(D1,k=20,which="SM") 
        print(u1)
        print('finish eig',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        fre2 = np.sort(u1.copy())
        fre = []
        for i in np.real(fre2):
            a = float(i)**0.5
            if float(i) < 0:
                fre.append(np.imag(a)*-1)
            else:
                fre.append(a)
             
        f=open(f"p_{self.pre}_{radius}.phonon",'w')
        f.write(f' BEGIN header\n Number of ions         {len(ll)}\n Number of branches     {3*len(ll)}\n Number of wavevectors  1\n Frequencies in         cm-1\n IR intensities in      (D/A)**2/amu\n Raman activities in    A**4 amu**(-1)\n Unit cell vectors (A)\n')
        sc = np.dot(self.n_sc,self.lattice)
        for i in range(3):
                f.write(f'    {sc[i,0]:.06f}    {sc[i,1]:.06f}    {sc[i,2]:.06f}\n')
        f.write('Fractional Co-ordinates\n')
        inv = np.linalg.inv(sc)
        nu = 1
        for i,j in enumerate(self.atomlist):
            if i not in ll:
                pass
            else:
                v = np.dot(inv,j)
                if self.type[i] == '8':
                    atom = 'O'; mass = 15.999400
                elif self.type[i] == '22':
                    atom = 'Ti'; mass = 47.867000
                else:
                    atom = 'err'
                f.write(f'{nu} {v[0]:.06f} {v[1]:.06f} {v[2]:.06f} {atom} {mass}\n')
                nu += 1
        f.write(" END header\n")
        f.write('q-pt=    1    0.000000  0.000000  0.000000      1.0\n')
        for i,j in enumerate(fre):
            f.write(f'{i+1}  {j:.06f}\n')
        f.write("                        Phonon Eigenvectors\n")
        f.write("Mode Ion                X                                   Y                                   Z\n")
        u11=u1.copy();w11=w1.copy()
        # for i in range(len(fre2)):
        for i in range(kk):
            index = np.where(u11==fre2[i])[0][0]
            u11=np.delete(u11,index,0)
            v=w11[:,index].reshape(-1,3)
            w11 = np.delete(w11,index,1)
            for m,n in enumerate(v):
                f.write(f"   {i+1}   {m+1} {n[0]:.12f}  0.000000000000      {n[1]:.12f}  0.000000000000     {n[2]:.12f}  0.000000000000\n")
        f.close()
        print('Successful')
        return u1,w1

    def nano_in_bulk(self,u1,w1,l_in):
        ll=self.atomlist
        fre2 = np.sort(u1.copy())
        fre = []
        for i in np.real(fre2):
            a = float(i)**0.5
            if float(i) < 0:
                fre.append(np.imag(a)*-1)
            else:
                fre.append(a)
             
        f=open(f"p_{self.pre}.nano.bulk.phonon",'w')
        f.write(f' BEGIN header\n Number of ions         {len(ll)}\n Number of branches     {3*len(ll)}\n Number of wavevectors  1\n Frequencies in         cm-1\n IR intensities in      (D/A)**2/amu\n Raman activities in    A**4 amu**(-1)\n Unit cell vectors (A)\n')
        sc = np.dot(self.n_sc,self.lattice)
        for i in range(3):
                f.write(f'    {sc[i,0]:.06f}    {sc[i,1]:.06f}    {sc[i,2]:.06f}\n')
        f.write('Fractional Co-ordinates\n')
        inv = np.linalg.inv(sc)
        nu = 1
        for i,j in enumerate(self.atomlist):
                v = np.dot(inv,j)
                if self.type[i] == '8':
                    atom = 'O'; mass = 15.999400
                elif self.type[i] == '22':
                    atom = 'Ti'; mass = 47.867000
                else:
                    atom = 'err'
                f.write(f'{nu} {v[0]:.06f} {v[1]:.06f} {v[2]:.06f} {atom} {mass}\n')
                nu += 1
        f.write(" END header\n")
        f.write('q-pt=    1    0.000000  0.000000  0.000000      1.0\n')
        for i,j in enumerate(fre[:50]):
            f.write(f'{i+1}  {j:.06f}\n')
        f.write("                        Phonon Eigenvectors\n")
        f.write("Mode Ion                X                                   Y                                   Z\n")
        u11=u1.copy();w11=w1.copy()
        latom =[]
        for i in np.sort(l_in):
            latom.append(i*3)
            latom.append(i*3+1)
            latom.append(i*3+2)
        # for i in range(len(fre2)):
        for i in range(50):
            index = np.where(u11==fre2[i])[0][0]
            u11=np.delete(u11,index,0)
            
            nano = np.zeros(len(self.atomlist)*3,)
            nano[latom] = w11[:,index]
            w11 = np.delete(w11,index,1)
            v=nano.reshape(-1,3)
            for m,n in enumerate(v):
                f.write(f"   {i+1}   {m+1} {n[0]:.12f}  0.000000000000      {n[1]:.12f}  0.000000000000     {n[2]:.12f}  0.000000000000\n")
        f.close()
        print('Successful')
    def eig_bulk(self,u1,w1):
        ll=self.atomlist
        fre2 = np.sort(u1.copy())
        fre = []
        for i in np.real(fre2):
            a = float(i)**0.5
            if float(i) < 0:
                fre.append(np.imag(a)*-1)
            else:
                fre.append(a)
             
        f=open(f"p_{self.pre}.bulk.lowest1.phonon",'w')
        f.write(f' BEGIN header\n Number of ions         {len(ll)}\n Number of branches     {3*len(ll)}\n Number of wavevectors  1\n Frequencies in         cm-1\n IR intensities in      (D/A)**2/amu\n Raman activities in    A**4 amu**(-1)\n Unit cell vectors (A)\n')
        sc = np.dot(self.n_sc,self.lattice)
        for i in range(3):
                f.write(f'    {sc[i,0]:.06f}    {sc[i,1]:.06f}    {sc[i,2]:.06f}\n')
        f.write('Fractional Co-ordinates\n')
        inv = np.linalg.inv(sc)
        nu = 1
        for i,j in enumerate(self.atomlist):
                v = np.dot(inv,j)
                if self.type[i] == '8':
                    atom = 'O'; mass = 15.999400
                elif self.type[i] == '22':
                    atom = 'Ti'; mass = 47.867000
                else:
                    atom = 'err'
                f.write(f'{nu} {v[0]:.06f} {v[1]:.06f} {v[2]:.06f} {atom} {mass}\n')
                nu += 1
        f.write(" END header\n")
        f.write('q-pt=    1    0.000000  0.000000  0.000000      1.0\n')
        l_low=[2483, 2788, 2485, 2484, 3117, 2754, 3035, 2752, 3858, 2750]
        l_low=[3036, 2482, 2485, 3859, 2484, 2509, 2753, 2751, 2533, 3115]
        # for i,j in enumerate(fre[:50]):
            # f.write(f'{i+1}  {j:.06f}\n')
        fre=np.array(u1)
        for i,j in enumerate(fre[l_low]):
            ff = j/abs(j)*np.sqrt(abs(j))
            f.write(f'{l_low[i]+1}  {ff:.06f}\n')
        f.write("                        Phonon Eigenvectors\n")
        f.write("Mode Ion                X                                   Y                                   Z\n")
        u11=u1.copy();w11=w1.copy()
        # for i in range(len(fre2)):
       
        # for i in range(50):
            # index = np.where(u11==fre2[i])[0][0]
        for i in l_low:
            index = i
            # u11=np.delete(u11,index,0)
            v=w11[:,index].reshape(-1,3)
            # w11 = np.delete(w11,index,1)
            for m,n in enumerate(v):
                f.write(f"   {i+1}   {m+1} {n[0]:.12f}  0.000000000000      {n[1]:.12f}  0.000000000000     {n[2]:.12f}  0.000000000000\n")
        f.close()
        print('Successful')
    def find_lincom(self,w2,mode,l):
        n1=self.nano_to_bulk(w2,mode,l)
        c1=np.dot(EIG,n1)
        index1 = np.argsort(-1*np.abs(c1))
        print('The first 10 coeff are:',index1[:10])
        print('The corresponding fre are:',u[index1[:10]])
        return index1

# sphere example
def test(pre,aa):
        print(aa)
        print('start reading data',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        a=4.6341768;c=2.9457560;n1 = int(np.ceil(aa*2/a));n3=int(np.ceil(aa*2/c))
        if n1 % 2 ==0:
            n1+=1
        if n3 %2==0:
            n3+=1
        # n_sc = np.array([[n1,0,0],[0,n1,0],[0,0,n3]])
        n_sc = np.array([[15,0,0],[0,15,0],[0,0,25]])
        p0=dyna(f'15.15.25.{pre}.castep',n_sc)
        # sc = np.dot(p0.n_sc,p0.lattice)
        shape=['ellipse',aa,aa,aa]
        print('start constructing sc DM',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        l,cube0=p0.sc_nano_dy(shape,clamping=False)
        print('finish reading data and constructing sc DM',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) )
        u1,w1=p0.eigenvector(l,cube0,aa)
        # index = np.argsort(u1)[0]
        # nano = p0.nano_to_bulk(w1,index,l)
        # coeff,fre=p0.pc_to_sc_eig2(nano)

def test2(aa):
    pass
import sys
r=int(sys.argv[1])
test(-7,r) 


