import os
from pathlib import Path

import pytest
from omegaconf import DictConfig, open_dict

from mldft.ml.train import train
from tests.helpers.run_if import RunIf


def test_train_cpu(cfg_train: DictConfig, create_two_electron_dataset) -> None:
    """Train 1 epoch on CPU.

    Args:
        cfg_train: A DictConfig containing a valid training configuration.
    """
    with open_dict(cfg_train):
        cfg_train.trainer.accelerator = "cpu"
    train(cfg_train)


@RunIf(min_gpus=1)
def test_train_gpu(cfg_train: DictConfig, create_two_electron_dataset) -> None:
    """Train 1 epoch on GPU.

    Args:
        cfg_train: A DictConfig containing a valid training configuration.
    """
    with open_dict(cfg_train):
        cfg_train.trainer.accelerator = "gpu"
    train(cfg_train)


@RunIf(min_gpus=1)
def test_train_epoch_gpu_amp(cfg_train: DictConfig, create_two_electron_dataset) -> None:
    """Train 1 epoch on GPU with mixed-precision.

    Args:
        cfg_train: A DictConfig containing a valid training configuration.
    """
    cfg_train_amp = cfg_train.copy()
    with open_dict(cfg_train_amp):
        cfg_train.trainer.max_epochs = 1
        cfg_train.trainer.accelerator = "gpu"
        cfg_train.trainer.precision = 16
        del cfg_train.callbacks.molecule_mesh_logging
    train(cfg_train)


def test_train_epoch_double_val_loop(cfg_train: DictConfig, create_two_electron_dataset) -> None:
    """Train 1 epoch with validation loop twice per epoch.

    Args:
        cfg_train: A DictConfig containing a valid training configuration.
    """
    with open_dict(cfg_train):
        cfg_train.trainer.accelerator = "cpu"
        cfg_train.trainer.max_epochs = 1
        cfg_train.trainer.val_check_interval = 0.5
    train(cfg_train)


@pytest.mark.slow
def test_train_ddp_sim(cfg_train: DictConfig, create_two_electron_dataset) -> None:
    """Simulate DDP (Distributed Data Parallel) on 2 CPU processes.

    Args:
        cfg_train: A DictConfig containing a valid training configuration.
    """
    with open_dict(cfg_train):
        cfg_train.trainer.max_epochs = 1
        cfg_train.trainer.accelerator = "cpu"
        cfg_train.trainer.devices = 2
        cfg_train.trainer.strategy = "ddp_spawn"
        # Reduce precision for less ram usage?
        # this leads to the test failing for me
        # cfg_train.trainer.precision = 16
    train(cfg_train)


@pytest.mark.slow
def test_train_resume(tmp_path: Path, cfg_train: DictConfig, create_two_electron_dataset) -> None:
    """Run 1 epoch, finish, and resume for another epoch.

    Args:
        tmp_path: The temporary logging path.
        cfg_train: A DictConfig containing a valid training configuration.
    """
    with open_dict(cfg_train):
        cfg_train.trainer.max_epochs = 1

    metric_dict_1, _ = train(cfg_train)

    files = os.listdir(tmp_path / "checkpoints")
    assert "last.ckpt" in files
    assert "epoch_000.ckpt" in files

    with open_dict(cfg_train):
        cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
        cfg_train.trainer.max_epochs = 2

    metric_dict_2, _ = train(cfg_train)

    files = os.listdir(tmp_path / "checkpoints")
    assert "epoch_001.ckpt" in files or "epoch_000.ckpt" in files
    assert "epoch_002.ckpt" not in files
