from pathlib import Path
from typing import Literal
import dpdata
import numpy as np

from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.mixing import SumCalculator
from dftd3.ase import DFTD3
from tqdm import tqdm
from deepmd.calculator import DP


def run_ase_dptest(
        calc: Calculator,
        test_data: Path,
        dispersion_correction: Literal["d3bj", "d3zero"] | None = None,
        # check all supported levels at dftd3.qcschema._available_levels
) -> dict:
    if dispersion_correction is not None:
        calc = SumCalculator(
            [calc, DFTD3(method="PBE", damping=dispersion_correction)]
        )

    energy_err = []
    energy_pre = []
    energy_lab = []
    atom_num = []
    energy_err_per_atom = []
    force_err = []
    virial_err = []
    virial_err_per_atom = []
    max_ele_num = 120
    systems = [i.parent for i in test_data.rglob("type_map.raw")]
    assert systems, f"No systems found in the test data {test_data}."
    mix_type = any(systems[0].rglob("real_atom_types.npy"))

    for filepth in tqdm(systems, desc="Systems"):
        if mix_type:
            sys = dpdata.MultiSystems()
            sys.load_systems_from_file(filepth, fmt="deepmd/npy/mixed")
        else:
            sys = dpdata.LabeledSystem(filepth, fmt="deepmd/npy")
        for ls in tqdm(sys, desc="Set", leave=False):
            for frame in tqdm(ls, desc="Frames", leave=False):
                atoms: Atoms = frame.to_ase_structure()[0]
                atoms.calc = calc

                # Energy
                energy_predict = np.array(atoms.get_potential_energy())
                energy_pre.append(energy_predict)
                energy_lab.append(frame.data["energies"])
                energy_err.append(energy_predict - frame.data["energies"])
                energy_err_per_atom.append(energy_err[-1] / len(atoms))
                atomic_numbers = atoms.get_atomic_numbers()
                atom_num.append(np.bincount(atomic_numbers, minlength=max_ele_num))

                # Force
                try:
                    force_pred = atoms.get_forces()
                    force_err.append(
                        frame.data["forces"].squeeze(0) - np.array(force_pred)
                    )
                except KeyError as _:  # no force in the data
                    pass

                # Virial
                try:
                    stress = atoms.get_stress()
                    virial_tensor = (
                            -np.array(
                                [
                                    [stress[0], stress[5], stress[4]],
                                    [stress[5], stress[1], stress[3]],
                                    [stress[4], stress[3], stress[2]],
                                ]
                            )
                            * atoms.get_volume()
                    )
                    virial_err.append(frame.data["virials"] - virial_tensor)
                    virial_err_per_atom.append(
                        virial_err[-1] / force_err[-1].shape[0]
                    )
                except (
                        NotImplementedError,  # atoms.get_stress() for eqv2
                        ValueError,  # atoms.get_volume()
                        KeyError,  # frame.data["virials"]
                ) as _:  # no virial in the data
                    pass

    atom_num = np.array(atom_num)
    energy_err = np.array(energy_err)
    energy_pre = np.array(energy_pre)
    energy_lab = np.array(energy_lab)
    shift_bias, _, _, _ = np.linalg.lstsq(atom_num, energy_err, rcond=1e-10)
    unbiased_energy = (
            energy_pre
            - (atom_num @ shift_bias.reshape(max_ele_num, -1)).reshape(-1)
            - energy_lab.squeeze()
    )
    unbiased_energy_err_per_a = unbiased_energy / atom_num.sum(-1)

    res = {
        "energy_mae": [np.mean(np.abs(np.stack(unbiased_energy)))],
        "energy_rmse": [np.sqrt(np.mean(np.square(unbiased_energy)))],
        "energy_mae_natoms": [np.mean(np.abs(np.stack(unbiased_energy_err_per_a)))],
        "energy_rmse_natoms": [
            np.sqrt(np.mean(np.square(unbiased_energy_err_per_a)))
        ],
    }
    if force_err:
        res.update(
            {
                "force_mae": [np.mean(np.abs(np.concatenate(force_err)))],
                "force_rmse": [
                    np.sqrt(np.mean(np.square(np.concatenate(force_err))))
                ],
            }
        )
    if virial_err_per_atom:
        res.update(
            {
                "virial_mae": [np.mean(np.abs(np.stack(virial_err)))],
                "virial_rmse": [np.sqrt(np.mean(np.square(np.stack(virial_err))))],
                "virial_mae_natoms": [
                    np.mean(np.abs(np.stack(virial_err_per_atom)))
                ],
                "virial_rmse_natoms": [
                    np.sqrt(np.mean(np.square(np.stack(virial_err_per_atom))))
                ],
            }
        )
    return res


# replicate zero-shot test of DPA-3.1-3M on ff-task, take Vandermause2022Active as an example
model = DP('../../model/DPA-3.1-3M.pt', head="Omat24")
result = run_ase_dptest(model, Path("../../data/Force-Field-Task-Datasets/Vandermause2022Active"))
print(f"Energy RMSE: {result['energy_rmse_natoms'][0] * 1000:.1f} meV/atom")
print(f"Force RMSE: {result['force_rmse'][0] * 1000:.1f} meV/Å")
print(f"Virial RMSE: {result['virial_rmse_natoms'][0] * 1000:.1f} meV/atom")





