Source code for src.common.base_trainer

from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, Optional
from pathlib import Path
from dataclasses import asdict
import json
import numpy as np
import torch

[docs] class BaseTrainer(ABC): """Abstract base class for all baseline method trainers""" def __init__(self, config): self.config = config self.results = {}
[docs] @abstractmethod def load_data(self) -> Tuple[Any, Any, Any]: """Load and preprocess data Returns: Tuple of (train_data, val_data, test_data) """ pass
[docs] @abstractmethod def preprocess(self, data: Any) -> Any: """Apply method-specific preprocessing Args: data: Raw data to preprocess Returns: Preprocessed data """ pass
[docs] @abstractmethod def train(self, train_data: Any) -> Any: """Train the model Args: train_data: Training data Returns: Trained model """ pass
[docs] @abstractmethod def evaluate(self, model: Any, test_data: Any) -> Dict[str, float]: """Evaluate the model Args: model: Trained model test_data: Test data Returns: Dictionary of evaluation metrics """ pass
[docs] @abstractmethod def get_predictions(self, model: Any, data: Any) -> np.ndarray: """Get model predictions Args: model: Trained model data: Data to predict on Returns: Predictions array """ pass
[docs] def load_checkpoint(self, checkpoint_path: Path) -> Any: """Load model from checkpoint (can be overridden by subclasses) Args: checkpoint_path: Path to checkpoint file Returns: Loaded model """ # Default implementation - subclasses should override if needed return self.load_model(checkpoint_path)
[docs] def save_model(self, model: Any, path: Path) -> None: """Save trained model Args: model: Model to save path: Path to save model """ path.parent.mkdir(parents=True, exist_ok=True) if hasattr(model, 'state_dict'): torch.save(model.state_dict(), path) else: torch.save(model, path)
[docs] def load_model(self, path: Path, model_class: Any = None, **model_kwargs) -> Any: """Load trained model Args: path: Path to load model from model_class: Model class to instantiate (if loading state_dict) model_kwargs: Arguments for model class Returns: Loaded model """ if model_class is not None: # Load state dict into new model instance model = model_class(**model_kwargs) model.load_state_dict(torch.load(path)) return model else: # Load full model return torch.load(path)
[docs] def get_model_info(self, model: Any) -> Dict[str, Any]: """Get model information like parameter count Args: model: Model to analyze Returns: Dictionary with model information """ info = {} if hasattr(model, 'parameters'): # Count parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) info.update({ 'total_parameters': total_params, 'trainable_parameters': trainable_params, 'non_trainable_parameters': total_params - trainable_params }) if hasattr(model, '__class__'): info['model_class'] = model.__class__.__name__ return info
[docs] def train_and_evaluate(self) -> Dict[str, Any]: """Complete training and evaluation pipeline Returns: Results dictionary with metrics and metadata """ from .utils import Timer timer = Timer() timer.start() # Check for existing checkpoint checkpoint_path = Path(self.config.output.folder) / "best_model.pt" model = None if getattr(self.config, 'mode', None) == 'eval' and checkpoint_path.exists(): print(f"Loading model from checkpoint: {checkpoint_path}") try: model = self.load_checkpoint(checkpoint_path) print("✓ Model loaded successfully from checkpoint") except Exception as e: print(f"✗ Failed to load checkpoint: {e}") print("Training new model...") model = None if model is None: # Load data train_data, val_data, test_data = self.load_data() # Preprocess train_data = self.preprocess(train_data) val_data = self.preprocess(val_data) if val_data is not None else None test_data = self.preprocess(test_data) # Train model = self.train(train_data) # Save model if requested if self.config.output.save_model: self.save_model(model, checkpoint_path) print(f"✓ Model saved to: {checkpoint_path}") else: # Load and preprocess data for evaluation only _, val_data, test_data = self.load_data() val_data = self.preprocess(val_data) if val_data is not None else None test_data = self.preprocess(test_data) # Evaluate val_metrics = self.evaluate(model, val_data) if val_data is not None else {} test_metrics = self.evaluate(model, test_data) # Get model information model_info = self.get_model_info(model) # Calculate execution time execution_time = timer.stop() # Get batch size from config batch_size = getattr(self.config.model, 'batch_size', 'N/A') # Save predictions if requested predictions_info = {} if self.config.output.save_predictions: predictions_path = Path(self.config.output.folder) / "predictions.npy" val_predictions = self.get_predictions(model, val_data) if val_data is not None else None test_predictions = self.get_predictions(model, test_data) import numpy as np predictions_data = { 'val_predictions': val_predictions, 'test_predictions': test_predictions } np.save(predictions_path, predictions_data) predictions_info['predictions_saved'] = str(predictions_path) print(f"✓ Predictions saved to: {predictions_path}") # Format results results = { 'config': asdict(self.config), 'metrics': { 'val': val_metrics, 'test': test_metrics, 'score': val_metrics.get('accuracy', test_metrics.get('accuracy', 0)) }, 'model_info': model_info, 'execution_time_seconds': execution_time, 'batch_size': batch_size, **predictions_info } self.results = results return results
[docs] def save_results(self, results: Dict[str, Any], output_path: Path) -> None: """Save results to JSON file Args: results: Results dictionary output_path: Path to save results """ output_path.mkdir(parents=True, exist_ok=True) # Save results with open(output_path / 'results.json', 'w') as f: json.dump(results, f, indent=2, default=str) # Save summary summary = { 'val_accuracy': results['metrics'].get('val', {}).get('accuracy', 'N/A'), 'test_accuracy': results['metrics']['test'].get('accuracy', 'N/A'), 'score': results['metrics']['score'] } with open(output_path / 'summary.json', 'w') as f: json.dump(summary, f, indent=2)
def _standardize_results(self, raw_results: Any) -> Dict[str, float]: """Convert method-specific results to standard format Args: raw_results: Raw results from method Returns: Standardized results dictionary """ if isinstance(raw_results, dict): return raw_results elif isinstance(raw_results, (list, tuple)) and len(raw_results) >= 1: # Assume first element is accuracy return {'accuracy': float(raw_results[0])} elif isinstance(raw_results, (int, float)): return {'accuracy': float(raw_results)} else: return {'accuracy': 0.0}