Multi-View Contrastive Learning Domain Generalization
Description
MVCLDG (Multi-View Contrastive Learning Domain Generalization) is a cross-subject domain generalization model for electroencephalogram (EEG) signal classification tasks. This model significantly enhances the generalization performance of event-related potential (ERP) recognition tasks on new subjects through multi-view feature extraction and domain-invariant representation learning.
Multi-view feature fusion: Simultaneously utilizes amplitude information from raw EEG signals and phase information derived from Hilbert Transform (HT) to enhance feature discriminability
Domain-invariant representation learning: Minimizes cross-domain feature distribution differences through domain alignment loss and contrastive learning loss
Multi-view contrastive learning: Simultaneously optimizes contrastive learning on raw, HT, and fused views
pip install torch numpy scipy scikit-learn matplotlib tqdm
An example execution is as follows:
from mvclg import MVCLDGModel, MVCLDGTrainer, EEGDatasetWithHT # Initialize model model = MVCLDGModel(input_shape=(1, 64, 256), num_classes=2, num_domains=4) # Prepare data dataset = EEGDatasetWithHT(data_path='path/to/data', dataset_id=1, include_ht=True) data_loader = DataLoader(dataset, batch_size=32, shuffle=True) domain_labels = dataset.get_domain_labels() # Create trainer and train trainer = MVCLDGTrainer(model, device='cuda', config={'learning_rate': 1e-3, ...}) trainer.train(data_loader, domain_labels, epochs=50)
Parameter num_domains specifies the number of source domains for domain generalization. Parameter include_ht controls whether to include the Hilbert Transform view. The model is trained using a combination of classification loss, domain alignment loss, and contrastive learning loss, with weights controlled by tradeoff_align and tradeoff_contrast in the config dictionary.
To use the train function, the data should be provided via a DataLoader that yields batches of EEG data and corresponding labels. The EEGDatasetWithHT class handles loading and preprocessing of EEG data, including optional Hilbert Transform computation for the phase view. The domain_labels are used to identify which domain each sample belongs to, which is essential for computing the domain alignment loss.
Files
Files
(25.9 kB)
| Name | Size | Download all |
|---|---|---|
|
md5:6159b3c90a71b833deda1d57d9b49368
|
25.9 kB | Download |