import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from typing import Tuple, Optional, Dict, Any
from abc import ABC, abstractmethod
[docs]
class WSDataLoader(ABC):
"""Abstract base class for weak supervision data loaders"""
def __init__(self, config):
self.config = config
[docs]
@abstractmethod
def load_split_data(self, dataset: str, seed: int) -> Tuple[Any, Any, Any, Optional[Any]]:
"""Load train/val/test splits
Args:
dataset: Dataset name
seed: Random seed
Returns:
Tuple of (train_data, val_data, test_data, metadata)
"""
pass
[docs]
def create_dataloader(self, x_data: np.ndarray, labels: np.ndarray,
batch_size: int = 200, shuffle: bool = True) -> DataLoader:
"""Create PyTorch DataLoader
Args:
x_data: Input data
labels: Labels
batch_size: Batch size
shuffle: Whether to shuffle data
Returns:
PyTorch DataLoader
"""
return DataLoader(
TensorDataset(
torch.tensor(x_data, dtype=torch.float),
torch.tensor(labels, dtype=torch.long)
),
batch_size=batch_size,
shuffle=shuffle
)
[docs]
def create_soft_dataloader(self, x_data: np.ndarray, soft_labels: np.ndarray,
batch_size: int = 200, shuffle: bool = True) -> DataLoader:
"""Create DataLoader for soft labels
Args:
x_data: Input data
soft_labels: Soft labels/probabilities
batch_size: Batch size
shuffle: Whether to shuffle data
Returns:
PyTorch DataLoader
"""
return DataLoader(
TensorDataset(
torch.tensor(x_data, dtype=torch.float),
torch.tensor(soft_labels, dtype=torch.float)
),
batch_size=batch_size,
shuffle=shuffle
)
[docs]
def create_weak_dataloader(self, x_data: np.ndarray, weak_labels: np.ndarray,
batch_size: int = 200, shuffle: bool = True) -> DataLoader:
"""Create DataLoader for weak labels
Args:
x_data: Input data
weak_labels: Weak supervision labels
batch_size: Batch size
shuffle: Whether to shuffle data
Returns:
PyTorch DataLoader
"""
return DataLoader(
TensorDataset(
torch.tensor(x_data, dtype=torch.float),
torch.tensor(weak_labels, dtype=torch.long)
),
batch_size=batch_size,
shuffle=shuffle
)
class StandardWSDataLoader(WSDataLoader):
"""Standard data loader for WRENCH-style datasets"""
def load_split_data(self, dataset: str, seed: int) -> Tuple[Any, Any, Any, Optional[Any]]:
"""Load standard WRENCH-style data splits
Args:
dataset: Dataset name
seed: Random seed
Returns:
Tuple of (train_data, val_data, test_data, metadata)
"""
# This would typically load from the other/ folder structure
# For now, return placeholder that can be overridden by specific loaders
raise NotImplementedError("Subclasses should implement dataset-specific loading")
def filter_abstain_votes(self, data: np.ndarray, votes: np.ndarray, labels: np.ndarray,
abstain_value: int = -1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Filter out data points where all votes are abstain
Args:
data: Input data
votes: Weak supervision votes
labels: True labels
abstain_value: Value representing abstain votes
Returns:
Filtered (data, votes, labels)
"""
# Find indices where at least one vote is not abstain
valid_indices = ~(votes == abstain_value).all(axis=1)
return data[valid_indices], votes[valid_indices], labels[valid_indices]
def split_labeled_unlabeled(self, data: np.ndarray, labels: np.ndarray,
num_labeled: int, seed: int = 42) -> Tuple[np.ndarray, np.ndarray]:
"""Split data into labeled and unlabeled portions
Args:
data: Input data
labels: Labels
num_labeled: Number of labeled examples to select
seed: Random seed
Returns:
Tuple of (labeled_indices, unlabeled_indices)
"""
np.random.seed(seed)
all_indices = np.arange(len(data))
labeled_indices = np.random.choice(all_indices, size=num_labeled, replace=False)
unlabeled_indices = np.setdiff1d(all_indices, labeled_indices)
return labeled_indices, unlabeled_indices