import numpy as np
import sympy as sp
import scipy as scp
import matplotlib.pyplot as plt
from sympy.combinatorics import IntegerPartition
import itertools as itools

#########################################################################
# operate on occupation basis
#########################################################################

def CheckMinMaxRules(part,q, mmax):
  #reject configurations which cannot be squeezed from a "root partition" (1 0 1 0 1 ... in the language of occupation numbers) (rejects some of such configs, not all)
  correct=True
  occ=np.zeros(mmax+1)
  for p in part:
    if(p<=mmax): # check if maximum momentum is not too much
      occ[p]+=1
    else:
      correct=False
  for i, o in enumerate(occ): #check if occupation numbers close to the beginning and end are not too much
    #print(i,o, i*2/q+1, (mmax-i)*2/q+1)
    if(o>i*2/q+1):
      correct=False
    if(o>(mmax-i)*2/q+1):
      correct=False
  #print(part, occ, correct)
  return correct 

def GenerateOccupationBasisMmax(Npart, q, DeltaL, mmax):
  # generates occupation basis with maximum occupied orbital mmax
  # the condigurations are generated as integer partitions of the total angular momentum
  mmax0=q*(Npart-1)
  L0=q*Npart*(Npart-1)/sp.Integer(2)
  L=L0+DeltaL
  mmax1=mmax0+DeltaL
  # first partition: L 0 0 0 ...
  part0=IntegerPartition([L])
  part2=list(np.concatenate([part0.partition, np.zeros(Npart-1, dtype=int)]))
  parts=[]
  if (CheckMinMaxRules(part2,q,mmax1))and(max(part2)<=mmax):
    parts.append(part2)
  # iterate through partitions
  part=part0.next_lex()
  while(part0!=part):
    part2=part.partition
    if(len(part2)<=Npart):
      part2=list(np.concatenate((part2, np.zeros(Npart-len(part2), dtype=int))))
      if (CheckMinMaxRules(part2,q, mmax1))and(max(part2)<=mmax): # check if it can be "squeezed" from the "root partition" and if maximum occupied orbital is at most mmax
        parts.append(part2)
    part=part.next_lex()
  basis=list(parts)
  return basis

def GenerateOccupationBasis(Npart, q, DeltaL):
  mmax0=q*(Npart-1)
  L0=q*Npart*(Npart-1)/sp.Integer(2)
  L=L0+DeltaL
  mmax=mmax0+DeltaL
  part0=IntegerPartition([L])
  part2=list(np.concatenate([part0.partition, np.zeros(Npart-1, dtype=int)]))
  parts=[]
  if (CheckMinMaxRules(part2,q,mmax)):
    parts.append(part2)
  part=part0.next_lex()
  while(part0!=part):
    part2=part.partition
    if(len(part2)<=Npart):
      part2=list(np.concatenate((part2, np.zeros(Npart-len(part2), dtype=int))))
      if (CheckMinMaxRules(part2,q, mmax)):
        parts.append(part2)
    part=part.next_lex()
  basis=list(parts)
  return basis

def FindIndexNaive(part, basis):
  ind=-1
  for i, b in enumerate(basis):
    if(b==part):
      ind=i
  if(ind<0):
    print("NOT FOUND")
  return ind


#########################################################################
# operate on symbolic expressions for Laughlin states
#########################################################################


# generate symbolic variables for particle coordinates z_i to be used by SymPy
def GenerateCoordinateSymbols(Npart):
  sstring=""
  for i in range(Npart):
    sstring+="z"+str(i)+" "
  symbols=sp.symbols(sstring)
  return symbols

def GetGroundStatePolynomial(symbols, q):
  Npart=len(symbols)
  psi=sp.Integer(1)
  for j in range(Npart):
    for i in range(j):
      psi*=(symbols[i]-symbols[j])**q
  return psi

#generate the symmetric polynomials
def GenerateSymmetricPolynomials(symbols, DeltaL):
  #generates all symmetric polynomials of symbols with total degree of DeltaL
  # this is done by generating all integer partitions of DeltaL which are the individual powers, keeping the ones whose length is less than the number of symbols
  # then creating a monomial with powers given by the partition
  Npart=len(symbols)
  sympols=[]
  if(DeltaL>0):
    part0=IntegerPartition(np.ones(DeltaL, dtype=int))
    part=part0.next_lex()
    parts=[part0.partition]
    while(part0!=part): # iterate through partitions
      part2=part.partition
      parts.append(part2)
      part=part.next_lex()
    for part in parts:  
      if(len(part)<=Npart):
        sympol=sp.Integer(0)
        # symmetrize the monomial with given powers
        for perm in itools.permutations(range(Npart), len(part)):
          term=sp.Integer(1)
          for iperm,p in enumerate(perm):
            term*=symbols[p]**part[iperm]
          sympol+=term
        sympols.append(sympol)
  else:
    sympols.append(sp.Integer(1))
  return sympols


#generate the symmetric polynomials with highest power being at most mmax
def GenerateSymmetricPolynomialsMmax(symbols, DeltaL, mmax):
  Npart=len(symbols)
  sympols=[]
  if(DeltaL>0):
    part0=IntegerPartition(np.ones(DeltaL, dtype=int))
    part=part0.next_lex()
    parts=[part0.partition]
    while(part0!=part):
      part2=part.partition
      parts.append(part2)
      part=part.next_lex()
    for part in parts:  
      if(len(part)<=Npart)and(np.max(part)<=mmax):
        sympol=sp.Integer(0)
        for perm in itools.permutations(range(Npart), len(part)):
          term=sp.Integer(1)
          for iperm,p in enumerate(perm):
            term*=symbols[p]**part[iperm]
          sympol+=term
        sympols.append(sympol)
  else:
    sympols.append(sp.Integer(1))
  return sympols



def GetWavefunctionCoefficients(psi, symbols,basis):
  # obtains wavefunction coefficients of a 2D Laughlin state in the second-quantized basis of orbital occupations 
  # to get the coefficients, we iterate over monomials making up the polynomial, and for each monomial we look at the set of powers
  # the same set of powers will appear many times (for example if there is z_1^2z_2 there will be also z_2^2 z_1, the set of powers [2,1])
  # so we check how many times it appears (set_mult) and what is the coefficient in front (set_coeff)
  # each set will correspond to a Fock state of orbitals
  # we also have to take into account the normalization of orbitals

  psi=sp.expand(psi) # expand the polynomial psi
  psiargs=psi.args
  degree_lists=[]
  coeffs=[]
  sets=[] # list of degree sets
  set_mult=[]
  set_coeff=[]
  # iterate over monomials making up the polynomial
  for iarg, arg in enumerate(psiargs):
    degree_list=[]
    for symbol in symbols:
      degree_list.append(sp.polys.polytools.degree(arg, symbol))
    coeff=sp.polys.polytools.LC(arg)
    coeffs.append(coeff)
    degree_lists.append(degree_list)
    newset=True
    set1=sorted(degree_list, reverse=True) # degree set (order doesn't matter) is represented as a sorted degree list 
    for iset, set0 in enumerate(sets):
      if(set0==set1):
        newset=False
        set_mult[iset]+=1
    # add the set to the list if it appears for the first time
    if(newset):
      sets.append(set1)
      set_mult.append(1)
      set_coeff.append(coeff)
  coeffs_asym=[]
  coeffs_float=[]
  for iset, set0 in enumerate(sets):
    prod=sp.Integer(1)
    for m in set0:
      prod*=sp.Integer(2)**m*sp.factorial(m) # include the normalization of orbitals
    prod=prod*set_mult[iset] # take into account the normalization of the Fock state
    prod=set_coeff[iset]*sp.sqrt(prod)
    coeffs_asym.append(prod)
    coeffs_float.append(float(prod))
  norm=np.sqrt(np.sum(np.abs(coeffs_float)**2))
  coeffs_float=coeffs_float/norm
  coeffs_float2=np.zeros(len(basis))
  for iset, set0 in enumerate(sets): # find the position of the set in the basis
    ind=FindIndexNaive(set0, basis)
    if(ind>=0):
      coeffs_float2[ind]=coeffs_float[iset]
  return coeffs_float2

def GetWavefunctionCoefficients1D(psi, symbols,basis):
  # obtains wavefunction coefficients of a 1D Laughlin state in the second-quantized basis of orbital occupations 
  psi=sp.expand(psi)
  psiargs=psi.args
  degree_lists=[]
  coeffs=[]
  sets=[]
  set_mult=[]
  set_coeff=[]
  for iarg, arg in enumerate(psiargs):
    degree_list=[]
    for symbol in symbols:
      degree_list.append(sp.polys.polytools.degree(arg, symbol))
    coeff=sp.polys.polytools.LC(arg)
    coeffs.append(coeff)
    degree_lists.append(degree_list)
    newset=True
    set1=sorted(degree_list, reverse=True)
    for iset, set0 in enumerate(sets):
      if(set0==set1):
        newset=False
        set_mult[iset]+=1
    if(newset):
      sets.append(set1)
      set_mult.append(1)
      set_coeff.append(coeff)
  coeffs_asym=[]
  coeffs_float=[]
  for iset, set0 in enumerate(sets):
    prod=sp.Integer(1)
    prod=prod*set_mult[iset]
    prod=set_coeff[iset]*sp.sqrt(prod)
    coeffs_asym.append(prod)
    coeffs_float.append(float(prod))
  norm=np.sqrt(np.sum(np.abs(coeffs_float)**2))
  coeffs_float=coeffs_float/norm
  coeffs_float2=np.zeros(len(basis))
  for iset, set0 in enumerate(sets):
    ind=FindIndexNaive(set0, basis)
    if(ind>=0):
      coeffs_float2[ind]=coeffs_float[iset]
  return coeffs_float2


def GetPrefactors(basis, mmax):
  prefs=[]
  for conf in basis:
    occ=np.zeros(mmax+1, dtype=int)
    for c in conf:
      occ[c]+=1
    pref=1
    for o in occ:
      pref*=1/np.sqrt(np.math.factorial(o))
    prefs.append(pref)
  return(prefs)



def GetFewLowestEdgeStates(Npart, q, DeltaLmax):
  # get coefficients of orthogonalized 2D Laughlin edge states for Delta L' up to DeltaLMax 
  # the function generates the basis as well
  mmax0=q*(Npart-1)
  L0=q*Npart*(Npart-1)/sp.Integer(2)
  print("mmax0:", mmax0, "L0:", L0)
  mmax=mmax0+DeltaLmax
  symbols=GenerateCoordinateSymbols(Npart)
  print(symbols)
  psi0=GetGroundStatePolynomial(symbols, q)
  allstates=[]
  for DeltaL in range(DeltaLmax+1):
    print("Delta L =", DeltaL)
    L=L0+DeltaL
    basis=GenerateOccupationBasis(Npart, q, DeltaL)
    print("basis:", basis)
    sympols=GenerateSymmetricPolynomials(symbols, DeltaL)
    print("Symmetric polynomials:", sympols)
    coeffs_all=[]
    for i, pol in enumerate(sympols):
      psi=psi0*pol
      coeffs=GetWavefunctionCoefficients(psi, symbols, basis)
      print(coeffs)
      coeffs_all.append(coeffs)
    OverlapMatrix(coeffs_all)
    Q,R=np.linalg.qr(np.transpose(coeffs_all))
    print(Q)
    allstates.append([basis,Q]) 
  return allstates


def GetFewLowestEdgeStatesMmax(Npart, q, DeltaLmax, mmax):
  # get coefficients of orthogonalized 2D Laughlin edge states for Delta L' up to DeltaLMax, with maximum occupied orbital mmax
  # the function generates the basis as well
  mmax0=q*(Npart-1)
  L0=q*Npart*(Npart-1)/sp.Integer(2)
  symbols=GenerateCoordinateSymbols(Npart)
  psi0=GetGroundStatePolynomial(symbols, q)
  allstates=[]
  for DeltaL in range(DeltaLmax+1):
    L=L0+DeltaL
    basis=GenerateOccupationBasisMmax(Npart, q, DeltaL, mmax) #generate basis
    sympols=GenerateSymmetricPolynomialsMmax(symbols, DeltaL, mmax-mmax0) #get symmetric polynomials
    coeffs_all=[] 
    for i, pol in enumerate(sympols):
      psi=psi0*pol # edge excitation = ground state polynomial * symmetric polynomial
      coeffs=GetWavefunctionCoefficients(psi, symbols, basis) # get second quantization wavefunction coefficients
      coeffs_all.append(coeffs)
    # orthohgonalize
    if(len(coeffs_all)>0):
      Q,R=np.linalg.qr(np.transpose(coeffs_all))
      allstates.append([basis,Q]) 
    else:
      allstates.append([np.array([[]]),np.array([[]])])
  return allstates


def GetFewLowestEdgeStatesMmax1D(Npart, q, DeltaLmax, mmax):
  # get coefficients of orthogonalized 1D Laughlin edge states for Delta L' up to DeltaLMax, with maximum occupied orbital mmax
  # the function generates the basis as well
  mmax0=q*(Npart-1)
  L0=q*Npart*(Npart-1)/sp.Integer(2)
  symbols=GenerateCoordinateSymbols(Npart)
  psi0=GetGroundStatePolynomial(symbols, q)
  allstates=[]
  for DeltaL in range(DeltaLmax+1):
    L=L0+DeltaL
    basis=GenerateOccupationBasisMmax(Npart, q, DeltaL, mmax)
    sympols=GenerateSymmetricPolynomialsMmax(symbols, DeltaL, mmax-mmax0)
    coeffs_all=[]
    for i, pol in enumerate(sympols):
      psi=psi0*pol
      coeffs=GetWavefunctionCoefficients1D(psi, symbols, basis)
      coeffs_all.append(coeffs)
    #orthogonalize
    if(len(coeffs_all)>0):
      Q,R=np.linalg.qr(np.transpose(coeffs_all))
      allstates.append([basis,Q]) 
    else:
      allstates.append([np.array([[]]),np.array([[]])])
  return allstates

#########################################################################################
# Other useful functions
#########################################################################################


def OverlapMatrix(vectors):
  nvec=len(vectors)
  S=np.zeros((nvec,nvec))
  for i in range(nvec):
    for j in range(nvec):
      S[i,j]=np.dot(vectors[i], vectors[j])
  print("Overlaps:")
  print(S)


if __name__ == "__main__":
  #GetFewLowestEdgeStates(3, 2, 5)
  #symbols=GenerateCoordinateSymbols(3)
  #psi0=GetGroundStatePolynomial(symbols, 2)
  #print(sp.expand(psi0))
  #basis=GenerateOccupationBasis(3, 2, 0)
  #print(basis)
  #GetWavefunctionCoefficients(psi0, symbols,basis)
  #sympols=GenerateSymmetricPolynomials(symbols, 5)
  #sympols=GenerateSymmetricPolynomialsMmax(symbols, 2, 1)
  #print(sympols)
  #print(len(sympols))
  #basis=GenerateOccupationBasis(4, 4, 4)
  GetFewLowestEdgeStatesMmax(2, 2, 6, 6)
  #print(basis)

"""
#compare density
dens=np.zeros(mmax+1)

for iset, set0 in enumerate(sets):
  occ=np.zeros(mmax+1)
  for s in set0:
    occ[s]+=1
  dens+=np.abs(coeffs_float[iset])**2*occ
  print(set0, occ,coeffs_float[iset],np.abs(coeffs_float[iset])**2, dens)

rs=np.arange(1, mmax+1.0001)

plt.plot(dens/rs)
"""

plt.show()

