src.common.WSDataLoader

class src.common.WSDataLoader(config)[source]

Bases: ABC

Abstract base class for weak supervision data loaders

Methods

create_dataloader(x_data, labels[, ...])

Create PyTorch DataLoader

create_soft_dataloader(x_data, soft_labels)

Create DataLoader for soft labels

create_weak_dataloader(x_data, weak_labels)

Create DataLoader for weak labels

load_split_data(dataset, seed)

Load train/val/test splits

create_dataloader(x_data, labels, batch_size=200, shuffle=True)[source]

Create PyTorch DataLoader

Args:

x_data: Input data labels: Labels batch_size: Batch size shuffle: Whether to shuffle data

Returns:

PyTorch DataLoader

Parameters:
  • x_data (numpy.ndarray)

  • labels (numpy.ndarray)

  • batch_size (int)

  • shuffle (bool)

Return type:

torch.utils.data.DataLoader

create_soft_dataloader(x_data, soft_labels, batch_size=200, shuffle=True)[source]

Create DataLoader for soft labels

Args:

x_data: Input data soft_labels: Soft labels/probabilities batch_size: Batch size shuffle: Whether to shuffle data

Returns:

PyTorch DataLoader

Parameters:
  • x_data (numpy.ndarray)

  • soft_labels (numpy.ndarray)

  • batch_size (int)

  • shuffle (bool)

Return type:

torch.utils.data.DataLoader

create_weak_dataloader(x_data, weak_labels, batch_size=200, shuffle=True)[source]

Create DataLoader for weak labels

Args:

x_data: Input data weak_labels: Weak supervision labels batch_size: Batch size shuffle: Whether to shuffle data

Returns:

PyTorch DataLoader

Parameters:
  • x_data (numpy.ndarray)

  • weak_labels (numpy.ndarray)

  • batch_size (int)

  • shuffle (bool)

Return type:

torch.utils.data.DataLoader

abstractmethod load_split_data(dataset, seed)[source]

Load train/val/test splits

Args:

dataset: Dataset name seed: Random seed

Returns:

Tuple of (train_data, val_data, test_data, metadata)

Parameters:
  • dataset (str)

  • seed (int)

Return type:

Tuple[Any, Any, Any, Any | None]