src.common
.WSDataLoader¶
- class src.common.WSDataLoader(config)[source]¶
Bases:
ABC
Abstract base class for weak supervision data loaders
Methods
create_dataloader
(x_data, labels[, ...])Create PyTorch DataLoader
create_soft_dataloader
(x_data, soft_labels)Create DataLoader for soft labels
create_weak_dataloader
(x_data, weak_labels)Create DataLoader for weak labels
load_split_data
(dataset, seed)Load train/val/test splits
- create_dataloader(x_data, labels, batch_size=200, shuffle=True)[source]¶
Create PyTorch DataLoader
- Args:
x_data: Input data labels: Labels batch_size: Batch size shuffle: Whether to shuffle data
- Returns:
PyTorch DataLoader
- Parameters:
x_data (numpy.ndarray)
labels (numpy.ndarray)
batch_size (int)
shuffle (bool)
- Return type:
torch.utils.data.DataLoader
- create_soft_dataloader(x_data, soft_labels, batch_size=200, shuffle=True)[source]¶
Create DataLoader for soft labels
- Args:
x_data: Input data soft_labels: Soft labels/probabilities batch_size: Batch size shuffle: Whether to shuffle data
- Returns:
PyTorch DataLoader
- Parameters:
x_data (numpy.ndarray)
soft_labels (numpy.ndarray)
batch_size (int)
shuffle (bool)
- Return type:
torch.utils.data.DataLoader
- create_weak_dataloader(x_data, weak_labels, batch_size=200, shuffle=True)[source]¶
Create DataLoader for weak labels
- Args:
x_data: Input data weak_labels: Weak supervision labels batch_size: Batch size shuffle: Whether to shuffle data
- Returns:
PyTorch DataLoader
- Parameters:
x_data (numpy.ndarray)
weak_labels (numpy.ndarray)
batch_size (int)
shuffle (bool)
- Return type:
torch.utils.data.DataLoader