import numpy as np
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
from matplotlib.patches import Circle

fig = plt.figure(figsize=(3+3/8, 4+1/2), layout='constrained', dpi=300)
figL, figR = fig.subfigures(1, 2, wspace=0.1)

def interp(Z):
    l = Z.shape[0]
    flat = np.linspace(-1.0, 1.0, l)
    X, Y = np.meshgrid(flat, flat)
    flat = np.linspace(-1.0, 1.0, 3*l)
    xx, yy = np.meshgrid(flat, flat)
    zz = griddata((X.flatten(),Y.flatten()),Z.flatten(), (xx,yy), method='cubic', fill_value=0)
    zz[(xx*xx+yy*yy)>1] = 0.0
    return zz

axsL = figL.subplots(5, 2)
axsR = figR.subplots(5, 2)

# datL = np.fromfile('./Data/wf_weak.dat', dtype=np.complex128).reshape(16, 201, 201)
datL = np.fromfile('./Data/wave_function_Nh.dat', dtype=np.complex128).reshape(16, 201, 201)
datR = np.fromfile('./Data/wave_function_1.dat', dtype=np.complex128).reshape(8, 101, 101)

def plot_dens_phase(ax, wfdat):
    dat = interp(wfdat[::-1])
    # dat = wfdat

    im = ax[0].imshow(np.abs(dat)**2, cmap='hot', extent=(-1.1, 1.1, -1.1, 1.1), interpolation='lanczos', interpolation_stage='rgba')
    im.set_clip_path(Circle((0, 0), radius=1, lw=0.5, transform=ax[0].transData))
    ax[0].add_patch(Circle((0, 0), fill=False, radius=1, lw=0.5, edgecolor='black'))

    im = ax[1].imshow(np.angle(dat), cmap='hsv', extent=(-1.1, 1.1, -1.1, 1.1), interpolation='lanczos', interpolation_stage='rgba')
    im.set_clip_path(Circle((0, 0), radius=1, lw=0.5, transform=ax[1].transData))
    ax[1].add_patch(Circle((0, 0), fill=False, radius=1, lw=0.5, edgecolor='black'))

    # ax[0].imshow(np.abs(dat)**2, cmap='hot', interpolation='lanczos', interpolation_stage='rgba')
    # ax[1].imshow(np.angle(dat), cmap='hsv', vmin=-np.pi, vmax=np.pi, interpolation='lanczos', interpolation_stage='rgba')

indicesL = [0,3,6,9,13]
for i in range(5):
    plot_dens_phase(axsL[i], datL[indicesL[i]])
    plot_dens_phase(axsR[i], datR[i])

    axsL[i,0].set_axis_off()
    axsL[i,1].set_axis_off()
    axsR[i,0].set_axis_off()
    axsR[i,1].set_axis_off()

figL.suptitle('(a)', fontweight='bold')
figR.suptitle('(b)', fontweight='bold')

figL.get_layout_engine().set(h_pad=0, hspace=0)
fig.get_layout_engine().set(w_pad=0)

plt.savefig('./Images/drawings.pdf')
plt.savefig('./Images/drawings.png')
