import numpy as np
from scipy.sparse.linalg import eigs
import matplotlib.pyplot as plt
import sys
import os
import inspect
import math
from itertools import permutations, combinations_with_replacement, combinations

# import ED functions
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 
import NonHermitianAngularMomentumED as ed
import sympy_wavefunctions as sympywfn

##################################################################################################
##  Start calculations
##################################################################################################

lambda0=1 #atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
hermitian=True # consider the Hermitian part or the whole non-Hermitian Hamiltonian
simplify=False # consider a simplified model or the full model
D0=2.1 #system size
Lmax=6 # Lmax-fold rotational symmetry
Npart=3 #number of particles (keep 1 for this phase diagram)
rtrunc=-1 #truncation radius of interaction (-1 if no truncation)
B0=12 # rescaled magnetic field
q=2 # Laughlin state at 1/q

orb=[[1./3.,1./3.], [2./3.,2./3.]] # positions of atoms in the unit cell

mmax=q*(Npart-1) # maximum orbital m making up the Laughlin ground state
L0=(q*Npart*(Npart-1)//2)%Lmax # array angular momentum L of the Laughlin ground state

# ranges of d0 and Vharm0 for the x and y axes, respectively
d0s=np.linspace(0.01,0.5, 20, endpoint=True)
nd0s=len(d0s)
Vquad0s=np.linspace(0,10,20, endpoint=True)
nVquad0s=len(Vquad0s)

overlaps_pd=np.zeros((nd0s, nVquad0s)) # empty matrix to be filled with overlaps
norms_pd=np.zeros((nd0s, nVquad0s)) # empty matrix to be filled with norms (i.e. projections of the true ground state into the subspace of orbital configurations making up the Laughlin state)


#generate site position just to know how many are there
latvec=[[np.sqrt(3)/2.,-0.5], [np.sqrt(3)/2.,0.5] ]
xs,ys=ed.GetLatticePointsSegment(D0, latvec, orb,1)
Nseg=len(xs)
Nsite=Nseg*Lmax

basisSP=ed.GenerateBasis(1, Nsite) # single-particle site basis
mombasisSP = ed.FindAngularMomentumEigenstates(basisSP, 1, Lmax) # single-particle angular momentum basis

basis=ed.GenerateBasis(Npart, Nsite)# single-particle site basis
mombasis = ed.FindAngularMomentumEigenstates(basis, Npart, Lmax)# single-particle angular momentum basis

for id0, d0 in enumerate(d0s):
 for iVquad0, Vquad0 in enumerate(Vquad0s):
  print(id0, iVquad0)
  d=d0*lambda0 # lattice constant
  Vquad=Vquad0/d**5/k0**3 #non-rescaled confining potential
  B=B0/d**3/k0**3 # non-rescaled magnetic field
  D=D0*d # system size
  latvec=[[d*np.sqrt(3)/2.,-d*0.5], [d*np.sqrt(3)/2.,d*0.5] ] #lattice vectors
  # generate site positions
  xs,ys=ed.GetLatticePointsSegment(D0, latvec, orb,d)
  Nseg=len(xs)
  Nsite=Nseg*Lmax
  xall=[]
  yall=[]
  for i in range(Lmax):
    xres,yres=ed.RotateAll(xs,ys, i, Lmax)
    xall.extend(xres)
    yall.extend(yres)
  pos=np.transpose(np.array([xall,yall]))
  #generate the interactions defining the Hamiltonian
  if(simplify):
    interactions=ed.RealDipolarInteractionFiniteCircularSimplifiedTransform(pos, B, Vquad,k0, rtrunc, Lmax)
  else:
    interactions=ed.RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax,hermitian)
  ################################################################################
  ##  Find single-particle eigenstates
  ################################################################################
  evecssiteSPL=[]
  neval=6
  densL=[]
  for L in range(Lmax):
    H=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions,basisSP,Lmax, mombasisSP)
    Hdense=H.todense()
    eval, evec=np.linalg.eig(Hdense)
    arg=np.argsort(np.real(eval))
    eval=eval[arg]
    evec=evec[:,arg]
    evec=np.array(evec)
    evecsite=ed.ConvertEigenstates(evec, basisSP, mombasisSP, L,Lmax)
    evecssiteSPL.append(evecsite)
    rs, dens=ed.RadialDensities(pos, evecsite, basisSP)
    densL.append(dens)
  #compute <r>
  rmeans=[]
  nst=2
  k=0
  for i in range(nst):
   for L in range(Lmax):
    dens0=densL[L][:,i]
    rmean=sum(np.array(rs)*dens0)
    rmeans.append(rmean)
    k+=1
  #check the LLL-like condition (<r> increasing with m)
  rmeans_relevant=rmeans[:mmax+1]
  landau_condition=True
  for i in range(mmax):
    if(rmeans_relevant[i]>=rmeans_relevant[i+1]):
      landau_condition=False
  if(landau_condition):
    ################################################################################
    ##  Find many-particle eigenstates (only for angular momentum L0)
    ################################################################################
    neval=10
    H=ed.GenerateCRSHamiltonianWithAngularMomentum(L0,interactions,basis,Lmax, mombasis)
    if (len(mombasis[0][L0])<10): # use dense diagonalization if Hilbert space is small enough
      Hdense=H.todense()
      eval, evec=np.linalg.eig(Hdense)
      arg=np.argsort(np.real(eval))
      eval=eval[arg]
      evec=evec[:,arg]
      evec=np.array(evec)
      evecsite=ed.ConvertEigenstates(evec, basis, mombasis, L0,Lmax)
    else: # otherwise use sparse diagonalization
      eval, evec=eigs(H, k=min(len(mombasis[0][L0])-1,neval), which = 'SR', return_eigenvectors=True)
      arg=np.argsort(np.real(eval))
      eval=eval[arg]
      evec=evec[:,arg]
      evecsite=ed.ConvertEigenstates(evec, basis, mombasis, L0,Lmax)
    ################################################################################
    ##  Calculate overlaps with model Laughlin states
    ################################################################################ 
    laughlin_states=sympywfn.GetFewLowestEdgeStatesMmax(Npart, q, Lmax, mmax)# get the model Laughlin states and the basis of orbital occupations
    # get the relevant orbitals and arrange them according to continuum angular momentum
    Nsub=int(np.ceil((mmax+1)/Lmax))
    evecsSP_ordered=[]
    for isub in range(Nsub):
      for L in range(Lmax): 
        evecsSP_ordered.append(evecssiteSPL[L][:,isub])
    overlaps_all=[]
    basis_laughlin, Q=laughlin_states[0]# get orbital occupation basis for the Laughlin ground state and the Laughlin ground state itself (Q)
    target_state=evecsite[:,0] # target state is the ED ground state at L=L0 
    prefactors=sympywfn.GetPrefactors(basis_laughlin, mmax) #normalization prefactors coming from two or more bosons at the same orbital
    target_state_LL=np.zeros(len(basis_laughlin), dtype="complex") # target state in the basis of orbital occupations
    # ------ convert the ED eigenstate to basis_laughlin. compute the overlap
    for il, bl in enumerate(basis_laughlin):
      for i,conf in enumerate(basis):
        conf2=[]
        for ic,c in enumerate(conf):
          if (c>0):
            conf2.append(2*ic+c-1)
        for perm in permutations(bl):
          prod=prefactors[il]
          for j,p in enumerate(perm):
            prod*=np.conj(evecsSP_ordered[p][conf2[j]])
          target_state_LL[il]+=prod*target_state[i]
    overlap=ed.GeneralizedOverlap(target_state_LL, Q)
    norm=np.dot(np.conj(target_state_LL), target_state_LL) #compute the norm (does not have to be 1 because laughlin_basis covers just a subspace of the total Hilbert space)
    overlaps_pd[id0, iVquad0]=overlap # add overlap to the map
    norms_pd[id0, iVquad0]=norm # add norm to the map
  else: # put NaN as a result if the LLL-like condition is not fulfilled
    overlaps_pd[id0, iVquad0]=float("nan")
    norms_pd[id0, iVquad0]=float("nan")

####################################################################################
# plot data
###################################################################################
dVquad0=Vquad0s[1]-Vquad0s[0]
Vquad0max=np.max(Vquad0s)
Vquad0min=np.min(Vquad0s)
dd0=d0s[1]-d0s[0]
d0max=np.max(d0s)
d0min=np.min(d0s)

plt.figure()
plt.imshow(np.transpose(overlaps_pd), extent=[d0min-dd0/2., d0max+dd0/2., Vquad0min-dVquad0/2., Vquad0max+dVquad0/2.], origin="lower")
plt.colorbar()
ax=plt.gca()
ax.set_aspect("auto")
plt.xlabel(r"$d_0$")
plt.ylabel(r"$V_{harm0}$")

plt.figure()
plt.imshow(np.transpose(norms_pd), extent=[d0min-dd0/2., d0max+dd0/2., Vquad0min-dVquad0/2., Vquad0max+dVquad0/2.], origin="lower")
plt.colorbar()
ax=plt.gca()
ax.set_aspect("auto")
plt.xlabel(r"$d_0$")
plt.ylabel(r"$V_{harm0}$")


####################################################################################
# save data to files
###################################################################################
# the files are 2d arrays of overlaps / norms, with information about the limits and steps of the data in the axes contained in the header
header="{} {} {} {} {} {}".format(d0min, d0max, dd0, Vquad0min, Vquad0max, dVquad0)
np.savetxt("phasediagram_gsoverlap_OBC_Npart={}_herm{}_D0={:.4f}_B0={:.4f}_rtrunc={:.4f}.txt".format(Npart, hermitian, D0,B0,rtrunc),overlaps_pd, header=header)

np.savetxt("phasediagram_gsnorm_OBC_Npart={}_herm{}_D0={:.4f}_B0={:.4f}_rtrunc={:.4f}.txt".format(Npart, hermitian, D0,B0,rtrunc),norms_pd, header=header)

plt.show()
