src.common.BaseTrainer

class src.common.BaseTrainer(config)[source]

Bases: ABC

Abstract base class for all baseline method trainers

Methods

evaluate(model, test_data)

Evaluate the model

get_model_info(model)

Get model information like parameter count

get_predictions(model, data)

Get model predictions

load_checkpoint(checkpoint_path)

Load model from checkpoint (can be overridden by subclasses)

load_data()

Load and preprocess data

load_model(path[, model_class])

Load trained model

preprocess(data)

Apply method-specific preprocessing

save_model(model, path)

Save trained model

save_results(results, output_path)

Save results to JSON file

train(train_data)

Train the model

train_and_evaluate()

Complete training and evaluation pipeline

abstractmethod evaluate(model, test_data)[source]

Evaluate the model

Args:

model: Trained model test_data: Test data

Returns:

Dictionary of evaluation metrics

Parameters:
  • model (Any)

  • test_data (Any)

Return type:

Dict[str, float]

get_model_info(model)[source]

Get model information like parameter count

Args:

model: Model to analyze

Returns:

Dictionary with model information

Parameters:

model (Any)

Return type:

Dict[str, Any]

abstractmethod get_predictions(model, data)[source]

Get model predictions

Args:

model: Trained model data: Data to predict on

Returns:

Predictions array

Parameters:
  • model (Any)

  • data (Any)

Return type:

numpy.ndarray

load_checkpoint(checkpoint_path)[source]

Load model from checkpoint (can be overridden by subclasses)

Args:

checkpoint_path: Path to checkpoint file

Returns:

Loaded model

Parameters:

checkpoint_path (Path)

Return type:

Any

abstractmethod load_data()[source]

Load and preprocess data

Returns:

Tuple of (train_data, val_data, test_data)

Return type:

Tuple[Any, Any, Any]

load_model(path, model_class=None, **model_kwargs)[source]

Load trained model

Args:

path: Path to load model from model_class: Model class to instantiate (if loading state_dict) model_kwargs: Arguments for model class

Returns:

Loaded model

Parameters:
  • path (Path)

  • model_class (Any)

Return type:

Any

abstractmethod preprocess(data)[source]

Apply method-specific preprocessing

Args:

data: Raw data to preprocess

Returns:

Preprocessed data

Parameters:

data (Any)

Return type:

Any

save_model(model, path)[source]

Save trained model

Args:

model: Model to save path: Path to save model

Parameters:
  • model (Any)

  • path (Path)

Return type:

None

save_results(results, output_path)[source]

Save results to JSON file

Args:

results: Results dictionary output_path: Path to save results

Parameters:
  • results (Dict[str, Any])

  • output_path (Path)

Return type:

None

abstractmethod train(train_data)[source]

Train the model

Args:

train_data: Training data

Returns:

Trained model

Parameters:

train_data (Any)

Return type:

Any

train_and_evaluate()[source]

Complete training and evaluation pipeline

Returns:

Results dictionary with metrics and metadata

Return type:

Dict[str, Any]