import torch
from torch.utils.data import Dataset
import numpy as np
from ConfigV1 import *

# Dataset
class MultichannelDataset(Dataset):
    def __init__(self, input_data, reflectivity_data):
        """
        Args:
            input_data: Input data with channels as the last dimension, shape (num_samples, height, width, num_channels).
            reflectivity_data: Output data (e.g., reflectivity), shape (num_samples, height, width).
        """
        # Reorder input data to have channels first: (num_samples, num_channels, height, width)
        self.input_data = np.transpose(input_data, (0, 3, 1, 2))
        self.reflectivity_data = reflectivity_data

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
        input_sample = self.input_data[idx]  # Shape: (num_channels, height, width)
        reflectivity_sample = self.reflectivity_data[idx]  # Shape: (height, width)

        # Convert to PyTorch tensors
        input_tensor = torch.tensor(input_sample, dtype=torch.float32)
        reflectivity_tensor = torch.tensor(reflectivity_sample, dtype=torch.float32)

        return input_tensor, reflectivity_tensor
