Source code for epbd_bert.datasets.sequence_randepbd_multimodal_dataset

import torch
import transformers

from epbd_bert.datasets.sequence_epbd_dataset import SequenceEPBDDataset


[docs] class SequenceRandEPBDMultiModalDataset(SequenceEPBDDataset): """Dataset for multi-modal transformer""" def __init__( self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, home_dir="" ): super().__init__(data_path, tokenizer, home_dir) def _get_epbd_features(self, fname): return torch.rand(6, 200, dtype=torch.float32)