from pathlib import Path
from uuid import UUID

import netCDF4
import numpy as np
import pytest
from cloudnetpy_qc import quality

SITE_META = {
    "name": "Kumpula",
    "altitude": 50,
    "latitude": 23,
    "longitude": 34.0,
}


class Check:
    temp_path: str
    nc: netCDF4.Dataset
    date: str
    site_meta: dict
    uuid: UUID

    @pytest.fixture(autouse=True)
    def run_before_and_after_tests(self):
        self.nc = netCDF4.Dataset(self.temp_path)
        yield
        self.nc.close()

    def test_qc(self):
        n = 0
        report = quality.run_tests(
            Path(self.temp_path),
            {"time": None, "latitude": 0, "longitude": 0, "altitude": 0},
            ignore_tests=["TestCFConvention", "TestCoordinates"],
        )
        keys = ("TestUnits", "TestLongNames", "TestStandardNames")
        for test in report.tests:
            if test.test_id in keys:
                assert not test.exceptions, test.exceptions
                n += 1
        assert n == len(keys)

    def test_common(self):
        all_fun = AllProductsFun(self.nc, self.site_meta, self.date, self.uuid)
        for name, method in AllProductsFun.__dict__.items():
            if "test_" in name:
                getattr(all_fun, name)()


class AllProductsFun:
    """Common tests for all Cloudnet products."""

    def __init__(self, nc: netCDF4.Dataset, site_meta: dict, date: str, uuid: UUID):
        self.nc = nc
        self.site_meta = site_meta
        self.date = date
        self.uuid = uuid

    def test_variable_names(self):
        keys = {"time", "latitude", "longitude", "altitude"}
        for key in keys:
            assert key in self.nc.variables

    def test_nan_values(self):
        for key in self.nc.variables.keys():
            assert bool(np.isnan(self.nc.variables[key]).all()) is False

    def test_time_axis(self):
        assert self.nc.variables["time"].axis == "T"

    def test_empty_units(self):
        for key in self.nc.variables.keys():
            if hasattr(self.nc.variables[key], "units"):
                value = self.nc.variables[key].units
                assert value != "", f"{key} - {value}"

    def test_variable_values(self):
        for key in ("altitude", "latitude", "longitude"):
            value = self.nc.variables[key][:]
            expected = self.site_meta[key]
            assert np.all(np.isclose(value, expected, atol=1e-2)), (
                f"{value} != {expected}"
            )

    def test_invalid_units(self):
        for key in self.nc.variables:
            variable = self.nc.variables[key]
            assert hasattr(variable, "units")
            assert variable.units != "", key

    def test_units(self):
        """Custom units that are not tested in QC tests."""
        data = [
            ("time", f"hours since {self.date} 00:00:00 +00:00"),
        ]
        for key, expected in data:
            if key in self.nc.variables:
                value = self.nc.variables[key].units
                assert value == expected, f"{value} != {expected}"

    def test_long_name_format(self):
        for key in self.nc.variables:
            assert hasattr(self.nc.variables[key], "long_name")
            value = self.nc.variables[key].long_name
            assert not value.endswith(".")

    def test_global_attributes(self):
        assert self.nc.location == self.site_meta["name"]
        assert self.nc.file_uuid == str(self.uuid)
        assert self.nc.Conventions == "CF-1.8"
        y, m, d = self.date.split("-")
        assert self.nc.year == y
        assert self.nc.month == m
        assert self.nc.day == d
        for key in ("cloudnetpy_version", "references", "history"):
            assert hasattr(self.nc, key)
