Source code for src.common.tune_manager

import tempfile
import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Union, Dict, Callable
import optuna
import optuna.samplers
import optuna.trial
import numpy as np
import json
import importlib

from .utils import Timer

@dataclass(frozen=True)
class TuneConfig:
    seed: int
    function: Union[str, Callable]  # Points to baseline's main function
    space: Dict[str, Any]  # Hyperparameter search space
    sampler: Dict[str, Any]  = field(default_factory=dict)

    def __post_init__(self):
        assert 'seed' not in self.sampler


def _suggest(trial: optuna.trial.Trial, distribution: str, label: str, *args):
    """Helper function to suggest values from trial"""
    return getattr(trial, f'suggest_{distribution}')(label, *args)


def sample_config(
    trial: optuna.trial.Trial,
    space: Union[bool, int, float, str, bytes, list, dict],
    label_parts: list,
) -> Any:
    """Sample configuration from hyperparameter space
    
    Args:
        trial: Optuna trial object
        space: Hyperparameter space definition
        label_parts: Parts of the parameter name for nested configs
        
    Returns:
        Sampled parameter value
    """
    if isinstance(space, (bool, int, float, str, bytes)):
        return space
    elif isinstance(space, dict):
        if '_tune_' in space or '-tune-' in space:
            distribution = space['_tune_'] if '_tune_' in space else space['-tune-']
            if distribution == "complex-custom-distribution":
                raise NotImplementedError("Complex custom distributions not implemented")
            else:
                raise ValueError(f'Unknown distribution: "{distribution}"')
        else:
            return {
                key: sample_config(trial, subspace, label_parts + [key])
                for key, subspace in space.items()
            }
    elif isinstance(space, list):
        if not space:
            return space
        elif space[0] not in ['_tune_', '-tune-']:
            return [
                sample_config(trial, subspace, label_parts + [str(i)])
                for i, subspace in enumerate(space)
            ]
        else:
            # space = ["_tune_"/"-tune-", distribution, distribution_arg_0, distribution_1, ...]
            _, distribution, *args = space
            label = '.'.join(map(str, label_parts))

            if distribution.startswith('?'):
                default, args_ = args[0], args[1:]
                if trial.suggest_categorical('?' + label, [False, True]):
                    return _suggest(trial, distribution.lstrip('?'), label, *args_)
                else:
                    return default

            elif distribution == '$list':
                size, item_distribution, *item_args = args
                return [
                    _suggest(trial, item_distribution, label + f'.{i}', *item_args)
                    for i in range(size)
                ]
            else:
                return _suggest(trial, distribution, label, *args)


def import_function(function_path: str) -> Callable:
    """Import function from string path
    
    Args:
        function_path: Dot-separated path to function (e.g., "bin.lol.main_function")
        
    Returns:
        Imported function
    """
    try:
        module_path, function_name = function_path.rsplit('.', 1)
        module = importlib.import_module(module_path)
        return getattr(module, function_name)
    except Exception as e:
        raise ImportError(f"Failed to import {function_path}: {e}")


def create_report(config: Dict[str, Any]) -> Dict[str, Any]:
    """Create initial report structure
    
    Args:
        config: Configuration dictionary
        
    Returns:
        Report dictionary
    """
    return {
        'config': config,
        'start_time': time.time(),
        'trials': [],
        'best': None,
        'status': 'running'
    }


def get_checkpoint_path(output_path: Path) -> Path:
    """Get checkpoint file path
    
    Args:
        output_path: Output directory path
        
    Returns:
        Checkpoint file path
    """
    return output_path / 'checkpoint.pkl'


def load_checkpoint(output_path: Path) -> Dict[str, Any]:
    """Load checkpoint from file
    
    Args:
        output_path: Output directory path
        
    Returns:
        Checkpoint data
    """
    import pickle
    with open(get_checkpoint_path(output_path), 'rb') as f:
        return pickle.load(f)


def dump_checkpoint(data: Dict[str, Any], output_path: Path) -> None:
    """Save checkpoint to file
    
    Args:
        data: Checkpoint data
        output_path: Output directory path
    """
    import pickle
    checkpoint_path = get_checkpoint_path(output_path)
    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(data, f)


def dump_summary(summary: Dict[str, Any], output_path: Path) -> None:
    """Save summary to JSON file
    
    Args:
        summary: Summary data
        output_path: Output directory path
    """
    output_path.mkdir(parents=True, exist_ok=True)
    with open(output_path / 'summary.json', 'w') as f:
        json.dump(summary, f, indent=2, default=str)


def dump_report(report: Dict[str, Any], output_path: Path) -> None:
    """Save full report to JSON file
    
    Args:
        report: Report data
        output_path: Output directory path
    """
    output_path.mkdir(parents=True, exist_ok=True)
    with open(output_path / 'report.json', 'w') as f:
        json.dump(report, f, indent=2, default=str)


def summarize(report: Dict[str, Any]) -> Dict[str, Any]:
    """Create summary from full report
    
    Args:
        report: Full report
        
    Returns:
        Summary dictionary
    """
    summary = {
        'n_trials': len(report.get('trials', [])),
        'status': report.get('status', 'unknown'),
        'start_time': report.get('start_time'),
        'end_time': report.get('end_time'),
    }
    
    if report.get('best'):
        summary['best_score'] = report['best'].get('metrics', {}).get('val', {}).get('score', 0)
        summary['best_config'] = report['best'].get('config', {})
    
    return summary


def backup_output(output_path: Path) -> None:
    """Create backup of output directory
    
    Args:
        output_path: Output directory path
    """
    # Simple backup by copying to timestamped directory
    import shutil
    breakpoint()
    timestamp = int(time.time())
    backup_path = output_path.parent / f"{output_path.name}_backup_{timestamp}"
    if output_path.exists():
        shutil.copytree(output_path, backup_path, dirs_exist_ok=True)


def start(output_path: Union[str, Path], force: bool = False, continue_: bool = False) -> bool:
    """Initialize output directory and check if should start
    
    Args:
        output_path: Output directory path
        force: Force restart even if output exists
        continue_: Continue from existing checkpoint
        
    Returns:
        True if should start/continue, False if should skip
    """
    output_path = Path(output_path)
    if output_path.exists() and not force and not continue_:
        print(f"Output directory {output_path} already exists. Use --force or --continue")
        return False
    if force and output_path.exists():
        import shutil
        shutil.rmtree(output_path)
    output_path.mkdir(parents=True, exist_ok=True)
    return True


def finish(output_path: Path, report: Dict[str, Any]) -> None:
    """Finalize the tuning run
    
    Args:
        output_path: Output directory path
        report: Final report
    """
    report['end_time'] = time.time()
    report['status'] = 'completed'
    dump_report(report, output_path)
    dump_summary(summarize(report), output_path)
    print(f"Tuning completed. Results saved to {output_path}")


[docs] def tune_hyperparameters( config: Dict[str, Any], output: Union[str, Path], *, force: bool = False, continue_: bool = False, n_trials: int = 50, timeout: int = 6000, optimize_metric: str = "accuracy" ) -> Optional[Dict[str, Any]]: """Main hyperparameter tuning function :param config: Tuning configuration :param output: Output directory path :param force: Force restart :param continue_: Continue from checkpoint :returns: Tuning report or None if skipped """ # -- Initialize output directory and check if should start --- if not start(output, force=force, continue_=continue_): return None # -- Create initial report structure ------------------------- output = Path(output) report = create_report(config) # Create TuneConfig from dictionary ------------------------- from .config_manager import make_config C = make_config(TuneConfig, config) # Set random seed ------------------------------------------- np.random.seed(C.seed) # Import evaluation function -------------------------------- function: Callable = ( import_function(C.function) if isinstance(C.function, str) else C.function ) # Resuming from tune checkpoint, continue_ output/checkpoint.pt if get_checkpoint_path(output).exists(): checkpoint = load_checkpoint(output) report, study, trial_reports, timer = ( checkpoint['report'], checkpoint['study'], checkpoint['trial_reports'], checkpoint['timer'], ) n_trials = None if n_trials is None else n_trials - len(study.trials) timeout = None if timeout is None else timeout - timer.total_elapsed report.setdefault('continuations', []).append(len(study.trials)) print(f'Resuming from checkpoint ({len(study.trials)} completed, {n_trials or "inf"} remaining)') # Or we create a new study ---------------------------------- else: study = optuna.create_study( direction='maximize', sampler=optuna.samplers.TPESampler(**C.sampler, seed=C.seed), ) trial_reports = [] timer = Timer() # Objective function for each trial ------------------------- def objective(trial: optuna.trial.Trial) -> float: raw_config = sample_config(trial, C.space, []) with tempfile.TemporaryDirectory(suffix=f'_trial_{trial.number}') as tmp: trial_report = function(raw_config, Path(tmp) / 'output') assert trial_report is not None trial_report['trial_id'] = trial.number trial_report['tuning_time'] = str(timer) trial_reports.append(trial_report) try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except ImportError: pass # Extract the metric to optimize ------------------------- metrics = trial_report['metrics'] if '.' in optimize_metric: # Handle nested metrics like 'val.accuracy' or 'test.f1' parts = optimize_metric.split('.') metric_value = metrics for part in parts: metric_value = metric_value.get(part, 0) else: # Default to val metric if available, otherwise test if 'val' in metrics and optimize_metric in metrics['val']: metric_value = metrics['val'][optimize_metric] elif 'test' in metrics and optimize_metric in metrics['test']: metric_value = metrics['test'][optimize_metric] else: metric_value = metrics.get('score', 0) return float(metric_value) # Callback to save progress after each trial is finished ---- def callback(*_, **__): if study.trials: report['best'] = trial_reports[study.best_trial.number] report['time'] = str(timer) report['n_completed_trials'] = len(trial_reports) report['trials'] = trial_reports dump_checkpoint({ 'report': report, 'study': study, 'trial_reports': trial_reports, 'timer': timer, }, output) # Save result of best trial ? ------------------------------- dump_summary(summarize(report), output) dump_report(report, output) False and backup_output(output) timer.start() warnings.filterwarnings('ignore', category=optuna.exceptions.ExperimentalWarning) warnings.filterwarnings('ignore', category=FutureWarning) try: study.optimize( objective, n_trials=n_trials, timeout=timeout, callbacks=[callback], show_progress_bar=True, gc_after_trial=True ) except KeyboardInterrupt: print("Tuning interrupted by user") report['status'] = 'interrupted' except Exception as e: print(f"Tuning failed with error: {e}") report['status'] = 'failed' report['error'] = str(e) finish(output, report) return report
def update_eval_config(eval_config_path: Path, best_config: Dict[str, Any]) -> None: """Update eval config file with best hyperparameters from tuning Args: eval_config_path: Path to evaluation config file best_config: Best configuration from tuning """ from confection import Config if eval_config_path.exists(): eval_config = Config().from_disk(eval_config_path) else: eval_config = {} # Update hyperparameters while preserving non-tunable settings def update_nested(target: Dict[str, Any], source: Dict[str, Any]): for key, value in source.items(): if key in target: if isinstance(value, dict) and isinstance(target[key], dict): update_nested(target[key], value) else: target[key] = value else: target[key] = value update_nested(eval_config, best_config) # Save updated config eval_config_path.parent.mkdir(parents=True, exist_ok=True) Config(eval_config).to_disk(eval_config_path)