from __future__ import annotations

import argparse
from pathlib import Path

from ase.units import Hartree
from ase.io.trajectory import write_atoms
from ase.build import molecule
from gpaw.mpi import world
import numpy as np
from ase.io.ulm import Writer


def write_wavefunctions(fpath: Path,
                        nM: int,
                        nn: int,
                        nt: int = 301,
                        time: float = 1240,  # 30fs
                        ):
    ushape = (1, 1)  # kpoints, spins
    occ_un = 2.0 * (np.linspace(0, 1, nn) < 0.5).reshape((1, 1, -1))

    writer = Writer(str(fpath), mode='w', tag='WFW')
    writer.write(version=3, split=False)

    writer.write(niter=0, time=0, action='init')
    w = writer.child('wave_functions')
    w.write(coefficients=np.random.rand(*ushape, nn, nM))
    w.write(occupations=occ_un)
    writer.sync()

    for niter, time in enumerate(np.linspace(0, time, nt)[1:]):
        writer.write(niter=niter, time=time, action='propagate')

        w = writer.child('wave_functions')
        w.write(coefficients=np.random.rand(*ushape, nn, nM)
                + 1j * np.random.rand(*ushape, nn, nM))
        w.write(occupations=occ_un)
        writer.sync()


def write_ksd(fpath: Path, nM: int):
    ushape = (1, 1)  # kpoints, spins
    atoms = molecule('H2')  # Dummy molecule

    # Generate reasonable eigenenergies
    eig_un = np.sort(np.random.rand(*ushape, nM))
    eig_n = eig_un[0, 0]
    fermilevel = eig_n[nM // 2]
    occ_un = np.array(2 * (eig_un < fermilevel), dtype=int)

    writer = Writer(str(fpath), mode='w', tag='KSD')
    writer.write(version=1)

    write_atoms(writer.child('atoms'), atoms)

    writer.write(ha=Hartree)
    # n dimension includes all unoccupied states
    writer.write(S_uMM=np.random.rand(*ushape, 1, nM, nM))
    writer.write(C0_unM=np.random.rand(*ushape, 1, nM, nM))
    writer.write(eig_un=eig_un, occ_un=occ_un)
    writer.write(occ_un=np.random.rand(*ushape, nM))
    writer.write(fermilevel=fermilevel)
    writer.write(only_ia=True)

    ia_p = []
    f_p = []
    w_p = []
    for i in range(nM):
        a0 = i + 1
        for a in range(a0, nM):
            f = occ_un[0, 0,  i] - occ_un[0, 0, a]
            if f < 1e-3:
                continue
            w = eig_n[a] - eig_n[i]
            f_p.append(f)
            w_p.append(w)
            ia_p.append((i, a))
    f_p = np.array(f_p)
    w_p = np.array(w_p)
    ia_p = np.array(ia_p, dtype=int)

    # Sort according to energy difference
    p_s = np.argsort(w_p)
    f_p = f_p[p_s]
    w_p = w_p[p_s]
    ia_p = ia_p[p_s]

    Np = len(f_p)
    P_p = []
    for p in range(Np):
        P = np.ravel_multi_index(ia_p[p], (nM, nM))
        P_p.append(P)
    P_p = np.array(P_p)

    writer.write(w_p=w_p, f_p=f_p, ia_p=ia_p, P_p=P_p)

    writer.write(dm_vp=np.random.rand(3, Np))
    writer.write(a_M=np.array(np.linspace(0, 1, nM) < 0.5, dtype=int))  # 0 or 1
    writer.write(l_M=np.zeros(nM, dtype=int))  # 0

    writer.close()


def generate_data(dpath: Path,
                  nbasis: int):
    assert world.size == 1, 'Do not run this script with MPI'
    dpath.mkdir(parents=True, exist_ok=True)

    nM = nbasis
    nn = nbasis // 3

    # Generate dummy ksd
    write_ksd(dpath / 'ksd.ulm', nM)
    write_wavefunctions(dpath / 'wfs.ulm', nM, nn)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('dpath', type=Path)
    parser.add_argument('nbasis', type=int)

    args = parser.parse_args()

    generate_data(args.dpath, args.nbasis)
