diff --git a/docs/user-guide/advanced.md b/docs/user-guide/advanced.md index a595731..737d483 100644 --- a/docs/user-guide/advanced.md +++ b/docs/user-guide/advanced.md @@ -117,8 +117,8 @@ Sparkwheel recognizes these special keys in configuration: - `_target_`: Class or function path to instantiate (e.g., `"torch.nn.Linear"`) - `_disabled_`: Skip instantiation if `true` (removed from parent). See [Instantiation](instantiation.md#_disabled_-skip-instantiation) for details. -- `_requires_`: List of dependencies to evaluate/instantiate first - `_mode_`: Operating mode for instantiation (see below) +- `_imports_`: Declare imports available to all expressions (see [Imports](#imports-for-expressions) below) ### `_mode_` - Instantiation Modes @@ -289,25 +289,56 @@ except ConfigKeyError as e: Color output is auto-detected and respects `NO_COLOR` environment variable. -## Globals for Expressions +## Imports for Expressions -Pre-import modules for use in expressions: +Make modules available to all expressions. There are two ways to do this: + +### Method 1: `_imports_` Key in YAML + +Declare imports directly in your config file: + +```yaml +# config.yaml +_imports_: + torch: torch + np: numpy + Path: pathlib.Path + +# Now use them in expressions +device: "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')" +data: "$np.array([1, 2, 3])" +save_path: "$Path('/data/models')" +``` + +The `_imports_` key is removed from the config after processing—it won't appear in your resolved config. + +### Method 2: `imports` Parameter in Python + +Pass imports when creating the Config: ```python from sparkwheel import Config -# Pre-import torch for all expressions -config = Config(globals={"torch": "torch", "np": "numpy"}) +# Pre-import modules for all expressions +config = Config(imports={"torch": "torch", "np": "numpy"}) config.update("config.yaml") # Now expressions can use torch and np without importing ``` -Example config: +### Combining Both Methods -```yaml -device: "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')" -data: "$np.array([1, 2, 3])" +You can use both approaches together—they merge: + +```python +from collections import Counter + +config = Config(imports={"Counter": Counter}) +config.update({ + "_imports_": {"json": "json"}, + "data": '$json.dumps({"a": 1})', + "counts": "$Counter([1, 1, 2])" +}) ``` ## Type Hints diff --git a/docs/user-guide/basics.md b/docs/user-guide/basics.md index ff4741f..7a36ddb 100644 --- a/docs/user-guide/basics.md +++ b/docs/user-guide/basics.md @@ -409,10 +409,10 @@ Sparkwheel reserves certain keys with special meaning: - `_target_`: Specifies a class to instantiate - `_disabled_`: Skip instantiation if true -- `_requires_`: Dependencies that must be resolved first - `_mode_`: Instantiation mode (default, callable, debug) +- `_imports_`: Declare imports available to all expressions -These are covered in detail in [Instantiation Guide](instantiation.md). +These are covered in detail in [Instantiation Guide](instantiation.md) and [Advanced Features](advanced.md). ## Common Patterns diff --git a/docs/user-guide/instantiation.md b/docs/user-guide/instantiation.md index 258c6ea..b3f17c6 100644 --- a/docs/user-guide/instantiation.md +++ b/docs/user-guide/instantiation.md @@ -224,7 +224,6 @@ augmentation: - `_target_`: Class or function path to instantiate (required) - `_args_`: List of positional arguments to pass - `_disabled_`: Skip instantiation if `true` (removed from parent) -- `_requires_`: Dependencies to resolve first - `_mode_`: Instantiation mode (`"default"`, `"callable"`, or `"debug"`) For complete details, see the [Advanced Features](advanced.md) and [API Reference](../reference/). diff --git a/src/sparkwheel/config.py b/src/sparkwheel/config.py index 744ca06..87f3423 100644 --- a/src/sparkwheel/config.py +++ b/src/sparkwheel/config.py @@ -148,7 +148,7 @@ class Config: ``` Args: - globals: Pre-imported packages for expressions (e.g., {"torch": "torch"}) + imports: Pre-imported packages for expressions (e.g., {"torch": "torch"}) schema: Dataclass schema for continuous validation coerce: Auto-convert compatible types (default: True) strict: Reject fields not in schema (default: True) @@ -159,7 +159,7 @@ def __init__( self, data: dict[str, Any] | None = None, # Internal/testing use only *, # Rest are keyword-only - globals: dict[str, Any] | None = None, + imports: dict[str, Any] | None = None, schema: type | None = None, coerce: bool = True, strict: bool = True, @@ -171,7 +171,7 @@ def __init__( Args: data: Initial data (internal/testing use only, not validated) - globals: Pre-imported packages for expression evaluation + imports: Pre-imported packages for expression evaluation schema: Dataclass schema for continuous validation coerce: Auto-convert compatible types strict: Reject fields not in schema @@ -196,14 +196,14 @@ def __init__( self._strict: bool = strict self._allow_missing: bool = allow_missing - # Process globals (import string module paths) - self._globals: dict[str, Any] = {} - if isinstance(globals, dict): - for k, v in globals.items(): - self._globals[k] = optional_import(v)[0] if isinstance(v, str) else v + # Process imports (import string module paths) + self._imports: dict[str, Any] = {} + if isinstance(imports, dict): + for k, v in imports.items(): + self._imports[k] = optional_import(v)[0] if isinstance(v, str) else v self._loader = Loader() - self._preprocessor = Preprocessor(self._loader, self._globals) + self._preprocessor = Preprocessor(self._loader, self._imports) def get(self, id: str = "", default: Any = None) -> Any: """Get raw config value (unresolved). @@ -683,6 +683,10 @@ def _parse(self, reset: bool = True) -> None: if reset: self._resolver.reset() + # Process _imports_ key if present in config data + # This allows YAML-based imports that become available to all expressions + self._process_imports_key() + # Phase 2: Expand local raw references (%key) now that all composition is complete # CLI overrides have been applied, so local refs will see final values self._data = self._preprocessor.process_raw_refs( @@ -693,7 +697,7 @@ def _parse(self, reset: bool = True) -> None: self._data = self._preprocessor.process(self._data, self._data, id="") # Stage 2: Parse config tree to create Items - parser = Parser(globals=self._globals, metadata=self._locations) + parser = Parser(globals=self._imports, metadata=self._locations) items = parser.parse(self._data) # Stage 3: Add items to resolver @@ -701,6 +705,52 @@ def _parse(self, reset: bool = True) -> None: self._is_parsed = True + def _process_imports_key(self) -> None: + """Process _imports_ key from config data. + + The _imports_ key allows declaring imports directly in YAML: + + ```yaml + _imports_: + torch: torch + np: numpy + Path: pathlib.Path + + model: + device: "$torch.device('cuda')" + ``` + + These imports become available to all expressions in the config. + The _imports_ key is removed from the data after processing. + """ + imports_key = "_imports_" + if imports_key not in self._data: + return + + imports_config = self._data.pop(imports_key) + if not isinstance(imports_config, dict): + return + + # Process each import + for name, module_path in imports_config.items(): + if isinstance(module_path, str): + # Handle dotted paths like "pathlib.Path" or "collections.Counter" + # Split into module and attribute if needed + if "." in module_path: + parts = module_path.rsplit(".", 1) + # First try as a module (e.g., "os.path") + module_obj, success = optional_import(module_path) + if not success: + # Try as module.attribute (e.g., "pathlib.Path") + module_obj, success = optional_import(parts[0], name=parts[1]) + self._imports[name] = module_obj + else: + # Simple module name like "json" + self._imports[name] = optional_import(module_path)[0] + else: + # Already a module or callable + self._imports[name] = module_path + def _get_by_id(self, id: str) -> Any: """Get config value by ID path. @@ -789,7 +839,8 @@ def parse_overrides(args: list[str]) -> dict[str, Any]: """Parse CLI argument overrides with automatic type inference. Supports only key=value syntax with operator prefixes. - Types are automatically inferred using ast.literal_eval(). + Values are parsed using YAML syntax (via ``yaml.safe_load``), ensuring + CLI overrides behave identically to values in YAML config files. Args: args: List of argument strings to parse (e.g., from argparse) @@ -805,11 +856,11 @@ def parse_overrides(args: list[str]) -> dict[str, Any]: Examples: >>> # Basic overrides (compose/merge) - >>> parse_overrides(["model::lr=0.001", "debug=True"]) + >>> parse_overrides(["model::lr=0.001", "debug=true"]) {"model::lr": 0.001, "debug": True} >>> # With operators - >>> parse_overrides(["=model={'_target_': 'ResNet'}", "~old_param"]) + >>> parse_overrides(["=model={_target_: ResNet}", "~old_param"]) {"=model": {'_target_': 'ResNet'}, "~old_param": None} >>> # Nested paths with operators @@ -817,9 +868,12 @@ def parse_overrides(args: list[str]) -> dict[str, Any]: {"=optimizer::lr": 0.01, "~model::old_param": None} Note: - The '=' character serves dual purpose: - - In 'key=value' → assignment operator (CLI syntax) - - In '=key=value' → replace operator prefix (config operator) + - The '=' character serves dual purpose: + - In 'key=value' → assignment operator (CLI syntax) + - In '=key=value' → replace operator prefix (config operator) + - Values use YAML syntax: ``true``/``false``, ``yes``/``no``, ``on``/``off`` + for booleans, ``null`` or ``~`` for None, ``{key: value}`` for dicts. + - Python's ``None`` is parsed as the string ``"None"`` (use ``null`` instead). """ import yaml diff --git a/src/sparkwheel/items.py b/src/sparkwheel/items.py index dfe5f02..300feec 100644 --- a/src/sparkwheel/items.py +++ b/src/sparkwheel/items.py @@ -116,7 +116,6 @@ class Component(Item, Instantiable): - `_target_`: Full module path (e.g., "collections.Counter") - `_args_`: List of positional arguments to pass to the target - - `_requires_`: Dependencies to evaluate/instantiate first - `_disabled_`: Skip instantiation if True - `_mode_`: Instantiation mode: - `"default"`: Returns component(*args, **kwargs) @@ -124,7 +123,7 @@ class Component(Item, Instantiable): - `"debug"`: Returns pdb.runcall(component, *args, **kwargs) """ - non_arg_keys = {"_target_", "_disabled_", "_requires_", "_mode_", "_args_"} + non_arg_keys = {"_target_", "_disabled_", "_mode_", "_args_"} def __init__(self, config: Any, id: str = "", source_location: Location | None = None) -> None: super().__init__(config=config, id=id, source_location=source_location) diff --git a/src/sparkwheel/schema.py b/src/sparkwheel/schema.py index 2d8bf40..43a7ee5 100644 --- a/src/sparkwheel/schema.py +++ b/src/sparkwheel/schema.py @@ -310,7 +310,7 @@ class AppConfig: if strict: unexpected_fields = set(config.keys()) - set(schema_fields.keys()) # Filter out sparkwheel special keys - special_keys = {"_target_", "_disabled_", "_requires_", "_mode_"} + special_keys = {"_target_", "_disabled_", "_mode_", "_imports_", "_args_"} unexpected_fields = unexpected_fields - special_keys if unexpected_fields: diff --git a/tests/test_components.py b/tests/test_components.py index a8bd7b0..98e87f7 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -102,7 +102,6 @@ def test_resolve_args(self): config = { "_target_": "collections.Counter", "_disabled_": False, - "_requires_": [], "_mode_": "default", "iterable": [1, 2, 2], } @@ -112,6 +111,7 @@ def test_resolve_args(self): assert kwargs == {"iterable": [1, 2, 2]} assert "_target_" not in kwargs assert "_disabled_" not in kwargs + assert "_mode_" not in kwargs def test_is_disabled_false(self): """Test is_disabled returns False.""" diff --git a/tests/test_config.py b/tests/test_config.py index 087dcea..6d1e51a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -134,17 +134,127 @@ def test_init_with_none(self): assert isinstance(parser._data, dict) assert parser._data == {} - def test_init_with_globals_dict(self): - """Test Config init with globals dict.""" - parser = Config({}, globals={"pd": "pandas"}) - assert "pd" in parser._globals + def test_init_with_imports_dict(self): + """Test Config init with imports dict.""" + parser = Config({}, imports={"pd": "pandas"}) + assert "pd" in parser._imports - def test_init_with_globals_callable(self): - """Test Config init with globals containing callables.""" + def test_init_with_imports_callable(self): + """Test Config init with imports containing callables.""" from collections import Counter - parser = Config({}, globals={"Counter": Counter}) - assert parser._globals["Counter"] is Counter + parser = Config({}, imports={"Counter": Counter}) + assert parser._imports["Counter"] is Counter + + +class TestConfigImports: + """Test _imports_ key handling.""" + + def test_imports_key_basic(self): + """Test _imports_ key makes modules available to expressions.""" + config = Config().update( + { + "_imports_": {"json": "json"}, + "data": '$json.dumps({"a": 1})', + } + ) + result = config.resolve("data") + assert result == '{"a": 1}' + + def test_imports_key_multiple_modules(self): + """Test _imports_ with multiple modules.""" + config = Config().update( + { + "_imports_": { + "os": "os", + "Path": "pathlib.Path", + }, + "sep": "$os.sep", + "path_type": "$Path", + } + ) + import os + from pathlib import Path + + assert config.resolve("sep") == os.sep + assert config.resolve("path_type") is Path + + def test_imports_key_removed_from_data(self): + """Test _imports_ key is removed from config data after processing.""" + config = Config().update( + { + "_imports_": {"json": "json"}, + "data": '$json.dumps({"a": 1})', + } + ) + config.resolve() # Trigger parsing + assert "_imports_" not in config._data + + def test_imports_key_combined_with_imports_parameter(self): + """Test _imports_ key works with imports parameter.""" + from collections import Counter + + config = Config(imports={"Counter": Counter}).update( + { + "_imports_": {"json": "json"}, + "counter": "$Counter([1, 1, 2])", + "data": '$json.dumps({"a": 1})', + } + ) + assert config.resolve("counter") == Counter([1, 1, 2]) + assert config.resolve("data") == '{"a": 1}' + + def test_imports_key_invalid_value_ignored(self): + """Test _imports_ with invalid value is ignored gracefully.""" + config = Config().update( + { + "_imports_": "not a dict", + "value": 42, + } + ) + result = config.resolve("value") + assert result == 42 + + def test_imports_key_with_dotted_class_path(self): + """Test _imports_ with dotted path to a class (e.g., pathlib.Path).""" + from collections import Counter + + config = Config().update( + { + "_imports_": {"Counter": "collections.Counter"}, + "counts": "$Counter([1, 1, 2, 2, 2])", + } + ) + result = config.resolve("counts") + assert result == Counter([1, 1, 2, 2, 2]) + + def test_imports_key_with_dotted_module_path(self): + """Test _imports_ with dotted path to a submodule (e.g., os.path).""" + import os.path + + config = Config().update( + { + "_imports_": {"ospath": "os.path"}, + "sep": "$ospath.sep", + } + ) + result = config.resolve("sep") + assert result == os.path.sep + + def test_imports_key_with_non_string_value(self): + """Test _imports_ with non-string value (already imported module).""" + import json + + # Pass the module directly via imports parameter, then use _imports_ with non-string + # Note: Can't put module in _imports_ dict in update() due to deepcopy, + # so we test via direct _data manipulation before parse + config = Config() + config._data = { + "_imports_": {"my_json": json}, + "data": '$my_json.dumps({"a": 1})', + } + result = config.resolve("data") + assert result == '{"a": 1}' class TestConfigReferences: diff --git a/tests/test_items.py b/tests/test_items.py index bbd07da..e2595fa 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -293,7 +293,6 @@ def test_instantiate_args_not_in_non_arg_keys(self): assert "_args_" not in kwargs assert "_target_" not in kwargs assert "_disabled_" not in kwargs - assert "_requires_" not in kwargs assert "_mode_" not in kwargs