import re
import toml
import tempfile
import confection
from pathlib import Path
from typing import TypeVar, Type, Dict, Any, Union
from dataclasses import dataclass, fields, is_dataclass, asdict
T = TypeVar('T')
def move_unsectioned_to_top(toml_string):
""" π if above line is unsectioned or """
lines = toml_string.strip().split("\n")
unsectioned, sectioned = [], []
in_section = False
for i, line in enumerate(lines):
if line.strip().startswith("[") and line.strip().endswith("]"):
in_section = True
elif line.strip() and in_section:
prev_unsectioned = i > 0 and any(
lines[i - 1] == prev_line
for prev_line in unsectioned[-2:]
if unsectioned
)
two_empty_before = (
i >= 2 and not lines[i - 1].strip() and not lines[i - 2].strip()
)
if prev_unsectioned or two_empty_before:
in_section = False
(sectioned if in_section else unsectioned).append(line)
return "\n".join([l for l in unsectioned if l.strip()] + [""] + sectioned)
def move_unsectioned_from_file(input_file, output_file=None):
""" π move unsectioned contents all to the top of the file """
with open(input_file, "r") as f:
content = f.read()
reorganized = move_unsectioned_to_top(content)
output_path = output_file if output_file else input_file
with open(output_path, "w") as f:
f.write(reorganized)
class DollarTomlDecoder(toml.TomlDecoder):
"""Custom TOML decoder that treats ${} patterns as strings"""
def load_line(self, line, currentlevel, multikey, multibackslash):
"""Override load_line to handle ${} patterns before parsing"""
if (
"=" in line
and not line.strip().startswith("#")
and not line.strip().startswith("[")
):
key_part, value_part = line.split("=", 1)
value = value_part.strip()
# If value contains ${} or starts with problematic patterns, quote it
if re.search(r"\$\{[^}]+\}", value) or (
value.startswith(".") and not value.replace(".", "", 1).isdigit()
):
if not (value.startswith('"') or value.startswith("'")):
line = f'{key_part}= "{value}"'
# Call parent method with potentially modified line
return super().load_line(line, currentlevel, multikey, multibackslash)
class Config(confection.Config):
def fix_toplevel_config(self, config: Dict):
result = {k: v for k, v in config.items() if isinstance(v, dict)}
toplevel = {k: v for k, v in config.items() if not isinstance(v, dict)}
if toplevel:
result['config'] = toplevel
result = {'config': toplevel, **result}
return result
def unwrap_toplevel_config(self, config: Dict):
result = {k: v for k, v in config.items() if k != "config"}
if "config" in config and isinstance(config["config"], dict):
result.update(config["config"])
return result
def from_disk(self, path, *, interpolate = True, overrides: Dict = {}):
move_unsectioned_from_file(path)
raw_config = toml.load(path, decoder=DollarTomlDecoder())
wrapped_config = self.fix_toplevel_config(raw_config)
with tempfile.NamedTemporaryFile(mode='w', suffix='.toml', delete=False) as f:
toml.dump(wrapped_config, f)
tmp_path = f.name
try:
loaded = super().from_disk(tmp_path, interpolate=interpolate, overrides=overrides)
return self.unwrap_toplevel_config(loaded)
finally:
Path(tmp_path).unlink()
def clean_config(self, path):
with open(path, "r") as f:
content = f.read()
content = re.sub(r'^\[config\]\s*\n?', '', content, flags=re.MULTILINE)
with open(path, "w") as f:
f.write(content)
def to_disk(self, path, *, interpolate = True):
try:
super().to_disk(path, interpolate=interpolate)
except confection.ConfigValidationError:
wrapped_config = self.fix_toplevel_config(self)
confection.Config(wrapped_config).to_disk(path, interpolate=interpolate)
self.clean_config(path)
[docs]
def read_config(config_class: Type[T], config_path: Union[str, Path], overrides: Dict[str, Any] = {}) -> T:
"""Read and validate config using dataclass
Args:
config_class: Dataclass type to create
config_path: Path to TOML config file
overrides: Dictionary of config overrides (supports nested keys with dots)
Returns:
Validated config instance
"""
config_path = Path(config_path)
if config_path.exists():
overrides = {
f"config.{k}" if "." not in k else k: v for k, v in overrides.items()
}
config_data = Config().from_disk(config_path, overrides=overrides)
else:
config_data = {}
return make_config(config_class, config_data)
[docs]
def make_config(config_class: Type[T], config_data: Dict[str, Any], overrides: Dict[str, Any] = {}) -> T:
"""Create dataclass instance from dictionary
Args:
config_class: Dataclass type to create
config_data: Configuration data
overrides: Dictionary of config overrides (supports nested keys with dots)
Returns:
Dataclass instance
"""
if not is_dataclass(config_class):
raise ValueError(f"config_class must be a dataclass, got {config_class}")
# Create defaults dict from dataclass field defaults
field_defaults = {f.name: f for f in fields(config_class)}
for field_name in overrides:
if "." not in field_name and field_name not in config_data:
field_info = field_defaults.get(field_name)
if field_info and (
field_info.default is not field_info.default_factory
or field_info.default_factory is not field_info.default_factory
):
config_data.update({field_name: overrides[field_name]})
# Filter config_data to only include fields that exist in the dataclass
filtered_data = {}
for key, value in config_data.items():
if key in field_defaults:
field_type = next(f.type for f in fields(config_class) if f.name == key)
value = overrides.get(key, value)
if is_dataclass(field_type) and isinstance(value, dict):
nested_overrides = {}
prefix = f"{key}."
for override_key, override_value in overrides.items():
if override_key.startswith(prefix):
nested_key = override_key[len(prefix) :]
nested_overrides[nested_key] = override_value
filtered_data[key] = make_config(field_type, value, nested_overrides)
else:
filtered_data[key] = value
return config_class(**filtered_data)
def serialize_config(data: Any) -> Any:
"""
Recursively convert Path objects to strings in nested data structures
"""
if isinstance(data, Path):
return data.as_posix()
elif isinstance(data, dict):
return {key: serialize_config(value) for key, value in data.items()}
elif isinstance(data, (list, tuple)):
return type(data)(serialize_config(item) for item in data)
elif hasattr(data, "__dataclass_fields__"):
# Handle dataclass instances
return serialize_config(asdict(data))
else:
return data
[docs]
def save_config(config: Any, path: Union[str, Path]) -> None:
"""Save config to TOML file
Args:
config: Configuration object (dataclass or dict)
path: Path to save config file
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
config_dict = serialize_config(config)
Config(config_dict).to_disk(path)
@dataclass
class BaseConfig:
"""Base configuration class with common fields"""
seed: int = 42
device: str = "auto" # "auto", "cpu", "cuda"
[docs]
@dataclass
class DataConfig:
"""Base data configuration"""
name: str
path: str = None
preprocessing: Dict[str, Any] = None
def __post_init__(self):
if self.path is None:
from .constant import Data2Path
self.path = Data2Path[self.name]
if self.preprocessing is None:
self.preprocessing = {}
[docs]
@dataclass
class OutputConfig:
"""Output configuration"""
folder: str = None
save_model: bool = False
save_predictions: bool = False
def __post_init__(self):
Path(self.folder).mkdir(exist_ok=True, parents=True)