"""DOI cache system for storing CrossRef API responses."""

import json
import logging
import time
from pathlib import Path
from typing import Any, Optional

logger = logging.getLogger(__name__)


class DOICache:
    """Cache system for DOI metadata from CrossRef API."""

    def __init__(
        self,
        cache_dir: str = ".cache",
        cache_filename: Optional[str] = None,
        manuscript_name: Optional[str] = None,
    ):
        """Initialize DOI cache.

        Args:
            cache_dir: Directory to store cache files
            cache_filename: Name of the cache file (if None, uses manuscript-specific naming)
            manuscript_name: Name of the manuscript (used for manuscript-specific caching)
        """
        self.cache_dir = Path(cache_dir)
        self.manuscript_name = manuscript_name

        # Determine cache filename
        if cache_filename is not None:
            # Use provided filename (backward compatibility)
            self.cache_file = self.cache_dir / cache_filename
        elif manuscript_name is not None:
            # Use manuscript-specific filename
            self.cache_file = self.cache_dir / f"doi_cache_{manuscript_name}.json"
        else:
            # Default filename
            self.cache_file = self.cache_dir / "doi_cache.json"

        self.cache_expiry_days = 30

        # Create cache directory if it doesn't exist
        self.cache_dir.mkdir(exist_ok=True)

        # Load existing cache
        self._cache = self._load_cache()

    def _load_cache(self) -> dict[str, Any]:
        """Load cache from file."""
        if not self.cache_file.exists():
            return {}

        try:
            with open(self.cache_file, encoding="utf-8") as f:
                cache_data = json.load(f)

            # Clean expired entries
            current_time = time.time()
            cleaned_cache = {}

            for doi, entry in cache_data.items():
                if "timestamp" in entry:
                    # Check if entry is still valid
                    entry_time = entry["timestamp"]
                    if (current_time - entry_time) < (
                        self.cache_expiry_days * 24 * 3600
                    ):
                        cleaned_cache[doi] = entry
                    else:
                        logger.debug(f"Expired cache entry for DOI: {doi}")
                else:
                    # Legacy entries without timestamp - remove them
                    logger.debug(f"Removing legacy cache entry for DOI: {doi}")

            return cleaned_cache

        except (json.JSONDecodeError, KeyError) as e:
            logger.warning(f"Error loading cache file: {e}. Starting with empty cache.")
            return {}

    def _save_cache(self) -> None:
        """Save cache to file."""
        try:
            with open(self.cache_file, "w", encoding="utf-8") as f:
                json.dump(self._cache, f, indent=2, ensure_ascii=False)
        except Exception as e:
            logger.error(f"Error saving cache file: {e}")

    def get(self, doi: str) -> Optional[dict[str, Any]]:
        """Get cached metadata for a DOI.

        Args:
            doi: DOI to look up

        Returns:
            Cached metadata if available and not expired, None otherwise
        """
        normalized_doi = doi.lower().strip()

        if normalized_doi in self._cache:
            entry = self._cache[normalized_doi]

            # Check if entry is still valid
            if "timestamp" in entry:
                current_time = time.time()
                entry_time = entry["timestamp"]

                if (current_time - entry_time) < (self.cache_expiry_days * 24 * 3600):
                    logger.debug(f"Cache hit for DOI: {doi}")
                    return entry.get("metadata")
                else:
                    # Entry expired, remove it
                    logger.debug(f"Cache entry expired for DOI: {doi}")
                    del self._cache[normalized_doi]
                    self._save_cache()

        logger.debug(f"Cache miss for DOI: {doi}")
        return None

    def set(self, doi: str, metadata: dict[str, Any]) -> None:
        """Cache metadata for a DOI.

        Args:
            doi: DOI to cache
            metadata: Metadata to cache
        """
        normalized_doi = doi.lower().strip()

        self._cache[normalized_doi] = {"metadata": metadata, "timestamp": time.time()}

        self._save_cache()
        logger.debug(f"Cached metadata for DOI: {doi}")

    def set_resolution_status(
        self, doi: str, resolves: bool, error_message: Optional[str] = None
    ) -> None:
        """Cache DOI resolution status.

        Args:
            doi: DOI to cache status for
            resolves: Whether the DOI resolves
            error_message: Optional error message if resolution failed
        """
        normalized_doi = doi.lower().strip()

        resolution_data = {
            "resolves": resolves,
            "error_message": error_message,
            "timestamp": time.time(),
        }

        # If we already have cached data, update it, otherwise create new entry
        if normalized_doi in self._cache:
            self._cache[normalized_doi]["resolution"] = resolution_data
        else:
            self._cache[normalized_doi] = {
                "metadata": None,
                "resolution": resolution_data,
                "timestamp": time.time(),
            }

        self._save_cache()
        logger.debug(f"Cached resolution status for DOI {doi}: {resolves}")

    def get_resolution_status(self, doi: str) -> Optional[dict[str, Any]]:
        """Get cached resolution status for a DOI.

        Args:
            doi: DOI to look up

        Returns:
            Resolution status if available and not expired, None otherwise
        """
        normalized_doi = doi.lower().strip()

        if normalized_doi in self._cache:
            entry = self._cache[normalized_doi]

            # Check if resolution status exists and is not expired
            if "resolution" in entry:
                resolution_data = entry["resolution"]
                if "timestamp" in resolution_data:
                    current_time = time.time()
                    entry_time = resolution_data["timestamp"]

                    if (current_time - entry_time) < (
                        self.cache_expiry_days * 24 * 3600
                    ):
                        logger.debug(f"Cache hit for DOI resolution: {doi}")
                        return resolution_data
                    else:
                        # Resolution data expired, remove it
                        logger.debug(f"Cache entry expired for DOI resolution: {doi}")
                        del entry["resolution"]
                        self._save_cache()

        logger.debug(f"Cache miss for DOI resolution: {doi}")
        return None

    def clear(self) -> None:
        """Clear all cached entries."""
        self._cache.clear()
        self._save_cache()
        logger.info("Cleared DOI cache")

    def cleanup_expired(self) -> int:
        """Remove expired entries from cache.

        Returns:
            Number of entries removed
        """
        current_time = time.time()
        expired_dois = []

        for doi, entry in self._cache.items():
            if "timestamp" in entry:
                entry_time = entry["timestamp"]
                if (current_time - entry_time) >= (self.cache_expiry_days * 24 * 3600):
                    expired_dois.append(doi)

        for doi in expired_dois:
            del self._cache[doi]

        if expired_dois:
            self._save_cache()
            logger.info(f"Removed {len(expired_dois)} expired cache entries")

        return len(expired_dois)

    def stats(self) -> dict[str, Any]:
        """Get cache statistics.

        Returns:
            Dictionary with cache statistics
        """
        current_time = time.time()
        valid_entries = 0
        expired_entries = 0

        for entry in self._cache.values():
            if "timestamp" in entry:
                entry_time = entry["timestamp"]
                if (current_time - entry_time) < (self.cache_expiry_days * 24 * 3600):
                    valid_entries += 1
                else:
                    expired_entries += 1

        return {
            "manuscript_name": self.manuscript_name,
            "total_entries": len(self._cache),
            "valid_entries": valid_entries,
            "expired_entries": expired_entries,
            "cache_file": str(self.cache_file),
            "cache_size_bytes": self.cache_file.stat().st_size
            if self.cache_file.exists()
            else 0,
        }
