Source code for src.common.data_manager

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from typing import Tuple, Optional, Dict, Any
from abc import ABC, abstractmethod


[docs] class WSDataLoader(ABC): """Abstract base class for weak supervision data loaders""" def __init__(self, config): self.config = config
[docs] @abstractmethod def load_split_data(self, dataset: str, seed: int) -> Tuple[Any, Any, Any, Optional[Any]]: """Load train/val/test splits Args: dataset: Dataset name seed: Random seed Returns: Tuple of (train_data, val_data, test_data, metadata) """ pass
[docs] def create_dataloader(self, x_data: np.ndarray, labels: np.ndarray, batch_size: int = 200, shuffle: bool = True) -> DataLoader: """Create PyTorch DataLoader Args: x_data: Input data labels: Labels batch_size: Batch size shuffle: Whether to shuffle data Returns: PyTorch DataLoader """ return DataLoader( TensorDataset( torch.tensor(x_data, dtype=torch.float), torch.tensor(labels, dtype=torch.long) ), batch_size=batch_size, shuffle=shuffle )
[docs] def create_soft_dataloader(self, x_data: np.ndarray, soft_labels: np.ndarray, batch_size: int = 200, shuffle: bool = True) -> DataLoader: """Create DataLoader for soft labels Args: x_data: Input data soft_labels: Soft labels/probabilities batch_size: Batch size shuffle: Whether to shuffle data Returns: PyTorch DataLoader """ return DataLoader( TensorDataset( torch.tensor(x_data, dtype=torch.float), torch.tensor(soft_labels, dtype=torch.float) ), batch_size=batch_size, shuffle=shuffle )
[docs] def create_weak_dataloader(self, x_data: np.ndarray, weak_labels: np.ndarray, batch_size: int = 200, shuffle: bool = True) -> DataLoader: """Create DataLoader for weak labels Args: x_data: Input data weak_labels: Weak supervision labels batch_size: Batch size shuffle: Whether to shuffle data Returns: PyTorch DataLoader """ return DataLoader( TensorDataset( torch.tensor(x_data, dtype=torch.float), torch.tensor(weak_labels, dtype=torch.long) ), batch_size=batch_size, shuffle=shuffle )
class StandardWSDataLoader(WSDataLoader): """Standard data loader for WRENCH-style datasets""" def load_split_data(self, dataset: str, seed: int) -> Tuple[Any, Any, Any, Optional[Any]]: """Load standard WRENCH-style data splits Args: dataset: Dataset name seed: Random seed Returns: Tuple of (train_data, val_data, test_data, metadata) """ # This would typically load from the other/ folder structure # For now, return placeholder that can be overridden by specific loaders raise NotImplementedError("Subclasses should implement dataset-specific loading") def filter_abstain_votes(self, data: np.ndarray, votes: np.ndarray, labels: np.ndarray, abstain_value: int = -1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Filter out data points where all votes are abstain Args: data: Input data votes: Weak supervision votes labels: True labels abstain_value: Value representing abstain votes Returns: Filtered (data, votes, labels) """ # Find indices where at least one vote is not abstain valid_indices = ~(votes == abstain_value).all(axis=1) return data[valid_indices], votes[valid_indices], labels[valid_indices] def split_labeled_unlabeled(self, data: np.ndarray, labels: np.ndarray, num_labeled: int, seed: int = 42) -> Tuple[np.ndarray, np.ndarray]: """Split data into labeled and unlabeled portions Args: data: Input data labels: Labels num_labeled: Number of labeled examples to select seed: Random seed Returns: Tuple of (labeled_indices, unlabeled_indices) """ np.random.seed(seed) all_indices = np.arange(len(data)) labeled_indices = np.random.choice(all_indices, size=num_labeled, replace=False) unlabeled_indices = np.setdiff1d(all_indices, labeled_indices) return labeled_indices, unlabeled_indices