Source code for src.common.config_manager

import re
import toml
import tempfile
import confection

from pathlib import Path
from typing import TypeVar, Type, Dict, Any, Union
from dataclasses import dataclass, fields, is_dataclass, asdict

T = TypeVar('T')


def move_unsectioned_to_top(toml_string):
    """ πŸ‘ if above line is unsectioned or """
    lines = toml_string.strip().split("\n")
    unsectioned, sectioned = [], []
    in_section = False
    for i, line in enumerate(lines):
        if line.strip().startswith("[") and line.strip().endswith("]"):
            in_section = True
        elif line.strip() and in_section:
            prev_unsectioned = i > 0 and any(
                lines[i - 1] == prev_line
                for prev_line in unsectioned[-2:]
                if unsectioned
            )
            two_empty_before = (
                i >= 2 and not lines[i - 1].strip() and not lines[i - 2].strip()
            )
            if prev_unsectioned or two_empty_before:
                in_section = False
        (sectioned if in_section else unsectioned).append(line)
    return "\n".join([l for l in unsectioned if l.strip()] + [""] + sectioned)
    
def move_unsectioned_from_file(input_file, output_file=None):
    """ πŸ‘ move unsectioned contents all to the top of the file """
    with open(input_file, "r") as f:
        content = f.read()
    reorganized = move_unsectioned_to_top(content)
    output_path = output_file if output_file else input_file
    with open(output_path, "w") as f:
        f.write(reorganized)


class DollarTomlDecoder(toml.TomlDecoder):
    """Custom TOML decoder that treats ${} patterns as strings"""

    def load_line(self, line, currentlevel, multikey, multibackslash):
        """Override load_line to handle ${} patterns before parsing"""
        if (
            "=" in line
            and not line.strip().startswith("#")
            and not line.strip().startswith("[")
        ):
            key_part, value_part = line.split("=", 1)
            value = value_part.strip()
            # If value contains ${} or starts with problematic patterns, quote it
            if re.search(r"\$\{[^}]+\}", value) or (
                value.startswith(".") and not value.replace(".", "", 1).isdigit()
            ):
                if not (value.startswith('"') or value.startswith("'")):
                    line = f'{key_part}= "{value}"'
        # Call parent method with potentially modified line
        return super().load_line(line, currentlevel, multikey, multibackslash)


class Config(confection.Config):
    def fix_toplevel_config(self, config: Dict):
        result = {k: v for k, v in config.items() if isinstance(v, dict)}
        toplevel = {k: v for k, v in config.items() if not isinstance(v, dict)}
        if toplevel:
            result['config'] = toplevel
            result = {'config': toplevel, **result}
        return result
    
    def unwrap_toplevel_config(self, config: Dict):
        result = {k: v for k, v in config.items() if k != "config"}
        if "config" in config and isinstance(config["config"], dict):
            result.update(config["config"])
        return result

    def from_disk(self, path, *, interpolate = True, overrides: Dict = {}):
        move_unsectioned_from_file(path)
        raw_config = toml.load(path, decoder=DollarTomlDecoder())
        wrapped_config = self.fix_toplevel_config(raw_config)
        with tempfile.NamedTemporaryFile(mode='w', suffix='.toml', delete=False) as f:
            toml.dump(wrapped_config, f)
            tmp_path = f.name
        try:
            loaded = super().from_disk(tmp_path, interpolate=interpolate, overrides=overrides)
            return self.unwrap_toplevel_config(loaded)
        finally:
            Path(tmp_path).unlink()

    def clean_config(self, path):
        with open(path, "r") as f:
            content = f.read()
        content = re.sub(r'^\[config\]\s*\n?', '', content, flags=re.MULTILINE)
        with open(path, "w") as f:
            f.write(content)

    def to_disk(self, path, *, interpolate = True):
        try:
            super().to_disk(path, interpolate=interpolate)
        except confection.ConfigValidationError: 
            wrapped_config = self.fix_toplevel_config(self)
            confection.Config(wrapped_config).to_disk(path, interpolate=interpolate)
            self.clean_config(path)

[docs] def read_config(config_class: Type[T], config_path: Union[str, Path], overrides: Dict[str, Any] = {}) -> T: """Read and validate config using dataclass Args: config_class: Dataclass type to create config_path: Path to TOML config file overrides: Dictionary of config overrides (supports nested keys with dots) Returns: Validated config instance """ config_path = Path(config_path) if config_path.exists(): overrides = { f"config.{k}" if "." not in k else k: v for k, v in overrides.items() } config_data = Config().from_disk(config_path, overrides=overrides) else: config_data = {} return make_config(config_class, config_data)
[docs] def make_config(config_class: Type[T], config_data: Dict[str, Any], overrides: Dict[str, Any] = {}) -> T: """Create dataclass instance from dictionary Args: config_class: Dataclass type to create config_data: Configuration data overrides: Dictionary of config overrides (supports nested keys with dots) Returns: Dataclass instance """ if not is_dataclass(config_class): raise ValueError(f"config_class must be a dataclass, got {config_class}") # Create defaults dict from dataclass field defaults field_defaults = {f.name: f for f in fields(config_class)} for field_name in overrides: if "." not in field_name and field_name not in config_data: field_info = field_defaults.get(field_name) if field_info and ( field_info.default is not field_info.default_factory or field_info.default_factory is not field_info.default_factory ): config_data.update({field_name: overrides[field_name]}) # Filter config_data to only include fields that exist in the dataclass filtered_data = {} for key, value in config_data.items(): if key in field_defaults: field_type = next(f.type for f in fields(config_class) if f.name == key) value = overrides.get(key, value) if is_dataclass(field_type) and isinstance(value, dict): nested_overrides = {} prefix = f"{key}." for override_key, override_value in overrides.items(): if override_key.startswith(prefix): nested_key = override_key[len(prefix) :] nested_overrides[nested_key] = override_value filtered_data[key] = make_config(field_type, value, nested_overrides) else: filtered_data[key] = value return config_class(**filtered_data)
def serialize_config(data: Any) -> Any: """ Recursively convert Path objects to strings in nested data structures """ if isinstance(data, Path): return data.as_posix() elif isinstance(data, dict): return {key: serialize_config(value) for key, value in data.items()} elif isinstance(data, (list, tuple)): return type(data)(serialize_config(item) for item in data) elif hasattr(data, "__dataclass_fields__"): # Handle dataclass instances return serialize_config(asdict(data)) else: return data
[docs] def save_config(config: Any, path: Union[str, Path]) -> None: """Save config to TOML file Args: config: Configuration object (dataclass or dict) path: Path to save config file """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) config_dict = serialize_config(config) Config(config_dict).to_disk(path)
@dataclass class BaseConfig: """Base configuration class with common fields""" seed: int = 42 device: str = "auto" # "auto", "cpu", "cuda"
[docs] @dataclass class DataConfig: """Base data configuration""" name: str path: str = None preprocessing: Dict[str, Any] = None def __post_init__(self): if self.path is None: from .constant import Data2Path self.path = Data2Path[self.name] if self.preprocessing is None: self.preprocessing = {}
[docs] @dataclass class OutputConfig: """Output configuration""" folder: str = None save_model: bool = False save_predictions: bool = False def __post_init__(self): Path(self.folder).mkdir(exist_ok=True, parents=True)