import numpy as np
import matplotlib.pyplot as plt
import sys
import os
import inspect
import math
from string import ascii_lowercase


# import ED functions
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 
import NonHermitianAngularMomentumED as ed


##################################################################################################
##  Start calculations
##################################################################################################
lambda0=1 #atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
hermitian=True # consider the Hermitian part or the whole non-Hermitian Hamiltonian
simplify=False # consider a simplified model or the full model
D0=2.1 #system size
Lmax=6 # Lmax-fold rotational symmetry
Npart=1 #number of particles (keep 1 for this phase diagram)
rtrunc=-1 #truncation radius of interaction (-1 if no truncation)
B0=12 # rescaled magnetic field

orb=[[1./3.,1./3.], [2./3.,2./3.]] # positions of atoms in the unit cell

max_orbital=4 #maximum orbital m up to which the LLL structure has to hold

d0s=np.linspace(0.01,0.5, 100, endpoint=True) # d0 values on the x axis fo phase diagram
nd0s=len(d0s)
Vquad0s=np.linspace(0,10,100, endpoint=True) #Vharm0 values on the y axis of the phase diagram
nVquad0s=len(Vquad0s)
landau_conditions=np.zeros((nd0s, nVquad0s), dtype=bool) #phase diagram array

################################################################################
##  Start phase diagram calculations
################################################################################
#form the Hamiltonian for each point (d0, Vquad0), diagonalize, get the radial density, check the LLL-like structure

for id0, d0 in enumerate(d0s):
 for iVquad0, Vquad0 in enumerate(Vquad0s):
  print(id0,iVquad0)
  d=d0*lambda0
  Vquad=Vquad0/d**5/k0**3 #non-rescaled confining potential strength
  B=B0/d**3/k0**3 # non-rescaled magnetic field
  D=D0*d # system size
  latvec=[[d*np.sqrt(3)/2.,-d*0.5], [d*np.sqrt(3)/2.,d*0.5] ] #lattice vfectors
  xs,ys=ed.GetLatticePointsSegment(D0, latvec, orb,d)
  # generate site positions
  Nseg=len(xs)
  Nsite=Nseg*Lmax
  xall=[]
  yall=[]
  for i in range(Lmax):
    xres,yres=ed.RotateAll(xs,ys, i, Lmax)
    xall.extend(xres)
    yall.extend(yres)
  pos=np.transpose(np.array([xall,yall]))
  #generate the interactions defining the Hamiltonian
  if(simplify):
    interactions=ed.RealDipolarInteractionFiniteCircularSimplifiedTransform(pos, B, Vquad,k0, rtrunc, Lmax)
  else:
    interactions=ed.RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc,Lmax,hermitian)
  basisSP=ed.GenerateBasis(Npart, Nsite) # generate site basis
  mombasisSP = ed.FindAngularMomentumEigenstates(basisSP, Npart, Lmax) #generate angular momentum basis
  densL=[]
  # generate the Hamiltonian for each L and diagonalize it
  for L in range(Lmax):
    H=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions,basisSP,Lmax, mombasisSP)
    Hdense=H.todense()
    eval, evec=np.linalg.eig(Hdense)
    arg=np.argsort(np.real(eval))
    eval=eval[arg]
    evec=evec[:,arg]
    evec=np.array(evec)
    evecsite=ed.ConvertEigenstates(evec, basisSP, mombasisSP, L,Lmax) # convert eigenvectors to site basis
    rs, dens=ed.RadialDensities(pos, evecsite, basisSP) # compute radial densities of each eigenvector
    densL.append(dens)
  rmeans=[]
  nst=2
  k=0
  # arrange the orbitals according to the continuum angular momentum m of corresponding LLL orbital, compute <r>
  for i in range(nst):
   for L in range(Lmax):
    dens0=densL[L][:,i]
    rmean=sum(np.array(rs)*dens0) #compute <r>
    rmeans.append(rmean)
    k+=1
  rmeans_relevant=rmeans[:max_orbital+1] #select the <r> values for the relevant orbitals
  # check if the orbitals have LLL-like structure (<r> growing with m)
  landau_condition=True
  for i in range(max_orbital):
    if(rmeans_relevant[i]>=rmeans_relevant[i+1]):
      landau_condition=False
  landau_conditions[id0, iVquad0]=landau_condition

# plot the phase diagram
plt.figure()
dVquad0=Vquad0s[1]-Vquad0s[0]
Vquad0max=np.max(Vquad0s)
Vquad0min=np.min(Vquad0s)
dd0=d0s[1]-d0s[0]
d0max=np.max(d0s)
d0min=np.min(d0s)
plt.imshow(np.transpose(landau_conditions), extent=[d0min-dd0/2., d0max+dd0/2., Vquad0min-dVquad0/2., Vquad0max+dVquad0/2.], origin="lower")
ax=plt.gca()
ax.set_aspect("auto")

# save the phase diagram to a file
# the file is a 2D array of 0s and 1s, with the information about the d0 and Vquad0 values in the header
plt.xlabel(r"$d_0$")
plt.ylabel(r"$V_{harm0}$")
header="{} {} {} {} {} {}".format(d0min, d0max, dd0, Vquad0min, Vquad0max, dVquad0)
np.savetxt("phasediagram_landau_OBC_maxo={}_herm{}_D0={:.4f}_B0={:.4f}_rtrunc={:.4f}.txt".format(max_orbital, hermitian, D0,B0,rtrunc),landau_conditions, header=header)


plt.show()
