Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions docs/user-guide/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ Sparkwheel recognizes these special keys in configuration:

- `_target_`: Class or function path to instantiate (e.g., `"torch.nn.Linear"`)
- `_disabled_`: Skip instantiation if `true` (removed from parent). See [Instantiation](instantiation.md#_disabled_-skip-instantiation) for details.
- `_requires_`: List of dependencies to evaluate/instantiate first
- `_mode_`: Operating mode for instantiation (see below)
- `_imports_`: Declare imports available to all expressions (see [Imports](#imports-for-expressions) below)

### `_mode_` - Instantiation Modes

Expand Down Expand Up @@ -289,25 +289,56 @@ except ConfigKeyError as e:

Color output is auto-detected and respects `NO_COLOR` environment variable.

## Globals for Expressions
## Imports for Expressions

Pre-import modules for use in expressions:
Make modules available to all expressions. There are two ways to do this:

### Method 1: `_imports_` Key in YAML

Declare imports directly in your config file:

```yaml
# config.yaml
_imports_:
torch: torch
np: numpy
Path: pathlib.Path

# Now use them in expressions
device: "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
data: "$np.array([1, 2, 3])"
save_path: "$Path('/data/models')"
```

The `_imports_` key is removed from the config after processing—it won't appear in your resolved config.

### Method 2: `imports` Parameter in Python

Pass imports when creating the Config:

```python
from sparkwheel import Config

# Pre-import torch for all expressions
config = Config(globals={"torch": "torch", "np": "numpy"})
# Pre-import modules for all expressions
config = Config(imports={"torch": "torch", "np": "numpy"})
config.update("config.yaml")

# Now expressions can use torch and np without importing
```

Example config:
### Combining Both Methods

```yaml
device: "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
data: "$np.array([1, 2, 3])"
You can use both approaches together—they merge:

```python
from collections import Counter

config = Config(imports={"Counter": Counter})
config.update({
"_imports_": {"json": "json"},
"data": '$json.dumps({"a": 1})',
"counts": "$Counter([1, 1, 2])"
})
```

## Type Hints
Expand Down
4 changes: 2 additions & 2 deletions docs/user-guide/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,10 @@ Sparkwheel reserves certain keys with special meaning:

- `_target_`: Specifies a class to instantiate
- `_disabled_`: Skip instantiation if true
- `_requires_`: Dependencies that must be resolved first
- `_mode_`: Instantiation mode (default, callable, debug)
- `_imports_`: Declare imports available to all expressions

These are covered in detail in [Instantiation Guide](instantiation.md).
These are covered in detail in [Instantiation Guide](instantiation.md) and [Advanced Features](advanced.md).

## Common Patterns

Expand Down
1 change: 0 additions & 1 deletion docs/user-guide/instantiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ augmentation:
- `_target_`: Class or function path to instantiate (required)
- `_args_`: List of positional arguments to pass
- `_disabled_`: Skip instantiation if `true` (removed from parent)
- `_requires_`: Dependencies to resolve first
- `_mode_`: Instantiation mode (`"default"`, `"callable"`, or `"debug"`)

For complete details, see the [Advanced Features](advanced.md) and [API Reference](../reference/).
86 changes: 70 additions & 16 deletions src/sparkwheel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class Config:
```

Args:
globals: Pre-imported packages for expressions (e.g., {"torch": "torch"})
imports: Pre-imported packages for expressions (e.g., {"torch": "torch"})
schema: Dataclass schema for continuous validation
coerce: Auto-convert compatible types (default: True)
strict: Reject fields not in schema (default: True)
Expand All @@ -159,7 +159,7 @@ def __init__(
self,
data: dict[str, Any] | None = None, # Internal/testing use only
*, # Rest are keyword-only
globals: dict[str, Any] | None = None,
imports: dict[str, Any] | None = None,
schema: type | None = None,
coerce: bool = True,
strict: bool = True,
Expand All @@ -171,7 +171,7 @@ def __init__(

Args:
data: Initial data (internal/testing use only, not validated)
globals: Pre-imported packages for expression evaluation
imports: Pre-imported packages for expression evaluation
schema: Dataclass schema for continuous validation
coerce: Auto-convert compatible types
strict: Reject fields not in schema
Expand All @@ -196,14 +196,14 @@ def __init__(
self._strict: bool = strict
self._allow_missing: bool = allow_missing

# Process globals (import string module paths)
self._globals: dict[str, Any] = {}
if isinstance(globals, dict):
for k, v in globals.items():
self._globals[k] = optional_import(v)[0] if isinstance(v, str) else v
# Process imports (import string module paths)
self._imports: dict[str, Any] = {}
if isinstance(imports, dict):
for k, v in imports.items():
self._imports[k] = optional_import(v)[0] if isinstance(v, str) else v

self._loader = Loader()
self._preprocessor = Preprocessor(self._loader, self._globals)
self._preprocessor = Preprocessor(self._loader, self._imports)

def get(self, id: str = "", default: Any = None) -> Any:
"""Get raw config value (unresolved).
Expand Down Expand Up @@ -683,6 +683,10 @@ def _parse(self, reset: bool = True) -> None:
if reset:
self._resolver.reset()

# Process _imports_ key if present in config data
# This allows YAML-based imports that become available to all expressions
self._process_imports_key()

# Phase 2: Expand local raw references (%key) now that all composition is complete
# CLI overrides have been applied, so local refs will see final values
self._data = self._preprocessor.process_raw_refs(
Expand All @@ -693,14 +697,60 @@ def _parse(self, reset: bool = True) -> None:
self._data = self._preprocessor.process(self._data, self._data, id="")

# Stage 2: Parse config tree to create Items
parser = Parser(globals=self._globals, metadata=self._locations)
parser = Parser(globals=self._imports, metadata=self._locations)
items = parser.parse(self._data)

# Stage 3: Add items to resolver
self._resolver.add_items(items)

self._is_parsed = True

def _process_imports_key(self) -> None:
"""Process _imports_ key from config data.

The _imports_ key allows declaring imports directly in YAML:

```yaml
_imports_:
torch: torch
np: numpy
Path: pathlib.Path

model:
device: "$torch.device('cuda')"
```

These imports become available to all expressions in the config.
The _imports_ key is removed from the data after processing.
"""
imports_key = "_imports_"
if imports_key not in self._data:
return

imports_config = self._data.pop(imports_key)
if not isinstance(imports_config, dict):
return

# Process each import
for name, module_path in imports_config.items():
if isinstance(module_path, str):
# Handle dotted paths like "pathlib.Path" or "collections.Counter"
# Split into module and attribute if needed
if "." in module_path:
parts = module_path.rsplit(".", 1)
# First try as a module (e.g., "os.path")
module_obj, success = optional_import(module_path)
if not success:
# Try as module.attribute (e.g., "pathlib.Path")
module_obj, success = optional_import(parts[0], name=parts[1])
self._imports[name] = module_obj
else:
# Simple module name like "json"
self._imports[name] = optional_import(module_path)[0]
else:
# Already a module or callable
self._imports[name] = module_path

def _get_by_id(self, id: str) -> Any:
"""Get config value by ID path.

Expand Down Expand Up @@ -789,7 +839,8 @@ def parse_overrides(args: list[str]) -> dict[str, Any]:
"""Parse CLI argument overrides with automatic type inference.

Supports only key=value syntax with operator prefixes.
Types are automatically inferred using ast.literal_eval().
Values are parsed using YAML syntax (via ``yaml.safe_load``), ensuring
CLI overrides behave identically to values in YAML config files.

Args:
args: List of argument strings to parse (e.g., from argparse)
Expand All @@ -805,21 +856,24 @@ def parse_overrides(args: list[str]) -> dict[str, Any]:

Examples:
>>> # Basic overrides (compose/merge)
>>> parse_overrides(["model::lr=0.001", "debug=True"])
>>> parse_overrides(["model::lr=0.001", "debug=true"])
{"model::lr": 0.001, "debug": True}

>>> # With operators
>>> parse_overrides(["=model={'_target_': 'ResNet'}", "~old_param"])
>>> parse_overrides(["=model={_target_: ResNet}", "~old_param"])
{"=model": {'_target_': 'ResNet'}, "~old_param": None}

>>> # Nested paths with operators
>>> parse_overrides(["=optimizer::lr=0.01", "~model::old_param"])
{"=optimizer::lr": 0.01, "~model::old_param": None}

Note:
The '=' character serves dual purpose:
- In 'key=value' → assignment operator (CLI syntax)
- In '=key=value' → replace operator prefix (config operator)
- The '=' character serves dual purpose:
- In 'key=value' → assignment operator (CLI syntax)
- In '=key=value' → replace operator prefix (config operator)
- Values use YAML syntax: ``true``/``false``, ``yes``/``no``, ``on``/``off``
for booleans, ``null`` or ``~`` for None, ``{key: value}`` for dicts.
- Python's ``None`` is parsed as the string ``"None"`` (use ``null`` instead).
"""
import yaml

Expand Down
3 changes: 1 addition & 2 deletions src/sparkwheel/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,14 @@ class Component(Item, Instantiable):

- `_target_`: Full module path (e.g., "collections.Counter")
- `_args_`: List of positional arguments to pass to the target
- `_requires_`: Dependencies to evaluate/instantiate first
- `_disabled_`: Skip instantiation if True
- `_mode_`: Instantiation mode:
- `"default"`: Returns component(*args, **kwargs)
- `"callable"`: Returns functools.partial(component, *args, **kwargs)
- `"debug"`: Returns pdb.runcall(component, *args, **kwargs)
"""

non_arg_keys = {"_target_", "_disabled_", "_requires_", "_mode_", "_args_"}
non_arg_keys = {"_target_", "_disabled_", "_mode_", "_args_"}

def __init__(self, config: Any, id: str = "", source_location: Location | None = None) -> None:
super().__init__(config=config, id=id, source_location=source_location)
Expand Down
2 changes: 1 addition & 1 deletion src/sparkwheel/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class AppConfig:
if strict:
unexpected_fields = set(config.keys()) - set(schema_fields.keys())
# Filter out sparkwheel special keys
special_keys = {"_target_", "_disabled_", "_requires_", "_mode_"}
special_keys = {"_target_", "_disabled_", "_mode_", "_imports_", "_args_"}
unexpected_fields = unexpected_fields - special_keys

if unexpected_fields:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def test_resolve_args(self):
config = {
"_target_": "collections.Counter",
"_disabled_": False,
"_requires_": [],
"_mode_": "default",
"iterable": [1, 2, 2],
}
Expand All @@ -112,6 +111,7 @@ def test_resolve_args(self):
assert kwargs == {"iterable": [1, 2, 2]}
assert "_target_" not in kwargs
assert "_disabled_" not in kwargs
assert "_mode_" not in kwargs

def test_is_disabled_false(self):
"""Test is_disabled returns False."""
Expand Down
Loading