Source code for src.common.utils

import os
import json
import pickle
import shutil
import time
from pathlib import Path
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import random



[docs] def set_seed(seed: int = 42) -> None: """Set random seed for reproducibility Args: seed: Random seed value """ random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
[docs] def get_device(device_str: str = "auto") -> torch.device: """Get PyTorch device Args: device_str: Device specification ("auto", "cpu", "cuda", "cuda:0", etc.) Returns: PyTorch device """ if device_str == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") else: return torch.device(device_str)
[docs] def ensure_dir(path: Union[str, Path]) -> Path: """Ensure directory exists Args: path: Directory path Returns: Path object """ path = Path(path) path.mkdir(parents=True, exist_ok=True) return path
[docs] def save_json(data: Dict[str, Any], path: Union[str, Path]) -> None: """Save data to JSON file Args: data: Data to save path: File path """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, 'w') as f: json.dump(data, f, indent=2, default=str)
[docs] def load_json(path: Union[str, Path]) -> Dict[str, Any]: """Load data from JSON file Args: path: File path Returns: Loaded data """ with open(path, 'r') as f: return json.load(f)
def save_pickle(data: Any, path: Union[str, Path]) -> None: """Save data to pickle file Args: data: Data to save path: File path """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, 'wb') as f: pickle.dump(data, f) def load_pickle(path: Union[str, Path]) -> Any: """Load data from pickle file Args: path: File path Returns: Loaded data """ with open(path, 'rb') as f: return pickle.load(f) def copy_file(src: Union[str, Path], dst: Union[str, Path]) -> None: """Copy file from source to destination Args: src: Source file path dst: Destination file path """ src = Path(src) dst = Path(dst) dst.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(src, dst)
[docs] class Timer:
[docs] def __init__(self): self.start_time = None self.total_elapsed = 0.0
def start(self): self.start_time = time.time() def stop(self): if self.start_time is None: return 0.0 elapsed = time.time() - self.start_time self.total_elapsed += elapsed self.start_time = None return elapsed def elapsed(self): if self.start_time is None: return 0.0 return time.time() - self.start_time def total(self): current = self.elapsed() if self.start_time else 0.0 return self.total_elapsed + current def reset(self): self.start_time = None self.total_elapsed = 0.0 def __str__(self): elapsed = self.total() hours, remainder = divmod(elapsed, 3600) minutes, seconds = divmod(remainder, 60) return f"{int(hours):02d}:{int(minutes):02d}:{seconds:05.2f}"
def format_results_summary(results: Dict[str, Any]) -> str: """Format results for display Args: results: Results dictionary Returns: Formatted string """ lines = [] if 'metrics' in results: metrics = results['metrics'] if 'val' in metrics: val_acc = metrics['val'].get('accuracy', 'N/A') lines.append(f"Validation Accuracy: {val_acc:.4f}" if isinstance(val_acc, (int, float)) else f"Validation Accuracy: {val_acc}") if 'test' in metrics: test_acc = metrics['test'].get('accuracy', 'N/A') lines.append(f"Test Accuracy: {test_acc:.4f}" if isinstance(test_acc, (int, float)) else f"Test Accuracy: {test_acc}") if 'score' in metrics: score = metrics['score'] lines.append(f"Score: {score:.4f}" if isinstance(score, (int, float)) else f"Score: {score}") return "\n".join(lines) if lines else "No metrics found" def normalize_obj(obj): """Convert objects to comparable forms""" if isinstance(obj, Path): return str(obj) elif isinstance(obj, dict): return {k: normalize_obj(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return [normalize_obj(item) for item in obj] else: return obj def log_experiment_info(config: Any, output_path: Path) -> None: """Log experiment configuration and environment info Args: config: Experiment configuration output_path: Output directory path """ info = { 'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"), 'config': config.__dict__ if hasattr(config, '__dict__') else dict(config), 'environment': { 'python_version': os.sys.version, 'pytorch_version': torch.__version__ if torch else 'N/A', 'cuda_available': torch.cuda.is_available() if torch else False, 'device_count': torch.cuda.device_count() if torch and torch.cuda.is_available() else 0 } } save_json(info, output_path / 'experiment_info.json') def execute_py(file_path: Path, args: str): import subprocess import shlex try: args = shlex.split(args) result = subprocess.run(['python', file_path, *args], check=True, capture_output=True, text=True) print(result.stdout) except subprocess.CalledProcessError as e: print(f"Error: {e}") print(e.stderr) def set_pdb(enable=True, port=5680, host="127.0.0.1"): import debugpy if enable: debugpy.listen((host, port)) debugpy.wait_for_client()