Source code for src.common.eval_manager

import numpy as np
from typing import Dict, List, Any, Optional
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from snorkel.labeling import LFAnalysis, analysis





def compute_majority_vote(weak_labels: np.ndarray) -> np.ndarray:
    """Efficiently compute majority vote using numpy operations"""
    majority_votes = np.full(len(weak_labels), -1)
    for i in range(len(weak_labels)):
        votes = weak_labels[i]
        non_abstain = votes[votes != -1]
        if len(non_abstain) > 0:
            # Use bincount for efficient majority vote
            counts = np.bincount(non_abstain)
            majority_votes[i] = np.argmax(counts)
    return majority_votes

def get_label_patterns(weak_labels: np.ndarray) -> tuple:
    """Get unique labeling patterns and their counts"""
    patterns = [tuple(row) for row in weak_labels]
    unique_patterns, counts = np.unique(patterns, return_counts=True, axis=0)
    return unique_patterns, counts


class BaseEvaluator:
    """Base evaluator with core classification metrics"""
    
    @staticmethod
    def compute_metrics(predictions: np.ndarray, labels: np.ndarray, 
                       average: str = 'weighted') -> Dict[str, float]:
        """Compute basic classification metrics
        
        Args:
            predictions: Predicted labels
            labels: True labels
            average: Averaging strategy for multi-class metrics
            
        Returns:
            Dictionary of basic metrics
        """
        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, predictions, average=average, zero_division=0
        )
        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1': float(f1)
        }


[docs] class Evaluator(BaseEvaluator): """Extended evaluator with weak supervision specific metrics"""
[docs] @staticmethod def compute_metrics(predictions: np.ndarray, labels: np.ndarray, average: str = 'weighted', **kwargs) -> Dict[str, float]: """Compute standard classification metrics with optional weak supervision metrics Args: predictions: Predicted labels labels: True labels average: Averaging strategy for multi-class metrics kwargs: Additional arguments (weak_labels, prediction_probs, loss_info, etc.) Returns: Dictionary of metrics """ # Get base metrics metrics = super(Evaluator, Evaluator).compute_metrics(predictions, labels, average) # Add weak supervision specific metrics if 'weak_labels' in kwargs: weak_metrics = Evaluator.compute_weak_supervision_metrics( predictions, labels, kwargs['weak_labels'] ) metrics.update(weak_metrics) # Add coverage metrics if soft predictions available if 'prediction_probs' in kwargs: coverage_metrics = Evaluator.compute_coverage_metrics(kwargs['prediction_probs']) metrics.update(coverage_metrics) # Add training info if available if 'loss_info' in kwargs: loss_info = kwargs['loss_info'] metrics['final_loss'] = float(loss_info.get('final_loss', 0.0)) metrics['loss_converged'] = bool(loss_info.get('converged', False)) if 'prediction_ranges' in kwargs: pred_ranges = kwargs['prediction_ranges'] metrics['prediction_min'] = float(pred_ranges.get('min', 0.0)) metrics['prediction_max'] = float(pred_ranges.get('max', 1.0)) metrics['prediction_std'] = float(pred_ranges.get('std', 0.0)) return metrics
[docs] @staticmethod def compute_coverage(predictions: np.ndarray, abstain_threshold: float = 0.001) -> float: """Compute coverage (percentage of non-abstaining predictions) Args: predictions: Soft predictions or confidence scores abstain_threshold: Threshold below which predictions are considered abstaining Returns: Coverage ratio """ if predictions.ndim > 1: # For soft predictions, check if max probability is above threshold max_probs = np.max(predictions, axis=1) coverage = np.mean(np.abs(max_probs - 0.5) > abstain_threshold) else: # For hard predictions, assume -1 is abstain coverage = np.mean(predictions != -1) return float(coverage)
[docs] @staticmethod def compute_adjusted_accuracy(soft_predictions: np.ndarray, labels: np.ndarray) -> float: """Compute accuracy that handles abstaining predictions Args: soft_predictions: Soft prediction probabilities labels: True labels Returns: Adjusted accuracy """ # Convert soft predictions to hard predictions hard_predictions = np.argmax(soft_predictions, axis=1) # Identify non-abstaining predictions (not exactly 0.5 probability) max_probs = np.max(soft_predictions, axis=1) non_abstain_mask = np.abs(max_probs - 0.5) > 0.001 if np.sum(non_abstain_mask) == 0: return 0.0 # Compute accuracy only on non-abstaining predictions return float(accuracy_score(labels[non_abstain_mask], hard_predictions[non_abstain_mask]))
[docs] @staticmethod def compute_confusion_matrix(predictions: np.ndarray, labels: np.ndarray) -> np.ndarray: """Compute confusion matrix Args: predictions: Predicted labels labels: True labels Returns: Confusion matrix """ return confusion_matrix(labels, predictions)
[docs] @staticmethod def evaluate_weak_supervision_method(results: Dict[str, Any], labels: np.ndarray, method_name: str = "unknown") -> Dict[str, Any]: """Comprehensive evaluation for weak supervision methods Args: results: Dictionary containing predictions and other method outputs labels: True labels method_name: Name of the method being evaluated Returns: Comprehensive evaluation dictionary """ evaluation = { 'method': method_name, 'total_samples': len(labels) } # Handle different result formats if 'predictions' in results: predictions = results['predictions'] elif isinstance(results, np.ndarray): predictions = results else: # Try to extract predictions from results predictions = None for key in ['preds', 'pred', 'output', 'y_pred']: if key in results: predictions = results[key] break if predictions is None: evaluation['error'] = "No predictions found in results" return evaluation # Compute metrics based on prediction type if predictions.ndim > 1 and predictions.shape[1] > 1: # Soft predictions evaluation['prediction_type'] = 'soft' evaluation['coverage'] = Evaluator.compute_coverage(predictions) evaluation['adjusted_accuracy'] = Evaluator.compute_adjusted_accuracy(predictions, labels) # Convert to hard predictions for standard metrics hard_predictions = np.argmax(predictions, axis=1) standard_metrics = Evaluator.compute_metrics(hard_predictions, labels) evaluation.update(standard_metrics) else: # Hard predictions evaluation['prediction_type'] = 'hard' if predictions.ndim > 1: predictions = predictions.flatten() # Filter out abstain predictions (-1) non_abstain_mask = predictions != -1 if np.sum(non_abstain_mask) > 0: filtered_preds = predictions[non_abstain_mask] filtered_labels = labels[non_abstain_mask] standard_metrics = Evaluator.compute_metrics(filtered_preds, filtered_labels) evaluation.update(standard_metrics) evaluation['coverage'] = float(np.mean(non_abstain_mask)) else: evaluation['coverage'] = 0.0 evaluation['accuracy'] = 0.0 # Add confusion matrix if possible if 'accuracy' in evaluation and evaluation['accuracy'] > 0: try: if predictions.ndim > 1: hard_preds = np.argmax(predictions, axis=1) else: hard_preds = predictions[predictions != -1] filtered_labels = labels[predictions != -1] cm = Evaluator.compute_confusion_matrix( hard_preds, filtered_labels if predictions.ndim == 1 else labels ) evaluation['confusion_matrix'] = cm.tolist() except Exception as e: evaluation['confusion_matrix_error'] = str(e) return evaluation
[docs] @staticmethod def compute_weak_supervision_metrics( predictions: np.ndarray, labels: np.ndarray, weak_labels: np.ndarray ) -> Dict[str, float]: """Compute metrics specific to weak supervision using Snorkel's LFAnalysis Args: predictions: Model predictions labels: True labels weak_labels: Weak supervision labels (L matrix from Snorkel) Returns: Dictionary of weak supervision metrics """ lf_analysis = LFAnalysis(L=weak_labels) # Get summary with ground truth if available if labels is not None: lf_summary = lf_analysis.lf_summary(Y=labels) else: lf_summary = lf_analysis.lf_summary() metrics = {} # Extract key metrics from LF summary metrics["avg_labeler_coverage"] = float(lf_summary["Coverage"].mean()) metrics["abstain_rate"] = float(1.0 - lf_summary["Coverage"].mean()) if "Emp. Acc." in lf_summary.columns: # Individual labeler accuracies emp_accs = lf_summary["Emp. Acc."].fillna(0.0) metrics["avg_labeler_accuracy"] = float(emp_accs.mean()) # Add individual labeler metrics for i, (coverage, accuracy) in enumerate(zip(lf_summary["Coverage"], emp_accs)): metrics[f"labeler_{i}_coverage"] = float(coverage) metrics[f"labeler_{i}_accuracy"] = float(accuracy) # Overall coverage (any labeler voted) non_abstain_mask = (weak_labels != -1).any(axis=1) metrics["weak_label_coverage"] = float(non_abstain_mask.mean()) # Conflict and overlap analysis from LFAnalysis if len(lf_summary) > 1: # Multi-labeler case metrics["avg_conflicts"] = float(lf_summary["Conflicts"].mean()) metrics["avg_overlaps"] = float(lf_summary["Overlaps"].mean()) # Majority vote analysis (more efficient) majority_votes = compute_majority_vote(weak_labels) valid_majority = majority_votes != -1 if valid_majority.sum() > 0: metrics["majority_vote_agreement"] = float( (predictions[valid_majority] == majority_votes[valid_majority]).mean() ) if labels is not None: metrics["majority_vote_accuracy"] = float( (majority_votes[valid_majority] == labels[valid_majority]).mean() ) else: # Single labeler case valid_votes = weak_labels.flatten() != -1 if valid_votes.sum() > 0: if labels is not None: metrics["weak_label_accuracy"] = float( (weak_labels.flatten()[valid_votes] == labels[valid_votes]).mean() ) metrics["weak_label_agreement"] = float( (predictions[valid_votes] == weak_labels.flatten()[valid_votes]).mean() ) # Label distribution analysis using Snorkel utilities unique_patterns, pattern_counts = get_label_patterns(weak_labels) metrics['num_label_patterns'] = len(unique_patterns) # Most common labeling pattern max_pattern_count = max(pattern_counts) metrics['largest_pattern_size'] = int(max_pattern_count) metrics['largest_pattern_ratio'] = float(max_pattern_count) / len(weak_labels) return metrics
[docs] @staticmethod def compute_coverage_metrics(prediction_probs: np.ndarray, confidence_threshold: float = 0.5) -> Dict[str, float]: """Compute coverage and confidence metrics from prediction probabilities Args: prediction_probs: Prediction probabilities/logits confidence_threshold: Threshold for confident predictions Returns: Dictionary of coverage metrics """ metrics = {} if prediction_probs.ndim > 1: # Multi-class probabilities max_probs = np.max(prediction_probs, axis=1) confident_mask = max_probs >= confidence_threshold metrics['confident_coverage'] = float(np.mean(confident_mask)) metrics['avg_confidence'] = float(np.mean(max_probs)) metrics['min_confidence'] = float(np.min(max_probs)) metrics['max_confidence'] = float(np.max(max_probs)) # Entropy-based uncertainty eps = 1e-8 entropies = -np.sum(prediction_probs * np.log(prediction_probs + eps), axis=1) metrics['avg_entropy'] = float(np.mean(entropies)) metrics['prediction_uncertainty'] = float(np.std(entropies)) else: # Binary probabilities or single values confident_mask = np.abs(prediction_probs - 0.5) >= (confidence_threshold - 0.5) metrics['confident_coverage'] = float(np.mean(confident_mask)) metrics['avg_confidence'] = float(np.mean(np.abs(prediction_probs - 0.5) + 0.5)) return metrics