# dataset.py

"""
This file defines the Dataset class for loading and preprocessing data.
It handles tokenization and structuring of the source code, binary code, 
and comments for use in the ContraBin model.
"""

import torch

class Dataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset class to handle the loading and preprocessing of triplet data 
    (source code, binary code, and comments).
    """

    def __init__(self, dataframe, tokenizer):
        """
        Initializes the Dataset with the provided data and tokenizer.

        Args:
            dataframe (pd.DataFrame): A pandas DataFrame containing the source code,
                                      binary code, and comments.
            tokenizer (transformers.PreTrainedTokenizer): A tokenizer for encoding text.
        """
        self.triplets = dataframe

        # Tokenizing source code
        self.encoded_source = tokenizer(
            list(dataframe['source']), 
            padding=True, 
            truncation=True, 
            max_length=configs.source_max_length
        )

        # Tokenizing binary code
        self.encoded_binary = tokenizer(
            list(dataframe['binary']), 
            padding=True, 
            truncation=True, 
            max_length=configs.binary_max_length
        )

        # Tokenizing comments
        self.encoded_comment = tokenizer(
            list(dataframe['comment']), 
            padding=True, 
            truncation=True, 
            max_length=configs.comment_max_length
        )

    def __getitem__(self, idx):
        """
        Retrieves the triplet data (source, binary, comment) for a specific index.

        Args:
            idx (int): The index of the data to retrieve.

        Returns:
            tuple: Encoded representations of source, binary, and comment.
        """
        # Source code tokens
        source_item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_source.items()
        }

        # Binary code tokens
        binary_item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_binary.items()
        }

        # Comment tokens
        comment_item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_comment.items()
        }

        return source_item, binary_item, comment_item

    def __len__(self):
        """
        Returns the total number of data points in the dataset.

        Returns:
            int: The length of the dataset.
        """
        return len(self.triplets)