import tempfile
import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Union, Dict, Callable
import optuna
import optuna.samplers
import optuna.trial
import numpy as np
import json
import importlib
from .utils import Timer
@dataclass(frozen=True)
class TuneConfig:
seed: int
function: Union[str, Callable] # Points to baseline's main function
space: Dict[str, Any] # Hyperparameter search space
sampler: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
assert 'seed' not in self.sampler
def _suggest(trial: optuna.trial.Trial, distribution: str, label: str, *args):
"""Helper function to suggest values from trial"""
return getattr(trial, f'suggest_{distribution}')(label, *args)
def sample_config(
trial: optuna.trial.Trial,
space: Union[bool, int, float, str, bytes, list, dict],
label_parts: list,
) -> Any:
"""Sample configuration from hyperparameter space
Args:
trial: Optuna trial object
space: Hyperparameter space definition
label_parts: Parts of the parameter name for nested configs
Returns:
Sampled parameter value
"""
if isinstance(space, (bool, int, float, str, bytes)):
return space
elif isinstance(space, dict):
if '_tune_' in space or '-tune-' in space:
distribution = space['_tune_'] if '_tune_' in space else space['-tune-']
if distribution == "complex-custom-distribution":
raise NotImplementedError("Complex custom distributions not implemented")
else:
raise ValueError(f'Unknown distribution: "{distribution}"')
else:
return {
key: sample_config(trial, subspace, label_parts + [key])
for key, subspace in space.items()
}
elif isinstance(space, list):
if not space:
return space
elif space[0] not in ['_tune_', '-tune-']:
return [
sample_config(trial, subspace, label_parts + [str(i)])
for i, subspace in enumerate(space)
]
else:
# space = ["_tune_"/"-tune-", distribution, distribution_arg_0, distribution_1, ...]
_, distribution, *args = space
label = '.'.join(map(str, label_parts))
if distribution.startswith('?'):
default, args_ = args[0], args[1:]
if trial.suggest_categorical('?' + label, [False, True]):
return _suggest(trial, distribution.lstrip('?'), label, *args_)
else:
return default
elif distribution == '$list':
size, item_distribution, *item_args = args
return [
_suggest(trial, item_distribution, label + f'.{i}', *item_args)
for i in range(size)
]
else:
return _suggest(trial, distribution, label, *args)
def import_function(function_path: str) -> Callable:
"""Import function from string path
Args:
function_path: Dot-separated path to function (e.g., "bin.lol.main_function")
Returns:
Imported function
"""
try:
module_path, function_name = function_path.rsplit('.', 1)
module = importlib.import_module(module_path)
return getattr(module, function_name)
except Exception as e:
raise ImportError(f"Failed to import {function_path}: {e}")
def create_report(config: Dict[str, Any]) -> Dict[str, Any]:
"""Create initial report structure
Args:
config: Configuration dictionary
Returns:
Report dictionary
"""
return {
'config': config,
'start_time': time.time(),
'trials': [],
'best': None,
'status': 'running'
}
def get_checkpoint_path(output_path: Path) -> Path:
"""Get checkpoint file path
Args:
output_path: Output directory path
Returns:
Checkpoint file path
"""
return output_path / 'checkpoint.pkl'
def load_checkpoint(output_path: Path) -> Dict[str, Any]:
"""Load checkpoint from file
Args:
output_path: Output directory path
Returns:
Checkpoint data
"""
import pickle
with open(get_checkpoint_path(output_path), 'rb') as f:
return pickle.load(f)
def dump_checkpoint(data: Dict[str, Any], output_path: Path) -> None:
"""Save checkpoint to file
Args:
data: Checkpoint data
output_path: Output directory path
"""
import pickle
checkpoint_path = get_checkpoint_path(output_path)
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
with open(checkpoint_path, 'wb') as f:
pickle.dump(data, f)
def dump_summary(summary: Dict[str, Any], output_path: Path) -> None:
"""Save summary to JSON file
Args:
summary: Summary data
output_path: Output directory path
"""
output_path.mkdir(parents=True, exist_ok=True)
with open(output_path / 'summary.json', 'w') as f:
json.dump(summary, f, indent=2, default=str)
def dump_report(report: Dict[str, Any], output_path: Path) -> None:
"""Save full report to JSON file
Args:
report: Report data
output_path: Output directory path
"""
output_path.mkdir(parents=True, exist_ok=True)
with open(output_path / 'report.json', 'w') as f:
json.dump(report, f, indent=2, default=str)
def summarize(report: Dict[str, Any]) -> Dict[str, Any]:
"""Create summary from full report
Args:
report: Full report
Returns:
Summary dictionary
"""
summary = {
'n_trials': len(report.get('trials', [])),
'status': report.get('status', 'unknown'),
'start_time': report.get('start_time'),
'end_time': report.get('end_time'),
}
if report.get('best'):
summary['best_score'] = report['best'].get('metrics', {}).get('val', {}).get('score', 0)
summary['best_config'] = report['best'].get('config', {})
return summary
def backup_output(output_path: Path) -> None:
"""Create backup of output directory
Args:
output_path: Output directory path
"""
# Simple backup by copying to timestamped directory
import shutil
breakpoint()
timestamp = int(time.time())
backup_path = output_path.parent / f"{output_path.name}_backup_{timestamp}"
if output_path.exists():
shutil.copytree(output_path, backup_path, dirs_exist_ok=True)
def start(output_path: Union[str, Path], force: bool = False, continue_: bool = False) -> bool:
"""Initialize output directory and check if should start
Args:
output_path: Output directory path
force: Force restart even if output exists
continue_: Continue from existing checkpoint
Returns:
True if should start/continue, False if should skip
"""
output_path = Path(output_path)
if output_path.exists() and not force and not continue_:
print(f"Output directory {output_path} already exists. Use --force or --continue")
return False
if force and output_path.exists():
import shutil
shutil.rmtree(output_path)
output_path.mkdir(parents=True, exist_ok=True)
return True
def finish(output_path: Path, report: Dict[str, Any]) -> None:
"""Finalize the tuning run
Args:
output_path: Output directory path
report: Final report
"""
report['end_time'] = time.time()
report['status'] = 'completed'
dump_report(report, output_path)
dump_summary(summarize(report), output_path)
print(f"Tuning completed. Results saved to {output_path}")
[docs]
def tune_hyperparameters(
config: Dict[str, Any],
output: Union[str, Path],
*,
force: bool = False,
continue_: bool = False,
n_trials: int = 50,
timeout: int = 6000,
optimize_metric: str = "accuracy"
) -> Optional[Dict[str, Any]]:
"""Main hyperparameter tuning function
:param config: Tuning configuration
:param output: Output directory path
:param force: Force restart
:param continue_: Continue from checkpoint
:returns: Tuning report or None if skipped
"""
# -- Initialize output directory and check if should start ---
if not start(output, force=force, continue_=continue_):
return None
# -- Create initial report structure -------------------------
output = Path(output)
report = create_report(config)
# Create TuneConfig from dictionary -------------------------
from .config_manager import make_config
C = make_config(TuneConfig, config)
# Set random seed -------------------------------------------
np.random.seed(C.seed)
# Import evaluation function --------------------------------
function: Callable = (
import_function(C.function) if isinstance(C.function, str) else C.function
)
# Resuming from tune checkpoint, continue_ output/checkpoint.pt
if get_checkpoint_path(output).exists():
checkpoint = load_checkpoint(output)
report, study, trial_reports, timer = (
checkpoint['report'],
checkpoint['study'],
checkpoint['trial_reports'],
checkpoint['timer'],
)
n_trials = None if n_trials is None else n_trials - len(study.trials)
timeout = None if timeout is None else timeout - timer.total_elapsed
report.setdefault('continuations', []).append(len(study.trials))
print(f'Resuming from checkpoint ({len(study.trials)} completed, {n_trials or "inf"} remaining)')
# Or we create a new study ----------------------------------
else:
study = optuna.create_study(
direction='maximize',
sampler=optuna.samplers.TPESampler(**C.sampler, seed=C.seed),
)
trial_reports = []
timer = Timer()
# Objective function for each trial -------------------------
def objective(trial: optuna.trial.Trial) -> float:
raw_config = sample_config(trial, C.space, [])
with tempfile.TemporaryDirectory(suffix=f'_trial_{trial.number}') as tmp:
trial_report = function(raw_config, Path(tmp) / 'output')
assert trial_report is not None
trial_report['trial_id'] = trial.number
trial_report['tuning_time'] = str(timer)
trial_reports.append(trial_report)
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
# Extract the metric to optimize -------------------------
metrics = trial_report['metrics']
if '.' in optimize_metric:
# Handle nested metrics like 'val.accuracy' or 'test.f1'
parts = optimize_metric.split('.')
metric_value = metrics
for part in parts:
metric_value = metric_value.get(part, 0)
else:
# Default to val metric if available, otherwise test
if 'val' in metrics and optimize_metric in metrics['val']:
metric_value = metrics['val'][optimize_metric]
elif 'test' in metrics and optimize_metric in metrics['test']:
metric_value = metrics['test'][optimize_metric]
else:
metric_value = metrics.get('score', 0)
return float(metric_value)
# Callback to save progress after each trial is finished ----
def callback(*_, **__):
if study.trials:
report['best'] = trial_reports[study.best_trial.number]
report['time'] = str(timer)
report['n_completed_trials'] = len(trial_reports)
report['trials'] = trial_reports
dump_checkpoint({
'report': report,
'study': study,
'trial_reports': trial_reports,
'timer': timer,
}, output)
# Save result of best trial ? -------------------------------
dump_summary(summarize(report), output)
dump_report(report, output)
False and backup_output(output)
timer.start()
warnings.filterwarnings('ignore', category=optuna.exceptions.ExperimentalWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
try:
study.optimize(
objective,
n_trials=n_trials,
timeout=timeout,
callbacks=[callback],
show_progress_bar=True,
gc_after_trial=True
)
except KeyboardInterrupt:
print("Tuning interrupted by user")
report['status'] = 'interrupted'
except Exception as e:
print(f"Tuning failed with error: {e}")
report['status'] = 'failed'
report['error'] = str(e)
finish(output, report)
return report
def update_eval_config(eval_config_path: Path, best_config: Dict[str, Any]) -> None:
"""Update eval config file with best hyperparameters from tuning
Args:
eval_config_path: Path to evaluation config file
best_config: Best configuration from tuning
"""
from confection import Config
if eval_config_path.exists():
eval_config = Config().from_disk(eval_config_path)
else:
eval_config = {}
# Update hyperparameters while preserving non-tunable settings
def update_nested(target: Dict[str, Any], source: Dict[str, Any]):
for key, value in source.items():
if key in target:
if isinstance(value, dict) and isinstance(target[key], dict):
update_nested(target[key], value)
else:
target[key] = value
else:
target[key] = value
update_nested(eval_config, best_config)
# Save updated config
eval_config_path.parent.mkdir(parents=True, exist_ok=True)
Config(eval_config).to_disk(eval_config_path)