import numpy as np
from scipy.sparse import csr_matrix #lil_matrix
import sys
import math

##################################################################################################################
# Functions to generate and operate on bases
##################################################################################################################

# compute co-lexicographical rank of a combination
def CoLexRank(conf):
   rsum=0
   for i, c in enumerate(conf):
     rsum=rsum+math.comb(c, i+1)
   return rsum

# generate combinations ordered co-lexicographically
def GenerateCombinations2(t, n, s):
    #The Art of Computer Programming, Volume 4, Fascicle 3: Generating All
    #Combinations and Partitions, Algorithm L
    #t - number of particles
    #n - number of states
    #s - number of configs
    #Initialize the algorithm
    configs=[]     #=np.zeros((s,t), dtype=int)
    oneconfig=[]
    for i in range(t):
       oneconfig.append(i)
    oneconfig.append(n)
    oneconfig.append(0)
    #First config
    configs.append(oneconfig[:t])
    for j in range(1,s):
       i=0
       while((oneconfig[i]+1)==oneconfig[i+1]):
          oneconfig[i]=i
          i=i+1
          if(i>=t):
             break
       oneconfig[i]=oneconfig[i]+1
       configs.append(list(oneconfig[:t]))
    return configs

def DecimalToBinary(dec, N):
  ter=np.zeros(N, dtype=int)
  for i in range(N-1, -1,-1):
    r=dec%2
    ter[i]=r
    dec=(dec-r)/2
  return ter

def BinaryToDecimal(ter):
  dec=0
  N=len(ter)
  for i in range(N-1, -1,-1):
    dec=dec+2**(N-i-1)*ter[i]
  return dec

#generate the basis
# that is, first generate the combinations, which describe which atoms are occupied
# then, take into account that each combinbation corresponds to 2**Npart configurations
# (ordered according to decimal representation of the corresponding binary number)
# because at each atom either - or + orbital can be occupied
def GenerateBasis(Npart, Nsite):
  Ncomb=math.comb(Nsite, Npart) 
  combs=GenerateCombinations2(Npart, Nsite, Ncomb) 
  confs=[]
  for icomb, comb in enumerate(combs):
    for dec in range(2**Npart):
      binary=DecimalToBinary(dec, Npart)
      conf=np.zeros(Nsite, dtype=int)
      for j in range(Npart):
        conf[comb[j]]=binary[j]+1
      confs.append(conf)
  return confs

# given a configuration, find its index in the basis
# use the co-lexicographical ordering to locate combination, and then binary numbers to locate the -/+ orbital occupations
def FindIndex(conf, Npart):
  comb=[]
  binary=[]
  for i, c in enumerate(conf):
    if(c>0):
      comb.append(i)
      binary.append(c-1)
  ind1=CoLexRank(comb)
  ind2=BinaryToDecimal(binary)
  return ind1*2**Npart+ind2


# shift a config in the angular direction (i.e. rotate config)
def MoveConfigAngular(conf, dL, Npart, Lmax):
  newind=[]
  Nsite=len(conf)
  for i in range(Nsite):
    inew=(i-dL*Nsite//Lmax)%Nsite #(minus because otherwise it translates to the left)
    newind.append(inew)
  conf2=conf[newind]
  #print("conf:", conf, "conf2:", conf2)
  return conf2

# create a shift table (i.e. what configuration j we get if we shift the configuration i by one sector)
def CreateShiftTableAngular(basis, Npart, Lmax):
  tab=np.zeros(len(basis), dtype=int)
  for i,conf in enumerate(basis):
   conf2=MoveConfigAngular(conf, 1, Npart, Lmax)
   tab[i]=FindIndex(conf2, Npart)
  return tab

# print the "orbit" (configurations transforming into each other by shifting) with corresponding phases
def PrintAngularOrbitWithPhase(orbit,phases,basis, Lmax):
  Ns=len(basis[0])
  string=""
  for i,o in enumerate(orbit):
    if(i>0):
      string+="+"
    string+=str(phases[o])+"*|"
    Nsec=Ns//Lmax
    for j in range(Lmax):
      for k in range(j*Nsec, (j+1)*Nsec):
        string+=str(basis[orbit[i]][k])
  print(string)


# print the whole angular momentum basis
def PrintAngularBasisBySectorWithPhase(mombasis, basis, Lmax):
  OrbitsPerMomentum,lin3,phase, norm=mombasis
  for L in range(Lmax):
    print("L=",L)
    for i, orbit in enumerate(OrbitsPerMomentum[L]):
      sys.stdout.write(str(i)+" ")
      PrintAngularOrbitWithPhase(orbit,phase[L,:],basis, Lmax)
      

# creates the angular-momentum-conserving basis
def FindAngularMomentumEigenstates(basis, Npart, Lmax):
  # returns:
  # OrbitsPerMomentum - OrbitsPerMomentum[k][iorb] is the iorb'th "orbit" (set of configurations which transform into each other by rotation) at angular momentum k.
  # At a given k, each orbit corresponds to a single eigenstate of angular momentum. 
  # OrbitsPerMomentum[k][iorb] is a list, and each element is an index of the configuration in the site basis ("basis" argument). Each config appears at most in one orbit per k
  # lin3[k, iconf] is the index of orbit ho which config iconf belongs to at a given angular momentum k
  # phase[k, iconf] is the phase (in the units of 2pi/Lmax) of config iconf in the orbit it belongs to at a given angular momentum k
  # norm[k, iconf] is the norm of config iconf in the orbit it belongs to at a given angular momentum k
  # i.e. the coefficient of config iconf in the linear combination is sqrt(norm[k, iconf])e^(2pi*i*phase[k,iconf]/Lmax)  

  Nsite=len(basis[0])
  Nconf=len(basis)
  stab_ang=CreateShiftTableAngular(basis, Npart, Lmax)
  free=np.ones(len(basis), dtype=int)
  OrbitsPerMomentum=[]
  for j in range(Lmax):
    OrbitsPerMomentum.append(list([]))
  lin3=np.ones((Lmax,Nconf), dtype=int)*(-1) 
  phase=np.ones((Lmax,Nconf), dtype=int)*(-1)
  norm=np.ones((Lmax,Nconf), dtype=int)*(-1)
  counts=np.zeros(Lmax, dtype=int)
  for i,conf in enumerate(basis):
   if(free[i]==1): # if the config was not included in any orbit before
    start=i
    free[i]=0 
    orbit=[i]
    current=i
    next=stab_ang[i]
    # Rotate the config until we go back to initial one ("start"), add to orbit
    while (next!=start):
      current=next
      free[current]=0
      next=stab_ang[current]
      orbit.append(current)
    olen=len(orbit)
    # iterate over the allowed values of angular momentum for the orbit (note that if orbit is shorter than Lmax then some angular momenta will not be allowed)
    for k0 in range(olen):
      k=k0*Lmax//olen
      k2=k
      OrbitsPerMomentum[k2].append(orbit)
      for i,o in enumerate(orbit):
        lin3[k2,o]=counts[k2]
        phase[k2,o]=-i*k
        norm[k2,o]=olen
      counts[k2]=counts[k2]+1
  return OrbitsPerMomentum, lin3,phase, norm

##################################################################################################################
# Functions to create Hamiltonians
##################################################################################################################


# compute the action of onsite terms on a configuration
def DiagonalElement(conf, interactions):
  elem=0
  for k,interaction in enumerate(interactions):
    if(interaction[0]=="onsite"): 
      ind1=interaction[1]
      sigma1=interaction[2]
      if(conf[ind1]==sigma1+1):
        elem=elem+interaction[3]
  return elem


#compute the action of a hopping term on a configuration
def BosonicHoppingPM(conf, interaction):
  # annihilates the particle at (i2, sigma2) and creates at (i1, sigma1)
  elem=0
  i1=interaction[1]
  i2=interaction[2]
  sigma1=interaction[3]
  sigma2=interaction[4]
  elem=interaction[5]
  if((conf[i1]==0)or(i1==i2))and(conf[i2]==sigma2+1):
    conf2=list(conf) 
    conf2[i2]=0
    conf2[i1]=sigma1+1
    return elem, conf2
  else:
    return None


# generate a Hamiltonian based on "interactions" in a basis with no angular momentum conservation
# iterates over the configurations conf, computes the action of the Hamiltonian terms on that configuration obtaining configuration conf2,
# finds index of conf2 in the basis and saves the corresponding matrix elements to a sparse matrix
def GenerateSparseHamiltonianNoMomentumCoLex(interactions,basis, Ns, Npart):
  Nconf=len(basis)
  H=[]
  ilist=[]
  jlist=[]
  count=0
  for i,conf in enumerate(basis):
    #print("i=", i,"conf:", conf)
    ilist.append(count) 
    temprow=np.zeros(Nconf, dtype=complex) # the elements corresponding to a given row are first stored in a "temprow" array so that if two terms produce the same configuration they can be conveniently summed
    visited=np.zeros(Nconf, dtype=bool)
    elem=DiagonalElement(conf, interactions)
    temprow[i]=temprow[i]+elem
    visited[i]=True
    for interaction in interactions:
     if (interaction[0]=="hopping"):
       el=BosonicHoppingPM(conf, interaction)
       if el is not None:
         elem,conf2=el 
         inew=FindIndex(conf2, Npart)
         temprow[inew]=temprow[inew]+elem
         visited[inew]=True
    #put "visited" elements of temprow into the sparse Hamiltonian matrix
    for inew,elem in enumerate(temprow):
      if(visited[inew]):
        H.append(elem)
        jlist.append(inew)
        count=count+1
  ilist.append(count)
  H=csr_matrix((H, jlist, ilist), shape=(Nconf,Nconf))
  # In H, row i describes the action of all the Hamiltonian terms on state |i>. In the actual Hamiltonian, this should be described by a column. Therefore, the result should be transposed
  H=H.transpose()
  H=H.tocsr()
  return H


# generate a Hamiltonian based on "interactions" in a basis with angular momentum conservation
# conceptually similar to GenerateSparseHamiltonianNoMomentumCoLex), except that now the basis are the orbits (linear combinations of configs)
# so it has to split the orbit into configs conf, act with the Hamiltonian term at each  conf producing conf2, find which orbit conf2 belongs to,
# and take into account all the coefficients (phase, norm) of the linear combination
def GenerateCRSHamiltonianWithAngularMomentum(L,interactions,basis,Lmax, mombasis):
  H0=[]
  ind0=[]
  Npart=sum(basis[0])
  OrbitsPerMomentum,lin3,phase, norm=mombasis
  for iorbit, orbit in enumerate(OrbitsPerMomentum[L]): #iterate over orbits (angular momentum basis)
    hrow=np.zeros(len(OrbitsPerMomentum[L]), dtype=complex)
    for i, o in enumerate(orbit): # iterate over orbits
        norm1=len(orbit)
        phase1=np.exp(2j*np.pi*float(phase[L,o])/float(Lmax)) # get the phase of the config in the linear combination
        conf=basis[o] # get the config (the "orbit" stores only the indices of configs in the full (no angular momentum conservation) basis)
        elem=DiagonalElement(conf,interactions) # act with onsite terms
        hrow[iorbit]=hrow[iorbit]+elem/float(norm1)
        # act with hopping terms
        for interaction in interactions:
         if (interaction[0]=="hopping"):
          el=BosonicHoppingPM(conf, interaction) # get the matrix element and the resulting config conf2
          if el is not None:
           elem,conf2=el 
           onew=FindIndex(conf2, Npart) # find the index of conf2 in the full (no angular momentum conservation) basis
           iorbitnew=lin3[L,onew]  # find which orbit does conf2 belong to
           phase2=np.exp(-2j*np.pi*float(phase[L,onew])/float(Lmax)) # get the phase of conf2 in the linear combination
           norm2=norm[L, onew]
           if (iorbitnew>=0):
 #                print "elem:", elem, elem/np.sqrt(norm1*norm2), phase1*phase2*elem/np.sqrt(norm1*norm2)
             elem=phase1*phase2*elem/np.sqrt(norm1*norm2) #the total matrix element, including phase and norm of each config in the linear combination
             hrow[iorbitnew]=hrow[iorbitnew]+elem
    # add the row to the Hamiltonian stored in the "list of lists" format
    thres=0.000000000000001
    H0list=[]
    jlist=[]
    for ilist, elist in enumerate(hrow):
      if (abs(elist)>thres):
        H0list.append(elist)
        jlist.append(ilist)
    H0.append(H0list)
    ind0.append(jlist)
  # convert the "list of lists" to CRS format
  H1=[]
  ilist1=[]
  jlist1=[]
  count=0
  for i, Hlist in enumerate(H0):
    ilist1.append(count)
    for j, elem in enumerate(Hlist):
      H1.append(elem)
      jlist1.append(ind0[i][j])
      count=count+1
  ilist1.append(count)
  H=csr_matrix((H1, jlist1, ilist1), shape=(len(OrbitsPerMomentum[L]),len(OrbitsPerMomentum[L])))
  # In H, row i describes the action of all the Hamiltonian terms on state |i>. In the actual Hamiltonian, this should be described by a column. Therefore, the result should be transposed
  H=H.transpose()
  H=H.tocsr()
  return H

# Green's function of the simplified model
def GreenRS2Simplified(rvec, k):
  # 0=-, 1=+
  x,y=rvec
  r=np.sqrt(x**2+y**2)
  phi=np.arctan2(y,x)
  G=np.zeros((2,2), dtype=complex)
  prefactor=1./(4*np.pi*r)
  term1=-1/(k*r)**2
  term2=3/(k*r)**2
  for i in range(2):
    G[i,i]=G[i,i]+term1+term2/2.
  G[0,1]=term2*np.exp(2j*phi)/2.
  G[1,0]=term2*np.exp(-2j*phi)/2.
  G=G*prefactor
  return G

def GreenRS2Full(rvec, k):
  # 0=-, 1=+
  x,y=rvec
  r=np.sqrt(x**2+y**2)
  phi=np.arctan2(y,x)
  G=np.zeros((2,2), dtype=complex)
  prefactor=np.exp(1j*k*r)/(4*np.pi*r)
  term1=1+1j/(k*r)-1/(k*r)**2
  term2=-1-3j/(k*r)+3/(k*r)**2
  for i in range(2):
    G[i,i]=G[i,i]+term1+term2/2.
  G[0,1]=term2*np.exp(2j*phi)/2.
  G[1,0]=term2*np.exp(-2j*phi)/2.
  G=G*prefactor
  return G

# intteractions (hopping and onsite terms) corresponding to the full Hamiltonian, with rotation-invariance-restoring transform
def RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax, hermitian):
  Nsite=pos.shape[0]
  Nsector=Nsite//Lmax
  interactions=[]
  for i in range(Nsite):
    pos1=pos[i,:]
    L1=i//Nsector
    phases1=[np.exp(2j*np.pi*L1/Lmax), np.exp(-2j*np.pi*L1/Lmax)] # phases due to the transformation
    for j in range(Nsite):
      if(i!=j):
        L2=j//Nsector
        phases2=[np.exp(2j*np.pi*L2/Lmax), np.exp(-2j*np.pi*L2/Lmax)] # phases due to the transformation
        pos2=pos[j,:]
        rvec=pos2-pos1
        elem=np.zeros((2,2), dtype=complex)
        r=np.sqrt(rvec[0]**2+rvec[1]**2)
        if(r<rtrunc)or(rtrunc<0):
          if(hermitian):
            G1=GreenRS2Full(rvec, k0)
            G2=GreenRS2Full(-rvec, k0)
            G=0.5*(G1+np.conj(np.transpose(G2)))
          else:
            G=GreenRS2Full(rvec, k0)
          for dir1 in range(2):
            for dir2 in range(2):
              interactions.append(["hopping", i, j, dir1,dir2, -3*np.pi/k0*G[dir1,dir2]*np.conj(phases1[dir1])*phases2[dir2]])
  for i in range(Nsite):
    pos1=pos[i,:]
    x,y=pos1
    if(hermitian):
      add=0
    else:
      add=-0.5j
    interactions.append(["onsite", i, 0, Vquad*(x**2+y**2)-B+add])
    interactions.append(["onsite", i, 1, Vquad*(x**2+y**2)+B+add])
  return interactions


# intteractions (hopping and onsite terms) corresponding to the simplified Hamiltonian, with rotation-invariance-restoring transform
def RealDipolarInteractionFiniteCircularSimplifiedTransform(pos, B, Vquad,k0, rtrunc, Lmax):
  Nsite=pos.shape[0]
  Nsector=Nsite//Lmax
  interactions=[]
  for i in range(Nsite):
    pos1=pos[i,:]
    L1=i//Nsector
    phases1=[np.exp(2j*np.pi*L1/Lmax), -np.exp(2j*np.pi*L1/Lmax)]
    for j in range(Nsite):
      if(i!=j):
        L2=j//Nsector
        phases2=[np.exp(2j*np.pi*L2/Lmax), -np.exp(2j*np.pi*L2/Lmax)]
        pos2=pos[j,:]
        rvec=pos2-pos1
        elem=np.zeros((2,2), dtype=complex)
        r=np.sqrt(rvec[0]**2+rvec[1]**2)
        if(r<rtrunc)or(rtrunc<0):
          G=GreenRS2Simplified(rvec, k0)
          for dir1 in range(2):
            for dir2 in range(2):
              interactions.append(["hopping", i, j, dir1,dir2, -3*np.pi/k0*G[dir1,dir2]*np.conj(phases1[dir1])*phases2[dir2]])
  for i in range(Nsite):
    pos1=pos[i,:]
    x,y=pos1
    interactions.append(["onsite", i, 0, Vquad*(x**2+y**2)-B])
    interactions.append(["onsite", i, 1, Vquad*(x**2+y**2)+B])
  return interactions



# intteractions (hopping and onsite terms) corresponding to anti-Hermitian  full Hamiltonian, with rotation-invariance-restoring transform
def AntiHermitianInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax):
  Nsite=pos.shape[0]
  Nsector=Nsite//Lmax
  interactions=[]
  for i in range(Nsite):
    pos1=pos[i,:]
    L1=i//Nsector
    phases1=[np.exp(2j*np.pi*L1/Lmax), np.exp(-2j*np.pi*L1/Lmax)]
    #phases1=[1,1]
    for j in range(Nsite):
      if(i!=j):
        #included=False
        L2=j//Nsector
        phases2=[np.exp(2j*np.pi*L2/Lmax), np.exp(-2j*np.pi*L2/Lmax)]
        #phases2=[1,1]
        pos2=pos[j,:]
        rvec=pos2-pos1
        elem=np.zeros((2,2), dtype=complex)
        r=np.sqrt(rvec[0]**2+rvec[1]**2)
        if(r<rtrunc)or(rtrunc<0):
          #print(rvec)
          G1=GreenRS2Full(rvec, k0)
          G2=GreenRS2Full(-rvec, k0)
          G=0.5*(G1-np.conj(np.transpose(G2)))
          for dir1 in range(2):
            for dir2 in range(2):
              interactions.append(["hopping", i, j, dir1,dir2, -3*np.pi/k0*G[dir1,dir2]*np.conj(phases1[dir1])*phases2[dir2]])
  for i in range(Nsite):
    pos1=pos[i,:]
    x,y=pos1
    interactions.append(["onsite", i, 0, -0.5j])
    interactions.append(["onsite", i, 1, -0.5j])
  return interactions

##################################################################################################################
# Functions to generate lattice positions
##################################################################################################################

# check if a site is inside a hexagon of size D
def CheckIfInside(D,pos):
   x,y=pos
   check=True
   if(y<-D):
     check=False
   if(y>D):
     check=False 
   y2=x*np.sqrt(3)-2*D
   if(y<y2):
     check=False 
   y2=-x*np.sqrt(3)+2*D
   if(y>y2):
     check=False 
   y2=-x*np.sqrt(3)-2*D
   if(y<y2):
     check=False 
   y2=x*np.sqrt(3)+2*D
   if(y>y2):
     check=False 
   return check

# get all the lattice sites within a hexagon
def GetLatticePoints(D0, latvec,orb, d):
  extent=int(4./3.*D0+1)
  D=D0*d
  print(extent)
  xs=[]
  ys=[]
  v1=np.array(latvec[0])
  v2=np.array(latvec[1])
  for i in range(-extent, extent+1):
    for j in range(-extent, extent+1):
      for o in orb:
        pos=(i+o[0])*v1+(j+o[1])*v2
        if CheckIfInside(D,pos):
          xs.append(pos[0])
          ys.append(pos[1])
  return xs,ys

# get all the lattice sites within "sector 0" of the hexagon
def GetLatticePointsSegment(D0, latvec, orb,d):
  extent=int(4./3.*D0+1)
  print(extent)
  D=D0*d
  xs=[]
  ys=[]
  v1=np.array(latvec[0])
  v2=np.array(latvec[1])
  for i in range(extent+1):
    for j in range(extent+1):
      for o in orb:
        pos=(i+o[0])*v1+(j+o[1])*v2
        if CheckIfInside(D,pos):
          xs.append(pos[0])
          ys.append(pos[1])
  return xs,ys

# rotate a single site by n times 2pi/Lmax
def Rotate(pos, n, Lmax):
  x,y=pos
  alpha=2*np.pi/Lmax*n
  return np.array([x*np.cos(alpha)-y*np.sin(alpha),x*np.sin(alpha)+y*np.cos(alpha)])

# rotate all the sites of sector 0
def RotateAll(xs,ys, n, Lmax):
  xres=[]
  yres=[]
  for i,x in enumerate(xs):
    y=ys[i]
    xnew, ynew=Rotate([x,y], n, Lmax)
    xres.append(xnew)
    yres.append(ynew)
  return xres, yres 

##################################################################################################################
# Functions to operate on eigenstates, measure expectation values etc.
##################################################################################################################

# convert eigenstates from angular momentum basis to site basis
def ConvertEigenstates(evecs, basis, mombasis, L,Lmax):
  nmom, neval=evecs.shape
  Nconf=len(basis)
  evecs2=np.zeros((Nconf, neval), dtype=complex)
  OrbitsPerMomentum,lin3,phase, norm=mombasis
  for i, orbit in enumerate(OrbitsPerMomentum[L]):
    for o in orbit:
      phase0=np.exp(2j*np.pi*float(phase[L,o])/float(Lmax)) 
      norm0=norm[L, o]
      evecs2[o,:]+=evecs[i,:]*phase0/np.sqrt(norm0)
  return evecs2

# compute the generalized overlap (sum of squared overlaps with all states in the matrix Q)
def GeneralizedOverlap(state, Q):
  overlap=0
  for i in range(Q.shape[1]):
    state2=Q[:,i]
    overlap0=np.dot(np.conj(state), state2)
    #print("overlap0:" overlap0)
    overlap+=np.abs(overlap0)**2
  return overlap

#compute particle density at each site
def ParticleDensities(pos, sitestates, basis):
  N=len(pos)
  nconf, neval=sitestates.shape
  dens=np.zeros((N,2, neval))
  for ieval in range(neval):
    for i, conf in enumerate(basis):
      for j, c in enumerate(conf):
        if(c>0):
          dens[j,c-1, ieval]+=abs(sitestates[i, ieval])**2
  return dens        

#compute radial density (return the radii and densities at the radii)
def RadialDensities(pos, sitestates, basis):
  dens=ParticleDensities(pos, sitestates, basis)
  dens=dens[:,0,:]+dens[:,1,:]
  #print(dens)
  posdel=0.00001
  N=len(pos)
  nconf, neval=sitestates.shape
  ind1d=np.zeros(N, dtype=int)
  rs=[]
  for i,p1 in enumerate(pos):
    add=True
    r1=np.sqrt(p1[0]**2+p1[1]**2 )
    iadd=len(rs)
    for j, r2 in enumerate(rs):
      if(np.abs(r2-r1)<posdel):
        add=False
        iadd=j
    ind1d[i]=iadd
    if(add):
      rs.append(r1)
  Nr=len(rs)
  densrad=np.zeros((Nr, neval))
  for i in range(N):
    for j in range(neval):
      densrad[ind1d[i],j]+=dens[i,j]
  return rs, densrad

# undo the rotational-invariance-restoring transformation
def OriginalPhases(evecs, basis, Lmax, Nsite):
  Nconf, neval=evecs.shape
  Nsector=Nsite//Lmax
  evecs2=np.zeros((Nconf, neval), dtype=complex) 
  for iconf,conf in enumerate(basis):
    phase=1
    for i,c in enumerate(conf):
      L1=i//Nsector   
      if(c==1):
        phase*=np.exp(2j*np.pi*L1/Lmax)
      if(c==2):
        phase*=np.exp(-2j*np.pi*L1/Lmax)
    evecs2[iconf,:]=phase*evecs[iconf,:] 
  return evecs2

# compute the vector b_1^\dagger b_2^\dagger|0>, where b_1 and b_2 are defined by orb 1 and orb 2
# "basis" is the site basis
def TwoParticleAmplitudeVector2part(orb1, orb2, basis):
  phi1=[]
  phi2=[]
  for occ1 in basis:
    #print(occ1)
    #phi10=1
    #phi20=1
    conf=[]
    for i,o in enumerate(occ1):
      if(o>0):
        ind=2*i+o-1
        conf.append(ind)
    phi10=orb1[conf[0]]*orb2[conf[1]]+orb1[conf[1]]*orb2[conf[0]]
    phi1.append(phi10)
  return phi1 

