import numpy as np
from scipy.sparse.linalg import eigsh, eigs
import matplotlib.pyplot as plt
import sys
import os
import inspect
import itertools as itools
from itertools import permutations, combinations_with_replacement, combinations

# import ED functions
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 
import sympy_wavefunctions as sympywfn
import NonHermitianAngularMomentumED as ed

##################################################################################################
##  Start calculations
##################################################################################################

# ---------------- system parameters ---------
lambda0=1 #atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
d0=0.5  # dimensionless lattice constant
d=d0*lambda0
hermitian=True # if diagonalize the Hermitian part of the Hamiltonian (keep True otherwise the orbitals will be nonorthogonal) 
D0=0.6 # system size 
D=D0*d
simplify=False #use simplified model instead of the full model
Vquad0=0 # rescaled confining potential strength (not useful for D0=0.6)
Lmax=6 # Lmax-fold rotational symmetry
Npart=2 #number of particles
rtrunc=-1 #truncation radius of interaction (-1 if no truncation)
ComputeOverlaps=True # compute overlaps with model 1D Laughlin states
ComputeAntihermitian=True  # compute decay rates as expectation value of antihermitian part of the Hamiltonian
ComputeRadialDensity=False # compute radial density (not useful for D0=0.6)
mmax=5 #cutoff in orbital space
B0=12 # rescaled magnetic field



B=B0/d**3/k0**3 # non-rescaled magnetic field
Vquad=Vquad0/d**5/k0**3 #non-rescaled confining potential strength
latvec=[[d*np.sqrt(3)/2.,-d*0.5], [d*np.sqrt(3)/2.,d*0.5] ]
V=0
q=2

orb=[[1./3.,1./3.], [2./3.,2./3.]] # positions of atoms in the unit cell

# ----------- get lattice positions (and plot them) ----------
xs,ys=ed.GetLatticePointsSegment(D0, latvec, orb,d) # positions of sites in sector 0
Nseg=len(xs)
Nsite=Nseg*Lmax
xall=[]
yall=[]
plt.figure()
for i in range(Lmax):
  xres,yres=ed.RotateAll(xs,ys, i, Lmax) # rotate sector 0 to get other sectors
  plt.scatter(xres, yres) # plot site positions in a given sector
  xall.extend(xres)
  yall.extend(yres)
pos=np.transpose(np.array([xall,yall]))

ax=plt.gca()
ax.set_aspect("equal")


# generate interactions (with rotation-invariance-restoring transform)
if(simplify):
  interactions=ed.RealDipolarInteractionFiniteCircularSimplifiedTransform(pos, B, Vquad,k0, rtrunc, Lmax)
else:
  interactions=ed.RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax,hermitian)


################################################################################
##  Find single-particle eigenstates (orbitals for later overlap calculation)
################################################################################

basisSP=ed.GenerateBasis(1, Nsite) #generate single-particle site basis
mombasisSP = ed.FindAngularMomentumEigenstates(basisSP, 1, Lmax) #generate single-particle momentum basis

evalsSPL=[]
evecsSPL=[]
evecssiteSPL=[]
neval=6
evalsSP_all=[]
LsSP_all=[]
for L in range(Lmax):
  H=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions,basisSP,Lmax, mombasisSP) # generate the Hamiltonian in momentum subspace
  Hdense=H.todense()
  eval, evec=np.linalg.eig(Hdense) # diagonalize 
  arg=np.argsort(np.real(eval)) #sort eigenvalues
  eval=eval[arg]
  evec=evec[:,arg]
  evec=np.array(evec)
  evalsSP_all.extend(eval) # 1D list of eigenvalues (e.g. for plotting)
  LsSP_all.extend(np.ones(len(eval))*L) # 1d array of angular momenta (e.g. for plotting)
  evalsSPL.append(eval) # store eigenvalues in a form where it is easy to retrieve the ones for given angular momentum 
  evecsSPL.append(evec) # store eigenvectors in a form where it is easy to retrieve the ones for given angular momentum
  evecsite=ed.ConvertEigenstates(evec, basisSP, mombasisSP, L,Lmax) # eigenvectors to site basis
  evecssiteSPL.append(evecsite) # store site-basis eigenvectors in a form where it is easy to retrieve the ones for given angular momentum

# plot single-particle spectrum
plt.figure()
plt.scatter(LsSP_all, np.real(evalsSP_all))
plt.colorbar()

################################################################################
##  Find many-particle eigenstates
################################################################################

basis=ed.GenerateBasis(Npart, Nsite) #site-space many-particle basis
mombasis = ed.FindAngularMomentumEigenstates(basis, Npart, Lmax) # many-particle angular momentum basis
evalsL=[]
evecsL=[]
evecssiteL=[]
neval=10 # number of eigenvalues to calculate if the matrix is sparse
evals_all=[]
Ls_all=[]
for L in range(Lmax):
  H=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions,basis,Lmax, mombasis)
  if (len(mombasis[0][L])<20): # if Hamiltonian is small, use dense diagonalization, otherwise use sparse 
    Hdense=H.todense()
    eval, evec=np.linalg.eig(Hdense)
    arg=np.argsort(np.real(eval))
    eval=eval[arg]
    evec=evec[:,arg]
    evec=np.array(evec)
    evals_all.extend(eval)
    Ls_all.extend(np.ones(len(eval))*L)
    evalsL.append(eval)
    evecsL.append(evec)
    evecsite=ed.ConvertEigenstates(evec, basis, mombasis, L,Lmax)
    evecssiteL.append(evecsite)
  else:
    eval, evec=eigs(H, k=min(len(mombasis[0][L])-1,neval), which = 'SR', return_eigenvectors=True)
    arg=np.argsort(np.real(eval))
    eval=eval[arg]
    evec=evec[:,arg]
    evals_all.extend(eval)
    Ls_all.extend(np.ones(len(eval))*L)
    evalsL.append(eval)
    evecsL.append(evec)
    evecsite=ed.ConvertEigenstates(evec, basis, mombasis, L,Lmax)
    evecssiteL.append(evecsite)

#plot spectrum if other plots are not performed
if not(ComputeAntihermitian or ComputeOverlaps):
  plt.figure()
  plt.scatter(Ls_all, np.real(evals_all), c=-2*np.imag(evals_all))
  plt.colorbar()

################################################################################
##  Radial density
################################################################################
densL=[]
rmeans_all=[]
rmeansL=[]
if(ComputeRadialDensity):
  for L in range(Lmax):
    evecsite=evecssiteL[L]
    neval=evecsite.shape[1]
    rs, dens=ed.RadialDensities(pos, evecsite, basis)
    densL.append(dens)
    rmeans=[]
    for i in range(neval):
      rmean=sum(np.array(rs)*dens[:,i])
      rmeans_all.append(rmean)
      rmeans.append(rmean)
    rmeansL.append(rmeans)
  dens_all=[np.concatenate([[float("nan"), float("nan")],rs])]
  for L in range(Lmax):
    dens=densL[L]
    for i in range(dens.shape[1]):
      dens_all.append(np.concatenate([[L,i], dens[:,i]]))    
  np.savetxt("raddens_OBC_d0={:.4f}_herm{}_simp{}_D0={:.4f}_Npart={}_B0={:.4f}_Vquad0={:.4f}_rtrunc={:.4f}.txt".format(d0, hermitian,simplify, D0,Npart,B0,Vquad0,rtrunc),dens_all)




################################################################################
##  Expectation value of the antihermitian part
################################################################################
if(ComputeAntihermitian): 
  interactions_antiherm=ed.AntiHermitianInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax)
  expval_all=[]
  for L in range(Lmax):
    print("L=", L)
    H_antiherm=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions_antiherm,basis,Lmax, mombasis)
    evec=evecsL[L]
    expval=[]
    for i in range(evec.shape[1]):
      vec=evec[:,i]
      expval.append(np.conj(vec) @ H_antiherm @ vec )
    #print(expval.shape)
    expval_all.extend(expval)
  rates_all=-2*np.imag(expval_all)

  #----------- plot spectrum with decay rates
  plt.figure()
  plt.scatter(Ls_all, np.real(evals_all), c=rates_all)
  plt.colorbar()
  masked=np.ma.masked_greater(rates_all, Npart) # mark the subradiant states
  plt.scatter(np.ma.masked_array(Ls_all, mask=masked.mask), np.real(np.ma.masked_array(evals_all, mask=masked.mask)), marker="s", s=100, facecolor="None", edgecolor="b")


################################################################################
##  Calculate overlaps with model Laughlin states
################################################################################

# default value of mmax
if(mmax==-1):
  mmax=q*(Npart-1)+Lmax-1

Lmax_ov=8 # maximum L' for which the overlap is computed
if(ComputeOverlaps):
  L0=(q*Npart*(Npart-1)//2)%Lmax
  laughlin_states=sympywfn.GetFewLowestEdgeStatesMmax1D(Npart, q, Lmax_ov-L0, mmax) # get the model Laughlin states and the basis of orbital occupations
  # get the relevant orbitals and arrange them according to continuum angular momentum
  Nsub=int(np.ceil((mmax+1)/Lmax))
  evecsSP_ordered=[]
  for isub in range(Nsub):
    for L in range(Lmax): 
      evecsSP_ordered.append(evecssiteSPL[L][:,isub])
  overlaps_all=[]
  # prepare empty arrays to fill with overlaps
  for L in range(Lmax):
    overlaps_all.append(np.zeros(evecssiteL[L].shape[1]))
  # ------ loop over delta L', compute overlaps
  for deltaL in range(Lmax_ov-L0+1):
    L=(L0+deltaL)%Lmax # array angular momentum L corresponding to given Delta L'
    basis_laughlin, Q=laughlin_states[deltaL] # get basis and Laughlin states for given Delta L' (Q-the array with all Laughlin states)
    prefactors=sympywfn.GetPrefactors(basis_laughlin, mmax) #normalization prefactors coming from two or more bosons at the same orbital
    overlaps=[] # list of overlaps for all states of given L
    # ------ convert the ED eigenstate to basis_laughlin. compute the overlap
    for itstate in range(evecssiteL[L].shape[1]):
      target_state=evecssiteL[L][:,itstate] # ED eigenstate to calculate overlaps with
      target_state_LL=np.zeros(len(basis_laughlin), dtype="complex")
      for il, bl in enumerate(basis_laughlin):
        for i,conf in enumerate(basis):
          conf2=[]
          for ic,c in enumerate(conf):
            if (c>0):
              conf2.append(2*ic+c-1)
          for perm in permutations(bl):
            prod=prefactors[il]
            for j,p in enumerate(perm):
              prod*=np.conj(evecsSP_ordered[p][conf2[j]])
            target_state_LL[il]+=prod*target_state[i]
      overlaps.append(ed.GeneralizedOverlap(target_state_LL, Q))  # compute overlap, append to list 
    overlaps_all[L%Lmax]+=overlaps
  # store overlaps in a way suitable for plotting
  overlaps_all2=[]
  for ovs in overlaps_all:
     overlaps_all2.extend(ovs)
  # plot spectrum with overlaps
  plt.figure()
  plt.scatter(Ls_all, np.real(evals_all), c=overlaps_all2)
  plt.colorbar()


# -------- arrange data for output and create output file header --------
output=[Ls_all, np.real(evals_all)]
header="# L, re(E), "
if(not(hermitian))or(simplify):
  header+="im(E), "
  output.append(np.imag(evals_all))
if(ComputeOverlaps):
  header+="overlap, "
  output.append(overlaps_all2)
if(ComputeAntihermitian):
  header+="antiherm rate, "
  output.append(rates_all)
if(ComputeRadialDensity):
  header+="average radius, "
  output.append(rmeans_all)


# output data
np.savetxt("spectrum_OBC_d0={:.4f}_herm{}_simp{}_D0={:.4f}_Npart={}_B0={:.4f}_Vquad0={:.4f}_rtrunc={:.4f}_mmax{}.txt".format(d0, hermitian,simplify, D0,Npart,B0,Vquad0,rtrunc, mmax),np.transpose(output), header=header)

plt.show()
