import numpy as np
from typing import Dict, List, Any, Optional
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from snorkel.labeling import LFAnalysis, analysis
def compute_majority_vote(weak_labels: np.ndarray) -> np.ndarray:
"""Efficiently compute majority vote using numpy operations"""
majority_votes = np.full(len(weak_labels), -1)
for i in range(len(weak_labels)):
votes = weak_labels[i]
non_abstain = votes[votes != -1]
if len(non_abstain) > 0:
# Use bincount for efficient majority vote
counts = np.bincount(non_abstain)
majority_votes[i] = np.argmax(counts)
return majority_votes
def get_label_patterns(weak_labels: np.ndarray) -> tuple:
"""Get unique labeling patterns and their counts"""
patterns = [tuple(row) for row in weak_labels]
unique_patterns, counts = np.unique(patterns, return_counts=True, axis=0)
return unique_patterns, counts
class BaseEvaluator:
"""Base evaluator with core classification metrics"""
@staticmethod
def compute_metrics(predictions: np.ndarray, labels: np.ndarray,
average: str = 'weighted') -> Dict[str, float]:
"""Compute basic classification metrics
Args:
predictions: Predicted labels
labels: True labels
average: Averaging strategy for multi-class metrics
Returns:
Dictionary of basic metrics
"""
accuracy = accuracy_score(labels, predictions)
precision, recall, f1, _ = precision_recall_fscore_support(
labels, predictions, average=average, zero_division=0
)
return {
'accuracy': float(accuracy),
'precision': float(precision),
'recall': float(recall),
'f1': float(f1)
}
[docs]
class Evaluator(BaseEvaluator):
"""Extended evaluator with weak supervision specific metrics"""
[docs]
@staticmethod
def compute_metrics(predictions: np.ndarray, labels: np.ndarray,
average: str = 'weighted', **kwargs) -> Dict[str, float]:
"""Compute standard classification metrics with optional weak supervision metrics
Args:
predictions: Predicted labels
labels: True labels
average: Averaging strategy for multi-class metrics
kwargs: Additional arguments (weak_labels, prediction_probs, loss_info, etc.)
Returns:
Dictionary of metrics
"""
# Get base metrics
metrics = super(Evaluator, Evaluator).compute_metrics(predictions, labels, average)
# Add weak supervision specific metrics
if 'weak_labels' in kwargs:
weak_metrics = Evaluator.compute_weak_supervision_metrics(
predictions, labels, kwargs['weak_labels']
)
metrics.update(weak_metrics)
# Add coverage metrics if soft predictions available
if 'prediction_probs' in kwargs:
coverage_metrics = Evaluator.compute_coverage_metrics(kwargs['prediction_probs'])
metrics.update(coverage_metrics)
# Add training info if available
if 'loss_info' in kwargs:
loss_info = kwargs['loss_info']
metrics['final_loss'] = float(loss_info.get('final_loss', 0.0))
metrics['loss_converged'] = bool(loss_info.get('converged', False))
if 'prediction_ranges' in kwargs:
pred_ranges = kwargs['prediction_ranges']
metrics['prediction_min'] = float(pred_ranges.get('min', 0.0))
metrics['prediction_max'] = float(pred_ranges.get('max', 1.0))
metrics['prediction_std'] = float(pred_ranges.get('std', 0.0))
return metrics
[docs]
@staticmethod
def compute_coverage(predictions: np.ndarray, abstain_threshold: float = 0.001) -> float:
"""Compute coverage (percentage of non-abstaining predictions)
Args:
predictions: Soft predictions or confidence scores
abstain_threshold: Threshold below which predictions are considered abstaining
Returns:
Coverage ratio
"""
if predictions.ndim > 1:
# For soft predictions, check if max probability is above threshold
max_probs = np.max(predictions, axis=1)
coverage = np.mean(np.abs(max_probs - 0.5) > abstain_threshold)
else:
# For hard predictions, assume -1 is abstain
coverage = np.mean(predictions != -1)
return float(coverage)
[docs]
@staticmethod
def compute_adjusted_accuracy(soft_predictions: np.ndarray, labels: np.ndarray) -> float:
"""Compute accuracy that handles abstaining predictions
Args:
soft_predictions: Soft prediction probabilities
labels: True labels
Returns:
Adjusted accuracy
"""
# Convert soft predictions to hard predictions
hard_predictions = np.argmax(soft_predictions, axis=1)
# Identify non-abstaining predictions (not exactly 0.5 probability)
max_probs = np.max(soft_predictions, axis=1)
non_abstain_mask = np.abs(max_probs - 0.5) > 0.001
if np.sum(non_abstain_mask) == 0:
return 0.0
# Compute accuracy only on non-abstaining predictions
return float(accuracy_score(labels[non_abstain_mask], hard_predictions[non_abstain_mask]))
[docs]
@staticmethod
def compute_confusion_matrix(predictions: np.ndarray, labels: np.ndarray) -> np.ndarray:
"""Compute confusion matrix
Args:
predictions: Predicted labels
labels: True labels
Returns:
Confusion matrix
"""
return confusion_matrix(labels, predictions)
[docs]
@staticmethod
def evaluate_weak_supervision_method(results: Dict[str, Any], labels: np.ndarray,
method_name: str = "unknown") -> Dict[str, Any]:
"""Comprehensive evaluation for weak supervision methods
Args:
results: Dictionary containing predictions and other method outputs
labels: True labels
method_name: Name of the method being evaluated
Returns:
Comprehensive evaluation dictionary
"""
evaluation = {
'method': method_name,
'total_samples': len(labels)
}
# Handle different result formats
if 'predictions' in results:
predictions = results['predictions']
elif isinstance(results, np.ndarray):
predictions = results
else:
# Try to extract predictions from results
predictions = None
for key in ['preds', 'pred', 'output', 'y_pred']:
if key in results:
predictions = results[key]
break
if predictions is None:
evaluation['error'] = "No predictions found in results"
return evaluation
# Compute metrics based on prediction type
if predictions.ndim > 1 and predictions.shape[1] > 1:
# Soft predictions
evaluation['prediction_type'] = 'soft'
evaluation['coverage'] = Evaluator.compute_coverage(predictions)
evaluation['adjusted_accuracy'] = Evaluator.compute_adjusted_accuracy(predictions, labels)
# Convert to hard predictions for standard metrics
hard_predictions = np.argmax(predictions, axis=1)
standard_metrics = Evaluator.compute_metrics(hard_predictions, labels)
evaluation.update(standard_metrics)
else:
# Hard predictions
evaluation['prediction_type'] = 'hard'
if predictions.ndim > 1:
predictions = predictions.flatten()
# Filter out abstain predictions (-1)
non_abstain_mask = predictions != -1
if np.sum(non_abstain_mask) > 0:
filtered_preds = predictions[non_abstain_mask]
filtered_labels = labels[non_abstain_mask]
standard_metrics = Evaluator.compute_metrics(filtered_preds, filtered_labels)
evaluation.update(standard_metrics)
evaluation['coverage'] = float(np.mean(non_abstain_mask))
else:
evaluation['coverage'] = 0.0
evaluation['accuracy'] = 0.0
# Add confusion matrix if possible
if 'accuracy' in evaluation and evaluation['accuracy'] > 0:
try:
if predictions.ndim > 1:
hard_preds = np.argmax(predictions, axis=1)
else:
hard_preds = predictions[predictions != -1]
filtered_labels = labels[predictions != -1]
cm = Evaluator.compute_confusion_matrix(
hard_preds, filtered_labels if predictions.ndim == 1 else labels
)
evaluation['confusion_matrix'] = cm.tolist()
except Exception as e:
evaluation['confusion_matrix_error'] = str(e)
return evaluation
[docs]
@staticmethod
def compute_weak_supervision_metrics(
predictions: np.ndarray, labels: np.ndarray, weak_labels: np.ndarray
) -> Dict[str, float]:
"""Compute metrics specific to weak supervision using Snorkel's LFAnalysis
Args:
predictions: Model predictions
labels: True labels
weak_labels: Weak supervision labels (L matrix from Snorkel)
Returns:
Dictionary of weak supervision metrics
"""
lf_analysis = LFAnalysis(L=weak_labels)
# Get summary with ground truth if available
if labels is not None:
lf_summary = lf_analysis.lf_summary(Y=labels)
else:
lf_summary = lf_analysis.lf_summary()
metrics = {}
# Extract key metrics from LF summary
metrics["avg_labeler_coverage"] = float(lf_summary["Coverage"].mean())
metrics["abstain_rate"] = float(1.0 - lf_summary["Coverage"].mean())
if "Emp. Acc." in lf_summary.columns:
# Individual labeler accuracies
emp_accs = lf_summary["Emp. Acc."].fillna(0.0)
metrics["avg_labeler_accuracy"] = float(emp_accs.mean())
# Add individual labeler metrics
for i, (coverage, accuracy) in enumerate(zip(lf_summary["Coverage"], emp_accs)):
metrics[f"labeler_{i}_coverage"] = float(coverage)
metrics[f"labeler_{i}_accuracy"] = float(accuracy)
# Overall coverage (any labeler voted)
non_abstain_mask = (weak_labels != -1).any(axis=1)
metrics["weak_label_coverage"] = float(non_abstain_mask.mean())
# Conflict and overlap analysis from LFAnalysis
if len(lf_summary) > 1: # Multi-labeler case
metrics["avg_conflicts"] = float(lf_summary["Conflicts"].mean())
metrics["avg_overlaps"] = float(lf_summary["Overlaps"].mean())
# Majority vote analysis (more efficient)
majority_votes = compute_majority_vote(weak_labels)
valid_majority = majority_votes != -1
if valid_majority.sum() > 0:
metrics["majority_vote_agreement"] = float(
(predictions[valid_majority] == majority_votes[valid_majority]).mean()
)
if labels is not None:
metrics["majority_vote_accuracy"] = float(
(majority_votes[valid_majority] == labels[valid_majority]).mean()
)
else: # Single labeler case
valid_votes = weak_labels.flatten() != -1
if valid_votes.sum() > 0:
if labels is not None:
metrics["weak_label_accuracy"] = float(
(weak_labels.flatten()[valid_votes] == labels[valid_votes]).mean()
)
metrics["weak_label_agreement"] = float(
(predictions[valid_votes] == weak_labels.flatten()[valid_votes]).mean()
)
# Label distribution analysis using Snorkel utilities
unique_patterns, pattern_counts = get_label_patterns(weak_labels)
metrics['num_label_patterns'] = len(unique_patterns)
# Most common labeling pattern
max_pattern_count = max(pattern_counts)
metrics['largest_pattern_size'] = int(max_pattern_count)
metrics['largest_pattern_ratio'] = float(max_pattern_count) / len(weak_labels)
return metrics
[docs]
@staticmethod
def compute_coverage_metrics(prediction_probs: np.ndarray,
confidence_threshold: float = 0.5) -> Dict[str, float]:
"""Compute coverage and confidence metrics from prediction probabilities
Args:
prediction_probs: Prediction probabilities/logits
confidence_threshold: Threshold for confident predictions
Returns:
Dictionary of coverage metrics
"""
metrics = {}
if prediction_probs.ndim > 1:
# Multi-class probabilities
max_probs = np.max(prediction_probs, axis=1)
confident_mask = max_probs >= confidence_threshold
metrics['confident_coverage'] = float(np.mean(confident_mask))
metrics['avg_confidence'] = float(np.mean(max_probs))
metrics['min_confidence'] = float(np.min(max_probs))
metrics['max_confidence'] = float(np.max(max_probs))
# Entropy-based uncertainty
eps = 1e-8
entropies = -np.sum(prediction_probs * np.log(prediction_probs + eps), axis=1)
metrics['avg_entropy'] = float(np.mean(entropies))
metrics['prediction_uncertainty'] = float(np.std(entropies))
else:
# Binary probabilities or single values
confident_mask = np.abs(prediction_probs - 0.5) >= (confidence_threshold - 0.5)
metrics['confident_coverage'] = float(np.mean(confident_mask))
metrics['avg_confidence'] = float(np.mean(np.abs(prediction_probs - 0.5) + 0.5))
return metrics