# data.py
# -*- coding: utf-8 -*-
"""
Data utilities for multimodal (video + audio) scene classification.

Features
--------
- Video loader: decord (preferred) -> OpenCV fallback
- Audio loader: torchaudio (preferred) -> librosa fallback
- Mel-spectrogram extraction (log-mel)
- Uniform temporal sampling of frames
- Works with single-label (int) or multi-label (comma-separated) targets
- Safe collate with padding for variable-length audio/time
- Minimal external assumptions about annotation format

Expected annotation CSV
-----------------------
Required columns:
    video_path : str, path to video file
Optional columns:
    audio_path : str, path to audio file (if empty, audio extracted from video or skipped)
    label      : str/int, single label (int) or comma-separated multi-labels (e.g., "0,3,7")
    start_sec  : float, optional clip start time in seconds (default 0)
    end_sec    : float, optional clip end time in seconds (default None -> full length)

Example:
    video_path,audio_path,label,start_sec,end_sec
    data/film1.mp4,,3,0,12.5
    data/film2.mp4,data/film2.wav,0,,
    data/film3.mp4,,1,30,45
"""

from __future__ import annotations
import os
import math
import csv
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

# Optional deps
try:
    import decord  # type: ignore
    from decord import VideoReader, cpu as decord_cpu
    _HAS_DECORD = True
except Exception:
    _HAS_DECORD = False

try:
    import cv2  # type: ignore
    _HAS_OPENCV = True
except Exception:
    _HAS_OPENCV = False

try:
    import torchaudio  # type: ignore
    _HAS_TORCHAUDIO = True
except Exception:
    _HAS_TORCHAUDIO = False

try:
    import librosa  # type: ignore
    _HAS_LIBROSA = True
except Exception:
    _HAS_LIBROSA = False

try:
    from torchvision import transforms as T  # type: ignore
    _HAS_TORCHVISION = True
except Exception:
    _HAS_TORCHVISION = False


# ----------------------------
# Utility: safe assert message
# ----------------------------
def _require(cond: bool, msg: str):
    if not cond:
        raise RuntimeError(msg)


# ----------------------------
# Video loading helpers
# ----------------------------
def _load_video_frames(
    path: str,
    num_frames: int,
    fps: Optional[float] = None,
    start_sec: float = 0.0,
    end_sec: Optional[float] = None,
    resize_hw: Tuple[int, int] = (224, 224),
) -> Tensor:
    """
    Returns frames as float tensor [T, C, H, W] in range [0,1].
    Uniformly samples `num_frames` between [start_sec, end_sec] (or full video).
    """
    _require(os.path.isfile(path), f"Video not found: {path}")

    if _HAS_DECORD:
        vr = VideoReader(path, ctx=decord_cpu(0))
        total_frames = len(vr)
        video_fps = float(vr.get_avg_fps())
        if fps is None:
            fps = video_fps

        if end_sec is None:
            end_sec = total_frames / video_fps

        start_idx = max(0, int(start_sec * video_fps))
        end_idx = min(total_frames - 1, int(end_sec * video_fps))
        idxs = _uniform_sample_indices(start_idx, end_idx, num_frames)

        frames = vr.get_batch(idxs)  # (T, H, W, 3), uint8
        frames = frames.asnumpy()
        frames = _resize_numpy_batch(frames, resize_hw)
        frames = torch.from_numpy(frames).float() / 255.0  # (T, H, W, 3)
        frames = frames.permute(0, 3, 1, 2)  # -> (T, C, H, W)
        return frames

    _require(_HAS_OPENCV, "Neither decord nor opencv found. Install `decord` (preferred) or `opencv-python`.")
    cap = cv2.VideoCapture(path)
    _require(cap.isOpened(), f"Failed to open video with OpenCV: {path}")

    video_fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
    if fps is None:
        fps = video_fps

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if end_sec is None:
        end_sec = total_frames / video_fps

    start_idx = max(0, int(start_sec * video_fps))
    end_idx = min(total_frames - 1, int(end_sec * video_fps))
    idxs = _uniform_sample_indices(start_idx, end_idx, num_frames)

    frames = []
    for idx in idxs:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ok, frame = cap.read()
        _require(ok, f"Failed to read frame {idx} from {path}")
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (resize_hw[1], resize_hw[0]), interpolation=cv2.INTER_LINEAR)
        frames.append(frame)
    cap.release()

    frames = torch.from_numpy(
        (torch.stack([torch.from_numpy(f) for f in frames]).numpy())
    ).float() / 255.0  # (T, H, W, 3)
    frames = frames.permute(0, 3, 1, 2)  # (T, C, H, W)
    return frames


def _uniform_sample_indices(start_idx: int, end_idx: int, num_samples: int) -> List[int]:
    _require(end_idx >= start_idx, "end_idx must be >= start_idx")
    if num_samples <= 1:
        return [start_idx]
    span = max(1, end_idx - start_idx + 1)
    steps = torch.linspace(0, span - 1, num_samples).round().long().tolist()
    return [start_idx + int(s) for s in steps]


def _resize_numpy_batch(frames_np, resize_hw: Tuple[int, int]):
    # frames_np: (T, H, W, 3) uint8
    if frames_np.shape[1:3] == (resize_hw[0], resize_hw[1]):
        return frames_np
    if _HAS_OPENCV:
        return np_stack_resize_cv(frames_np, resize_hw)  # type: ignore
    # Fallback: PIL via torchvision if available
    if _HAS_TORCHVISION:
        import numpy as np
        from PIL import Image
        out = []
        for f in frames_np:
            img = Image.fromarray(f)
            img = img.resize((resize_hw[1], resize_hw[0]), Image.BILINEAR)
            out.append(np.array(img))
        return np.stack(out, axis=0)
    # Minimal pure-numpy (nearest) fallback
    import numpy as np
    out = []
    th, tw = resize_hw
    for f in frames_np:
        # crude nearest neighbor resize
        h, w = f.shape[:2]
        ys = (np.linspace(0, h - 1, th)).astype(int)
        xs = (np.linspace(0, w - 1, tw)).astype(int)
        out.append(f[ys][:, xs])
    return np.stack(out, axis=0)


def np_stack_resize_cv(frames_np, resize_hw: Tuple[int, int]):
    import numpy as np
    out = []
    for f in frames_np:
        out.append(cv2.resize(f, (resize_hw[1], resize_hw[0]), interpolation=cv2.INTER_LINEAR))
    return np.stack(out, axis=0)


# ----------------------------
# Audio loading helpers
# ----------------------------
def _load_audio_mel(
    audio_path: Optional[str],
    video_path: str,
    sample_rate: int = 16000,
    n_mels: int = 128,
    win_length_ms: int = 25,
    hop_length_ms: int = 10,
    start_sec: float = 0.0,
    end_sec: Optional[float] = None,
    log_eps: float = 1e-6,
) -> Optional[Tensor]:
    """
    Returns log-mel spectrogram as Tensor [n_mels, T] (float32), or None if audio unavailable.
    If audio_path is None, attempts to read audio track from video when torchaudio is present.
    """
    sr = sample_rate
    win_length = int(sr * (win_length_ms / 1000.0))
    hop_length = int(sr * (hop_length_ms / 1000.0))

    # 1) Try torchaudio
    if _HAS_TORCHAUDIO:
        try:
            if audio_path and os.path.isfile(audio_path):
                wav, file_sr = torchaudio.load(audio_path)
            else:
                # extract from video
                wav, file_sr = torchaudio.load(video_path)
            if wav.shape[0] > 1:
                wav = torch.mean(wav, dim=0, keepdim=True)  # mono
            if file_sr != sr:
                wav = torchaudio.functional.resample(wav, file_sr, sr)

            # crop by time if end_sec provided
            if end_sec is not None:
                start = int(start_sec * sr)
                end = int(end_sec * sr)
                wav = wav[:, start:end]
            else:
                if start_sec > 0:
                    start = int(start_sec * sr)
                    wav = wav[:, start:]

            mel_spec = torchaudio.transforms.MelSpectrogram(
                sample_rate=sr,
                n_fft=2048,
                win_length=win_length,
                hop_length=hop_length,
                n_mels=n_mels,
                center=True,
                power=2.0,
            )(wav)
            mel_spec = torch.log(mel_spec + log_eps).squeeze(0)  # [n_mels, T]
            return mel_spec
        except Exception:
            pass  # fall through to librosa

    # 2) librosa fallback
    if _HAS_LIBROSA:
        try:
            import numpy as np
            # librosa can't read from video directly; require explicit audio_path
            if audio_path and os.path.isfile(audio_path):
                wav, file_sr = librosa.load(audio_path, sr=sr, mono=True)
            else:
                return None  # no audio available
            if end_sec is not None:
                start = int(start_sec * sr)
                end = int(end_sec * sr)
                wav = wav[start:end]
            else:
                if start_sec > 0:
                    start = int(start_sec * sr)
                    wav = wav[start:]
            mel = librosa.feature.melspectrogram(
                y=wav,
                sr=sr,
                n_fft=2048,
                hop_length=hop_length,
                win_length=win_length,
                n_mels=n_mels,
                center=True,
                power=2.0,
            )
            mel = np.log(mel + log_eps).astype("float32")
            return torch.from_numpy(mel)  # [n_mels, T]
        except Exception:
            return None

    return None


# ----------------------------
# Label parsing
# ----------------------------
def _parse_label(raw: str, num_classes: Optional[int] = None) -> Union[int, Tensor]:
    """
    If raw includes comma -> multi-label one-hot vector (len=num_classes required).
    Else -> int label.
    """
    if raw is None:
        raise ValueError("Missing label value.")
    txt = str(raw).strip()
    if "," in txt:
        _require(num_classes is not None, "num_classes must be provided for multi-label targets.")
        idxs = []
        for p in txt.split(","):
            p = p.strip()
            if p == "":
                continue
            idxs.append(int(p))
        y = torch.zeros(num_classes, dtype=torch.float32)
        for i in idxs:
            if 0 <= i < num_classes:
                y[i] = 1.0
        return y
    return int(txt)


# ----------------------------
# Dataset
# ----------------------------
class VideoAudioSceneDataset(Dataset):
    def __init__(
        self,
        csv_path: str,
        num_frames: int = 16,
        fps: Optional[float] = None,
        resize_hw: Tuple[int, int] = (224, 224),
        sample_rate: int = 16000,
        n_mels: int = 128,
        win_length_ms: int = 25,
        hop_length_ms: int = 10,
        transform: Optional[Any] = None,  # torchvision-like transform applied per-frame (C,H,W)
        return_paths: bool = False,
        num_classes: Optional[int] = None,  # required for multi-label
    ):
        """
        Args:
            csv_path: annotation CSV (see header in module docstring)
            transform: torchvision-like transform applied to each frame tensor (C,H,W)
        """
        _require(os.path.isfile(csv_path), f"CSV not found: {csv_path}")
        self.items = self._read_csv(csv_path)
        self.num_frames = num_frames
        self.fps = fps
        self.resize_hw = resize_hw
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.win_length_ms = win_length_ms
        self.hop_length_ms = hop_length_ms
        self.transform = transform
        self.return_paths = return_paths
        self.num_classes = num_classes

        if self.transform is None and _HAS_TORCHVISION:
            # reasonable defaults
            self.transform = T.Compose([
                T.ConvertImageDtype(torch.float32),
                T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
            ])

    def _read_csv(self, csv_path: str) -> List[Dict[str, str]]:
        rows: List[Dict[str, str]] = []
        with open(csv_path, "r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            _require("video_path" in reader.fieldnames, "CSV must have column `video_path`.")
            for r in reader:
                rows.append(r)
        _require(len(rows) > 0, "Empty CSV.")
        return rows

    def __len__(self) -> int:
        return len(self.items)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.items[idx]
        vpath = row["video_path"]
        apath = row.get("audio_path", None) or None
        label_raw = row.get("label", None)
        start_sec = float(row.get("start_sec", 0.0) or 0.0)
        end_sec = row.get("end_sec", None)
        end_sec = float(end_sec) if (end_sec not in [None, ""]) else None

        # video frames [T, C, H, W]
        frames = _load_video_frames(
            vpath,
            num_frames=self.num_frames,
            fps=self.fps,
            start_sec=start_sec,
            end_sec=end_sec,
            resize_hw=self.resize_hw,
        )
        # per-frame transform
        if self.transform is not None:
            frames = torch.stack([self.transform(frames[t]) for t in range(frames.shape[0])], dim=0)

        # audio log-mel [n_mels, Ta] or None
        mel = _load_audio_mel(
            apath,
            vpath,
            sample_rate=self.sample_rate,
            n_mels=self.n_mels,
            win_length_ms=self.win_length_ms,
            hop_length_ms=self.hop_length_ms,
            start_sec=start_sec,
            end_sec=end_sec,
        )

        # label
        y: Optional[Union[int, Tensor]] = None
        if label_raw is not None and str(label_raw).strip() != "":
            y = _parse_label(label_raw, num_classes=self.num_classes)

        sample: Dict[str, Any] = {
            "video": frames,           # [T, C, H, W]
            "audio_mel": mel,          # [n_mels, Ta] or None
            "label": y,                # int or 1D tensor or None
            "start_sec": start_sec,
            "end_sec": end_sec,
        }
        if self.return_paths:
            sample["video_path"] = vpath
            if apath:
                sample["audio_path"] = apath
        return sample


# ----------------------------
# Collate: pad variable length
# ----------------------------
def pad_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Pads audio_mel to max time length in batch.
    Stacks video frames (assumed same T by sampling).
    """
    videos = [b["video"] for b in batch]  # [B, T, C, H, W]
    T_vid = videos[0].shape[0]
    _require(all(v.shape[0] == T_vid for v in videos), "All videos in a batch must share T (use same num_frames).")
    video = torch.stack(videos, dim=0)

    # audio_mel padding
    mels = [b.get("audio_mel") for b in batch]
    has_mel = all(m is not None for m in mels)
    if has_mel:
        n_mels = mels[0].shape[0]
        max_Ta = max(int(m.shape[1]) for m in mels)  # type: ignore
        mel_padded = torch.zeros((len(mels), n_mels, max_Ta), dtype=torch.float32)
        mel_lens = torch.zeros((len(mels),), dtype=torch.long)
        for i, m in enumerate(mels):
            t = m.shape[1]  # type: ignore
            mel_padded[i, :, :t] = m  # type: ignore
            mel_lens[i] = t
    else:
        mel_padded, mel_lens = None, None

    # labels
    labels = [b.get("label") for b in batch]
    if all((isinstance(y, int) or (isinstance(y, torch.Tensor) and y.dim() == 0)) for y in labels if y is not None):
        # single-label ints -> tensor
        y_out = torch.tensor([int(y) if y is not None else -1 for y in labels], dtype=torch.long)
    elif all(isinstance(y, torch.Tensor) and y.dim() == 1 for y in labels if y is not None):
        # multi-label -> stack, fill missing with zeros
        max_dim = max(y.numel() for y in labels if y is not None)
        y_out = torch.zeros((len(labels), max_dim), dtype=torch.float32)
        for i, y in enumerate(labels):
            if y is not None:
                y_out[i, :y.numel()] = y
    else:
        y_out = None

    out: Dict[str, Any] = {
        "video": video,                 # [B, T, C, H, W]
        "audio_mel": mel_padded,       # [B, n_mels, Ta] or None
        "audio_len": mel_lens,         # [B] or None
        "label": y_out,                # [B] or [B, K] or None
    }
    return out


# ----------------------------
# Builder
# ----------------------------
def build_dataloader(
    csv_path: str,
    batch_size: int = 8,
    shuffle: bool = True,
    num_workers: int = 4,
    **dataset_kwargs,
) -> DataLoader:
    ds = VideoAudioSceneDataset(csv_path=csv_path, **dataset_kwargs)
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=pad_collate,
    )


# ----------------------------
# Quick self-check (optional)
# ----------------------------
if __name__ == "__main__":
    # Minimal smoke test (requires a real CSV and media)
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--csv", required=True, help="Path to annotation CSV.")
    parser.add_argument("--batch_size", type=int, default=2)
    args = parser.parse_args()

    dl = build_dataloader(
        args.csv,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        num_frames=8,
        resize_hw=(224, 224),
        num_classes=12,  # set if using multi-labels
        return_paths=True,
    )

    batch = next(iter(dl))
    print("Video:", batch["video"].shape)
    if batch["audio_mel"] is not None:
        print("Audio mel:", batch["audio_mel"].shape, "lens:", batch["audio_len"])
    print("Label:", None if batch["label"] is None else batch["label"].shape)
1