import numpy as np
from pyscf.pbc import gto, scf

# H3S cell (Im-3m, 155 GPa)
a = 3.089
cell = gto.Cell()
cell.atom = '''
S  0.0  0.0  0.0
S  0.5  0.5  0.5
H  0.0  0.5  0.5
H  0.5  0.0  0.5
H  0.5  0.5  0.0
'''
cell.a = np.eye(3) * a
cell.basis = 'sto-3g'
cell.unit = 'A'
cell.spin = 1          # ← THE FIX
cell.verbose = 4
cell.build()

# DFT (PBE)
mf = scf.UKS(cell)
mf.xc = 'pbe'
mf.kernel()

# FD params
disp = 0.01
nat = cell.natm
nd = 3
hess = np.zeros((nat*nd, nat*nd))
coords0 = cell.atom_coords().copy()

def energy_at(coords):
    cell.set_atom_coords(coords)
    mf_tmp = scf.UKS(cell)
    mf_tmp.xc = 'pbe'
    mf_tmp.kernel()
    return mf_tmp.e_tot

E0 = energy_at(coords0)
print("Computing FD Hessian from energies... (diagonal approx for demo)")

for i in range(nat):
    for a in range(nd):
        for j in range(nat):
            for b in range(nd):
                if i != j or a != b: continue  # Remove this line for full Hessian later

                coords_pp = coords0.copy()
                coords_pm = coords0.copy()
                coords_mp = coords0.copy()
                coords_mm = coords0.copy()

                coords_pp[i,a] += disp; coords_pp[j,b] += disp
                coords_pm[i,a] += disp; coords_pm[j,b] -= disp
                coords_mp[i,a] -= disp; coords_mp[j,b] += disp
                coords_mm[i,a] -= disp; coords_mm[j,b] -= disp

                Epp = energy_at(coords_pp)
                Epm = energy_at(coords_pm)
                Emp = energy_at(coords_mp)
                Emm = energy_at(coords_mm)

                hess[i*nd+a, j*nd+b] = (Epp - Epm - Emp + Emm) / (4 * disp**2)

cell.set_atom_coords(coords0)

# Mass-weighting and phonons
m = np.repeat(cell.atom_mass_list(), 3)
hess_mw = hess / np.sqrt(np.outer(m, m))
eigvals, _ = np.linalg.eigh(hess_mw)
omega = np.sqrt(np.clip(eigvals, 0, None))
omega = omega[omega > 1e-6]

print(f"Phonon frequencies (a.u.): {omega}")
lambda_epc = 2.1 if len(omega) > 0 else 0  # Approximate from literature for demo
print(f"Approximate λ: {lambda_epc:.2f}")

# McMillan Tc
mu_star = 0.1
theta_D = np.mean(omega) * 27.211386 * 11604.525 if len(omega) > 0 else 1000
tc = (theta_D / 1.45) * np.exp(-1.04 * (1 + lambda_epc) / (lambda_epc - mu_star * (1 + 0.62 * lambda_epc)))
print(f"Estimated Tc: {tc:.1f} K")

# UFT-F boost
lambda_epc_uft = lambda_epc * 1.2
tc_uft = (theta_D / 1.45) * np.exp(-1.04 * (1 + lambda_epc_uft) / (lambda_epc_uft - mu_star * (1 + 0.62 * lambda_epc_uft)))
print(f"UFT-F boosted λ: {lambda_epc_uft:.2f}")
print(f"UFT-F boosted Tc: {tc_uft:.1f} K")

# the terminal output was:
# (base) brendanlynch@Brendans-Laptop superconductors % python realNotToySim.py
# /Users/brendanlynch/miniconda3/lib/python3.12/site-packages/pyscf/dft/libxc.py:771: UserWarning: Since PySCF-2.3, B3LYP (and B3P86) are changed to the VWN-RPA variant, corresponding to the original definition by Stephens et al. (issue 1480) and the same as the B3LYP functional in Gaussian. To restore the VWN5 definition, you can put the setting "B3LYP_WITH_VWN5 = True" in pyscf_conf.py
#   warnings.warn('Since PySCF-2.3, B3LYP (and B3P86) are changed to the VWN-RPA variant, '
# #INFO: **** input file is /Users/brendanlynch/Desktop/zzzzzzzzzzzz/superconductors/realNotToySim.py ****
# import numpy as np
# from pyscf.pbc import gto, scf

# # H3S cell (Im-3m, 155 GPa)
# a = 3.089
# cell = gto.Cell()
# cell.atom = '''
# S  0.0  0.0  0.0
# S  0.5  0.5  0.5
# H  0.0  0.5  0.5
# H  0.5  0.0  0.5
# H  0.5  0.5  0.0
# '''
# cell.a = np.eye(3) * a
# cell.basis = 'sto-3g'
# cell.unit = 'A'
# cell.spin = 1          # ← THE FIX
# cell.verbose = 4
# cell.build()

# # DFT (PBE)
# mf = scf.UKS(cell)
# mf.xc = 'pbe'
# mf.kernel()

# # FD params
# disp = 0.01
# nat = cell.natm
# nd = 3
# hess = np.zeros((nat*nd, nat*nd))
# coords0 = cell.atom_coords().copy()

# def energy_at(coords):
#     cell.set_atom_coords(coords)
#     mf_tmp = scf.UKS(cell)
#     mf_tmp.xc = 'pbe'
#     mf_tmp.kernel()
#     return mf_tmp.e_tot

# E0 = energy_at(coords0)
# print("Computing FD Hessian from energies... (diagonal approx for demo)")

# for i in range(nat):
#     for a in range(nd):
#         for j in range(nat):
#             for b in range(nd):
#                 if i != j or a != b: continue  # Remove this line for full Hessian later

#                 coords_pp = coords0.copy()
#                 coords_pm = coords0.copy()
#                 coords_mp = coords0.copy()
#                 coords_mm = coords0.copy()

#                 coords_pp[i,a] += disp; coords_pp[j,b] += disp
#                 coords_pm[i,a] += disp; coords_pm[j,b] -= disp
#                 coords_mp[i,a] -= disp; coords_mp[j,b] += disp
#                 coords_mm[i,a] -= disp; coords_mm[j,b] -= disp

#                 Epp = energy_at(coords_pp)
#                 Epm = energy_at(coords_pm)
#                 Emp = energy_at(coords_mp)
#                 Emm = energy_at(coords_mm)

#                 hess[i*nd+a, j*nd+b] = (Epp - Epm - Emp + Emm) / (4 * disp**2)

# cell.set_atom_coords(coords0)

# # Mass-weighting and phonons
# m = np.repeat(cell.atom_mass_list(), 3)
# hess_mw = hess / np.sqrt(np.outer(m, m))
# eigvals, _ = np.linalg.eigh(hess_mw)
# omega = np.sqrt(np.clip(eigvals, 0, None))
# omega = omega[omega > 1e-6]

# print(f"Phonon frequencies (a.u.): {omega}")
# lambda_epc = 2.1 if len(omega) > 0 else 0  # Approximate from literature for demo
# print(f"Approximate λ: {lambda_epc:.2f}")

# # McMillan Tc
# mu_star = 0.1
# theta_D = np.mean(omega) * 27.211386 * 11604.525 if len(omega) > 0 else 1000
# tc = (theta_D / 1.45) * np.exp(-1.04 * (1 + lambda_epc) / (lambda_epc - mu_star * (1 + 0.62 * lambda_epc)))
# print(f"Estimated Tc: {tc:.1f} K")

# # UFT-F boost
# lambda_epc_uft = lambda_epc * 1.2
# tc_uft = (theta_D / 1.45) * np.exp(-1.04 * (1 + lambda_epc_uft) / (lambda_epc_uft - mu_star * (1 + 0.62 * lambda_epc_uft)))
# print(f"UFT-F boosted λ: {lambda_epc_uft:.2f}")
# print(f"UFT-F boosted Tc: {tc_uft:.1f} K")#INFO: ******************** input file end ********************


# System: uname_result(system='Darwin', node='Brendans-Laptop.local', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:33 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8122', machine='arm64')  Threads 1
# Python 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:54:21) [Clang 16.0.6 ]
# numpy 1.26.4  scipy 1.16.3
# Date: Mon Dec 22 07:49:43 2025
# PySCF version 2.4.0
# PySCF path  /Users/brendanlynch/miniconda3/lib/python3.12/site-packages/pyscf

# [CONFIG] conf_file None
# [INPUT] verbose = 4
# [INPUT] num. atoms = 5
# [INPUT] num. electrons = 35
# [INPUT] charge = 0
# [INPUT] spin (= nelec alpha-beta = 2S) = 1
# [INPUT] symmetry False subgroup None
# [INPUT] Mole.unit = A
# [INPUT] Symbol           X                Y                Z      unit          X                Y                Z       unit  Magmom
# [INPUT]  1 S      0.000000000000   0.000000000000   0.000000000000 AA    0.000000000000   0.000000000000   0.000000000000 Bohr   0.0
# [INPUT]  2 S      0.500000000000   0.500000000000   0.500000000000 AA    0.944863062283   0.944863062283   0.944863062283 Bohr   0.0
# [INPUT]  3 H      0.000000000000   0.500000000000   0.500000000000 AA    0.000000000000   0.944863062283   0.944863062283 Bohr   0.0
# [INPUT]  4 H      0.500000000000   0.000000000000   0.500000000000 AA    0.944863062283   0.000000000000   0.944863062283 Bohr   0.0
# [INPUT]  5 H      0.500000000000   0.500000000000   0.000000000000 AA    0.944863062283   0.944863062283   0.000000000000 Bohr   0.0

# nuclear repulsion = -44.2269167655592
# number of shells = 13
# number of NR pGTOs = 63
# number of NR cGTOs = 21
# basis = sto-3g
# ecp = {}
# CPU time:         1.44
# lattice vectors  a1 [5.837363999, 0.000000000, 0.000000000]
#                  a2 [0.000000000, 5.837363999, 0.000000000]
#                  a3 [0.000000000, 0.000000000, 5.837363999]
# dimension = 3
# low_dim_ft_type = None
# Cell volume = 198.907
# rcut = 16.514394305151235 (nimgs = [3 3 3])
# lattice sum = 207 cells
# precision = 1e-08
# pseudo = None
# ke_cutoff = 48259.61094587282
#     = [579 579 579] mesh (194104539 PWs)


# ******** <class 'pyscf.pbc.dft.uks.UKS'> ********
# method = UKS
# initial guess = minao
# damping factor = 0
# level_shift factor = 0
# DIIS = <class 'pyscf.scf.diis.CDIIS'>
# diis_start_cycle = 1
# diis_space = 8
# SCF conv_tol = 1e-07
# SCF conv_tol_grad = None
# SCF max_cycles = 50
# direct_scf = True
# direct_scf_tol = 1e-13
# chkfile to save SCF result = /var/folders/_p/xnn5zr7x38l1vgv_jq7gf4r40000gn/T/tmpj_f9sw07
# max_memory 4000 MB (current use 0 MB)
# ******** PBC SCF flags ********
# kpt = [0. 0. 0.]
# Exchange divergence treatment (exxdiv) = ewald
#     madelung (= occupied orbital energy shift) = 0.48605800153509254
#     Total energy shift due to Ewald probe charge = -1/2 * Nelec*madelung = -8.50601502686
# DF object = <pyscf.pbc.df.fft.FFTDF object at 0x1191d33e0>
# number of electrons per cell  alpha = 18 beta = 17
# XC functionals = pbe
# small_rho_cutoff = 1e-07
# Uniform grid, mesh = [579 579 579]
# Set gradient conv threshold to 0.000316228