from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, Optional
from pathlib import Path
from dataclasses import asdict
import json
import numpy as np
import torch
[docs]
class BaseTrainer(ABC):
"""Abstract base class for all baseline method trainers"""
def __init__(self, config):
self.config = config
self.results = {}
[docs]
@abstractmethod
def load_data(self) -> Tuple[Any, Any, Any]:
"""Load and preprocess data
Returns:
Tuple of (train_data, val_data, test_data)
"""
pass
[docs]
@abstractmethod
def preprocess(self, data: Any) -> Any:
"""Apply method-specific preprocessing
Args:
data: Raw data to preprocess
Returns:
Preprocessed data
"""
pass
[docs]
@abstractmethod
def train(self, train_data: Any) -> Any:
"""Train the model
Args:
train_data: Training data
Returns:
Trained model
"""
pass
[docs]
@abstractmethod
def evaluate(self, model: Any, test_data: Any) -> Dict[str, float]:
"""Evaluate the model
Args:
model: Trained model
test_data: Test data
Returns:
Dictionary of evaluation metrics
"""
pass
[docs]
@abstractmethod
def get_predictions(self, model: Any, data: Any) -> np.ndarray:
"""Get model predictions
Args:
model: Trained model
data: Data to predict on
Returns:
Predictions array
"""
pass
[docs]
def load_checkpoint(self, checkpoint_path: Path) -> Any:
"""Load model from checkpoint (can be overridden by subclasses)
Args:
checkpoint_path: Path to checkpoint file
Returns:
Loaded model
"""
# Default implementation - subclasses should override if needed
return self.load_model(checkpoint_path)
[docs]
def save_model(self, model: Any, path: Path) -> None:
"""Save trained model
Args:
model: Model to save
path: Path to save model
"""
path.parent.mkdir(parents=True, exist_ok=True)
if hasattr(model, 'state_dict'):
torch.save(model.state_dict(), path)
else:
torch.save(model, path)
[docs]
def load_model(self, path: Path, model_class: Any = None, **model_kwargs) -> Any:
"""Load trained model
Args:
path: Path to load model from
model_class: Model class to instantiate (if loading state_dict)
model_kwargs: Arguments for model class
Returns:
Loaded model
"""
if model_class is not None:
# Load state dict into new model instance
model = model_class(**model_kwargs)
model.load_state_dict(torch.load(path))
return model
else:
# Load full model
return torch.load(path)
[docs]
def get_model_info(self, model: Any) -> Dict[str, Any]:
"""Get model information like parameter count
Args:
model: Model to analyze
Returns:
Dictionary with model information
"""
info = {}
if hasattr(model, 'parameters'):
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
info.update({
'total_parameters': total_params,
'trainable_parameters': trainable_params,
'non_trainable_parameters': total_params - trainable_params
})
if hasattr(model, '__class__'):
info['model_class'] = model.__class__.__name__
return info
[docs]
def train_and_evaluate(self) -> Dict[str, Any]:
"""Complete training and evaluation pipeline
Returns:
Results dictionary with metrics and metadata
"""
from .utils import Timer
timer = Timer()
timer.start()
# Check for existing checkpoint
checkpoint_path = Path(self.config.output.folder) / "best_model.pt"
model = None
if getattr(self.config, 'mode', None) == 'eval' and checkpoint_path.exists():
print(f"Loading model from checkpoint: {checkpoint_path}")
try:
model = self.load_checkpoint(checkpoint_path)
print("✓ Model loaded successfully from checkpoint")
except Exception as e:
print(f"✗ Failed to load checkpoint: {e}")
print("Training new model...")
model = None
if model is None:
# Load data
train_data, val_data, test_data = self.load_data()
# Preprocess
train_data = self.preprocess(train_data)
val_data = self.preprocess(val_data) if val_data is not None else None
test_data = self.preprocess(test_data)
# Train
model = self.train(train_data)
# Save model if requested
if self.config.output.save_model:
self.save_model(model, checkpoint_path)
print(f"✓ Model saved to: {checkpoint_path}")
else:
# Load and preprocess data for evaluation only
_, val_data, test_data = self.load_data()
val_data = self.preprocess(val_data) if val_data is not None else None
test_data = self.preprocess(test_data)
# Evaluate
val_metrics = self.evaluate(model, val_data) if val_data is not None else {}
test_metrics = self.evaluate(model, test_data)
# Get model information
model_info = self.get_model_info(model)
# Calculate execution time
execution_time = timer.stop()
# Get batch size from config
batch_size = getattr(self.config.model, 'batch_size', 'N/A')
# Save predictions if requested
predictions_info = {}
if self.config.output.save_predictions:
predictions_path = Path(self.config.output.folder) / "predictions.npy"
val_predictions = self.get_predictions(model, val_data) if val_data is not None else None
test_predictions = self.get_predictions(model, test_data)
import numpy as np
predictions_data = {
'val_predictions': val_predictions,
'test_predictions': test_predictions
}
np.save(predictions_path, predictions_data)
predictions_info['predictions_saved'] = str(predictions_path)
print(f"✓ Predictions saved to: {predictions_path}")
# Format results
results = {
'config': asdict(self.config),
'metrics': {
'val': val_metrics,
'test': test_metrics,
'score': val_metrics.get('accuracy', test_metrics.get('accuracy', 0))
},
'model_info': model_info,
'execution_time_seconds': execution_time,
'batch_size': batch_size,
**predictions_info
}
self.results = results
return results
[docs]
def save_results(self, results: Dict[str, Any], output_path: Path) -> None:
"""Save results to JSON file
Args:
results: Results dictionary
output_path: Path to save results
"""
output_path.mkdir(parents=True, exist_ok=True)
# Save results
with open(output_path / 'results.json', 'w') as f:
json.dump(results, f, indent=2, default=str)
# Save summary
summary = {
'val_accuracy': results['metrics'].get('val', {}).get('accuracy', 'N/A'),
'test_accuracy': results['metrics']['test'].get('accuracy', 'N/A'),
'score': results['metrics']['score']
}
with open(output_path / 'summary.json', 'w') as f:
json.dump(summary, f, indent=2)
def _standardize_results(self, raw_results: Any) -> Dict[str, float]:
"""Convert method-specific results to standard format
Args:
raw_results: Raw results from method
Returns:
Standardized results dictionary
"""
if isinstance(raw_results, dict):
return raw_results
elif isinstance(raw_results, (list, tuple)) and len(raw_results) >= 1:
# Assume first element is accuracy
return {'accuracy': float(raw_results[0])}
elif isinstance(raw_results, (int, float)):
return {'accuracy': float(raw_results)}
else:
return {'accuracy': 0.0}