# data_processing.py

"""
This file contains utilities for preprocessing and tokenizing datasets, as well as custom dataset
definitions for training the ContraBin model. It includes parsing functions for source code,
binary code, and comments.
"""

import re
import pandas as pd
import torch
from transformers import AutoTokenizer


def parse_decompile(x):
    """
    Parses decompiled binary code by removing comments, typedefs, and extra whitespace.

    Args:
        x (str): Decompiled binary code.

    Returns:
        str: Cleaned and parsed binary code.
    """
    x = re.sub(r"//.*\n", "\n", x)  # Remove single-line comments
    i_type = x.rfind("typedef")
    if i_type != -1:
        x = ";\n".join(x[i_type:].split(";\n")[2:])
    i_ext = x.rfind("extern")
    if i_ext != -1:
        x = "\n".join(x[i_ext:].split("\n")[2:])
    x = x.replace("\n", " ")  # Remove newlines
    x = re.sub(r'(\w)\s+(\w)', r'\1 \2', x)  # Strip whitespace between words
    x = re.sub(r'(?![\+\-\*\/=><])(\W)\s+(?![\+\-\*\/=><])(\W)', r'\1\2', x)  # Strip between symbols
    x = re.sub(r'(?![\+\-\*\/=><])(\W)\s+(\w)', r'\1\2', x)  # Strip between symbols and words
    x = re.sub(r'(\w)\s+(?![\+\-\*\/=><])(\W)', r'\1\2', x)  # Strip between words and symbols
    return x


def parse_source(x):
    """
    Cleans source code by removing comments, tabs, and extra whitespace.

    Args:
        x (str): Source code.

    Returns:
        str: Parsed source code.
    """
    x = re.sub(r"\/\*.*\*\/", " ", x)  # Remove block comments
    x = x.replace("\t", " ")  # Replace tabs with spaces
    x = x.replace("\n", " ")  # Remove newlines
    x = re.sub(r'(\w)\s+(\w)', r'\1 \2', x)  # Strip whitespace between words
    x = re.sub(r'(?![\+\-\*\/=><])(\W)\s+(?![\+\-\*\/=><])(\W)', r'\1\2', x)  # Strip between symbols
    x = re.sub(r'(?![\+\-\*\/=><])(\W)\s+(\w)', r'\1\2', x)  # Strip between symbols and words
    x = re.sub(r'(\w)\s+(?![\+\-\*\/=><])(\W)', r'\1\2', x)  # Strip between words and symbols
    return x


def parse_binary(x):
    """
    Cleans binary code by removing extra spaces.

    Args:
        x (str): Binary code.

    Returns:
        str: Cleaned binary code.
    """
    return "".join(x.split("  "))


class Dataset(torch.utils.data.Dataset):
    """
    Custom dataset for ContraBin that handles tokenization of source code, binary code, and comments.
    """

    def __init__(self, dataframe, tokenizer):
        """
        Initializes the dataset.

        Args:
            dataframe (pd.DataFrame): DataFrame containing source, binary, and comment data.
            tokenizer: Tokenizer for encoding text data.
        """
        self.triplets = dataframe
        self.encoded_source = tokenizer(
            list(dataframe['source']), padding=True, truncation=True, max_length=configs.source_max_length
        )
        self.encoded_binary = tokenizer(
            list(dataframe['binary']), padding=True, truncation=True, max_length=configs.binary_max_length
        )
        self.encoded_comment = tokenizer(
            list(dataframe['comment']), padding=True, truncation=True, max_length=configs.comment_max_length
        )

    def __getitem__(self, idx):
        """
        Retrieves a single data item.

        Args:
            idx (int): Index of the item.

        Returns:
            tuple: Tokenized source, binary, and comment data.
        """
        source_item = {key: torch.tensor(values[idx]) for key, values in self.encoded_source.items()}
        binary_item = {key: torch.tensor(values[idx]) for key, values in self.encoded_binary.items()}
        comment_item = {key: torch.tensor(values[idx]) for key, values in self.encoded_comment.items()}
        return source_item, binary_item, comment_item

    def __len__(self):
        """
        Returns the length of the dataset.

        Returns:
            int: Number of data items in the dataset.
        """
        return len(self.triplets)


def build_loaders(dataset, mode):
    """
    Builds DataLoader for training or validation.

    Args:
        dataset (Dataset): Custom dataset object.
        mode (str): "train" or "valid" for determining shuffle behavior.

    Returns:
        DataLoader: DataLoader object.
    """
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=configs.batch_size,
        num_workers=configs.num_workers,
        shuffle=(mode == "train")
    )