import os
import json
import pickle
import shutil
import time
from pathlib import Path
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import random
[docs]
def set_seed(seed: int = 42) -> None:
"""Set random seed for reproducibility
Args:
seed: Random seed value
"""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
[docs]
def get_device(device_str: str = "auto") -> torch.device:
"""Get PyTorch device
Args:
device_str: Device specification ("auto", "cpu", "cuda", "cuda:0", etc.)
Returns:
PyTorch device
"""
if device_str == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
return torch.device(device_str)
[docs]
def ensure_dir(path: Union[str, Path]) -> Path:
"""Ensure directory exists
Args:
path: Directory path
Returns:
Path object
"""
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
return path
[docs]
def save_json(data: Dict[str, Any], path: Union[str, Path]) -> None:
"""Save data to JSON file
Args:
data: Data to save
path: File path
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'w') as f:
json.dump(data, f, indent=2, default=str)
[docs]
def load_json(path: Union[str, Path]) -> Dict[str, Any]:
"""Load data from JSON file
Args:
path: File path
Returns:
Loaded data
"""
with open(path, 'r') as f:
return json.load(f)
def save_pickle(data: Any, path: Union[str, Path]) -> None:
"""Save data to pickle file
Args:
data: Data to save
path: File path
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'wb') as f:
pickle.dump(data, f)
def load_pickle(path: Union[str, Path]) -> Any:
"""Load data from pickle file
Args:
path: File path
Returns:
Loaded data
"""
with open(path, 'rb') as f:
return pickle.load(f)
def copy_file(src: Union[str, Path], dst: Union[str, Path]) -> None:
"""Copy file from source to destination
Args:
src: Source file path
dst: Destination file path
"""
src = Path(src)
dst = Path(dst)
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dst)
[docs]
class Timer:
[docs]
def __init__(self):
self.start_time = None
self.total_elapsed = 0.0
def start(self):
self.start_time = time.time()
def stop(self):
if self.start_time is None:
return 0.0
elapsed = time.time() - self.start_time
self.total_elapsed += elapsed
self.start_time = None
return elapsed
def elapsed(self):
if self.start_time is None:
return 0.0
return time.time() - self.start_time
def total(self):
current = self.elapsed() if self.start_time else 0.0
return self.total_elapsed + current
def reset(self):
self.start_time = None
self.total_elapsed = 0.0
def __str__(self):
elapsed = self.total()
hours, remainder = divmod(elapsed, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{int(hours):02d}:{int(minutes):02d}:{seconds:05.2f}"
def format_results_summary(results: Dict[str, Any]) -> str:
"""Format results for display
Args:
results: Results dictionary
Returns:
Formatted string
"""
lines = []
if 'metrics' in results:
metrics = results['metrics']
if 'val' in metrics:
val_acc = metrics['val'].get('accuracy', 'N/A')
lines.append(f"Validation Accuracy: {val_acc:.4f}" if isinstance(val_acc, (int, float)) else f"Validation Accuracy: {val_acc}")
if 'test' in metrics:
test_acc = metrics['test'].get('accuracy', 'N/A')
lines.append(f"Test Accuracy: {test_acc:.4f}" if isinstance(test_acc, (int, float)) else f"Test Accuracy: {test_acc}")
if 'score' in metrics:
score = metrics['score']
lines.append(f"Score: {score:.4f}" if isinstance(score, (int, float)) else f"Score: {score}")
return "\n".join(lines) if lines else "No metrics found"
def normalize_obj(obj):
"""Convert objects to comparable forms"""
if isinstance(obj, Path):
return str(obj)
elif isinstance(obj, dict):
return {k: normalize_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [normalize_obj(item) for item in obj]
else:
return obj
def log_experiment_info(config: Any, output_path: Path) -> None:
"""Log experiment configuration and environment info
Args:
config: Experiment configuration
output_path: Output directory path
"""
info = {
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
'config': config.__dict__ if hasattr(config, '__dict__') else dict(config),
'environment': {
'python_version': os.sys.version,
'pytorch_version': torch.__version__ if torch else 'N/A',
'cuda_available': torch.cuda.is_available() if torch else False,
'device_count': torch.cuda.device_count() if torch and torch.cuda.is_available() else 0
}
}
save_json(info, output_path / 'experiment_info.json')
def execute_py(file_path: Path, args: str):
import subprocess
import shlex
try:
args = shlex.split(args)
result = subprocess.run(['python', file_path, *args], check=True, capture_output=True, text=True)
print(result.stdout)
except subprocess.CalledProcessError as e:
print(f"Error: {e}")
print(e.stderr)
def set_pdb(enable=True, port=5680, host="127.0.0.1"):
import debugpy
if enable:
debugpy.listen((host, port))
debugpy.wait_for_client()