diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 0000000..9d8aaac
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,36 @@
+name: Release
+
+on:
+ push:
+ tags:
+ - '*'
+
+permissions:
+ contents: write
+
+jobs:
+ release:
+ name: Create GitHub Release
+ runs-on: ubuntu-latest
+ timeout-minutes: 5
+ if: startsWith(github.ref, 'refs/tags')
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Verify tag is on main branch
+ run: |
+ git fetch origin main
+ if ! git merge-base --is-ancestor ${{ github.sha }} origin/main; then
+ echo "Error: Tag is not on the main branch"
+ exit 1
+ fi
+
+ - name: Create Release
+ uses: softprops/action-gh-release@v2
+ with:
+ generate_release_notes: true
+ draft: false
+ prerelease: false
diff --git a/README.md b/README.md
index 118cd0e..5f0d961 100644
--- a/README.md
+++ b/README.md
@@ -10,25 +10,30 @@
-
-βοΈ YAML configuration meets Python π
+YAML configuration meets Python
Define Python objects in YAML. Reference, compose, and instantiate them effortlessly.
-## What is Sparkwheel?
+## Quick Start
-Stop hardcoding parameters. Define complex Python objects in clean YAML files, compose them naturally, and instantiate with one line.
+```bash
+pip install sparkwheel
+```
```yaml
# config.yaml
+dataset:
+ num_classes: 10
+ batch_size: 32
+
model:
_target_: torch.nn.Linear
in_features: 784
- out_features: "%dataset::num_classes" # Reference other values
+ out_features: "%dataset::num_classes" # Reference
-dataset:
- num_classes: 10
+training:
+ steps_per_epoch: "$10000 // @dataset::batch_size" # Expression
```
```python
@@ -36,77 +41,25 @@ from sparkwheel import Config
config = Config()
config.update("config.yaml")
-model = config.resolve("model") # Actual torch.nn.Linear(784, 10) instance!
-```
-
-## Key Features
-
-- **Declarative Object Creation** - Instantiate any Python class from YAML with `_target_`
-- **Smart References** - `@` for resolved values, `%` for raw YAML
-- **Composition by Default** - Configs merge naturally (dicts merge, lists extend)
-- **Explicit Operators** - `=` to replace, `~` to delete when needed
-- **Python Expressions** - Compute values dynamically with `$` prefix
-- **Schema Validation** - Type-check configs with Python dataclasses
-- **CLI Overrides** - Override any value from command line
-
-## Installation
-```bash
-pip install sparkwheel
+model = config.resolve("model") # Actual torch.nn.Linear(784, 10)
```
-**[β Get Started in 5 Minutes](https://project-lighter.github.io/sparkwheel/getting-started/quickstart/)**
+## Features
-## Coming from Hydra/OmegaConf?
-
-Sparkwheel builds on similar ideas but adds powerful features:
-
-| Feature | Hydra/OmegaConf | Sparkwheel |
-|---------|-----------------|------------|
-| Config composition | Explicit (`+`, `++`) | **By default** (dicts merge, lists extend) |
-| Replace semantics | Default | Explicit with `=` operator |
-| Delete keys | Not idempotent | Idempotent `~` operator |
-| References | OmegaConf interpolation | `@` (resolved) + `%` (raw YAML) |
-| Python expressions | Limited | Full Python with `$` |
-| Schema validation | Structured Configs | Python dataclasses |
-| List extension | Lists replace | **Lists extend by default** |
-
-**Composition by default** means configs merge naturally without operators:
-```yaml
-# base.yaml
-model:
- hidden_size: 256
- dropout: 0.1
-
-# experiment.yaml
-model:
- hidden_size: 512 # Override
- # dropout inherited
-```
-
-## Documentation
+- **Declarative Objects** - Instantiate any Python class with `_target_`
+- **Smart References** - `@` for resolved values, `%` for raw YAML
+- **Composition by Default** - Dicts merge, lists extend automatically
+- **Explicit Control** - `=` to replace, `~` to delete
+- **Python Expressions** - Dynamic values with `$`
+- **Schema Validation** - Type-check with dataclasses
-- [Full Documentation](https://project-lighter.github.io/sparkwheel/)
-- [Quick Start Guide](https://project-lighter.github.io/sparkwheel/getting-started/quickstart/)
-- [Core Concepts](https://project-lighter.github.io/sparkwheel/user-guide/basics/)
-- [API Reference](https://project-lighter.github.io/sparkwheel/reference/)
+**[Get Started](https://project-lighter.github.io/sparkwheel/getting-started/quickstart/)** Β· **[Documentation](https://project-lighter.github.io/sparkwheel/)** Β· **[Quick Reference](https://project-lighter.github.io/sparkwheel/user-guide/quick-reference/)**
## Community
-- [Discord Server](https://discord.gg/zJcnp6KrUp) - Chat with the community
-- [YouTube Channel](https://www.youtube.com/channel/UCef1oTpv2QEBrD2pZtrdk1Q) - Tutorials and demos
-- [GitHub Issues](https://github.com/project-lighter/sparkwheel/issues) - Bug reports and feature requests
-
-## Contributing
-
-We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for development setup and guidelines.
+- [Discord](https://discord.gg/zJcnp6KrUp) Β· [YouTube](https://www.youtube.com/channel/UCef1oTpv2QEBrD2pZtrdk1Q) Β· [Issues](https://github.com/project-lighter/sparkwheel/issues)
## About
-Sparkwheel is a hard fork of [MONAI Bundle](https://github.com/Project-MONAI/MONAI/tree/dev/monai/bundle)'s configuration system, refined and expanded for general-purpose use. We're deeply grateful to the MONAI team for their excellent foundation.
-
-Sparkwheel powers [Lighter](https://project-lighter.github.io/lighter/), our configuration-driven deep learning framework built on PyTorch Lightning.
-
-## License
-
-Apache License 2.0 - See [LICENSE](LICENSE) for details.
+Sparkwheel is a hard fork of [MONAI Bundle](https://github.com/Project-MONAI/MONAI/tree/dev/monai/bundle)'s config system, with the goal of making a more general-purpose configuration library for Python projects. It combines the best of MONAI Bundle and [Hydra](http://hydra.cc/)/[OmegaComf](https://omegaconf.readthedocs.io/), while introducing new features and improvements not found in either.
diff --git a/docs/index.md b/docs/index.md
index db33c86..f782043 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -199,7 +199,7 @@ Sparkwheel has two types of references with distinct purposes:
- **Composition-by-default** - Configs merge/extend naturally, no operators needed for common case
- **List extension** - Lists extend by default (unique vs Hydra!)
- **`=` replace operator** - Explicit control when you need replacement
- - **`~` delete operator** - Remove inherited keys cleanly (idempotent!)
+ - **`~` delete operator** - Remove inherited keys explicitly
- **Python expressions with `$`** - Compute values dynamically
- **Dataclass validation** - Type-safe configs without boilerplate
- **Dual reference system** - `@` for resolved values, `%` for raw YAML
diff --git a/docs/user-guide/advanced.md b/docs/user-guide/advanced.md
index 1555750..a3b7c9e 100644
--- a/docs/user-guide/advanced.md
+++ b/docs/user-guide/advanced.md
@@ -221,8 +221,6 @@ config.update({"~plugins": [0, 2]}) # Remove list items
config.update({"~dataloaders": ["train", "test"]}) # Remove dict keys
```
-**Note:** The `~` directive is idempotent - it doesn't error if the key doesn't exist, enabling reusable configs.
-
### Programmatic Updates
Apply operators programmatically:
diff --git a/docs/user-guide/cli.md b/docs/user-guide/cli.md
index dc3474d..f32f8a4 100644
--- a/docs/user-guide/cli.md
+++ b/docs/user-guide/cli.md
@@ -111,7 +111,7 @@ Three operators for fine-grained control:
|----------|--------|----------|---------|
| **Compose** (default) | `key=value` | Merges dicts, extends lists | `model::lr=0.001` |
| **Replace** | `=key=value` | Completely replaces value | `=model={'_target_': 'ResNet'}` |
-| **Delete** | `~key` | Removes key (idempotent) | `~debug` |
+| **Delete** | `~key` | Removes key (errors if missing) | `~debug` |
!!! info "Type Inference"
Values are automatically typed using `ast.literal_eval()`:
diff --git a/docs/user-guide/operators.md b/docs/user-guide/operators.md
index 6baf2d1..9b03f63 100644
--- a/docs/user-guide/operators.md
+++ b/docs/user-guide/operators.md
@@ -126,11 +126,14 @@ Remove keys or list items with `~key`:
### Delete Entire Keys
```yaml
-# Remove keys (idempotent - no error if missing!)
+# Remove keys explicitly
~old_param: null
~debug_settings: null
```
+!!! warning "Key Must Exist"
+ The delete operator will raise an error if the key doesn't exist. This helps catch typos and configuration mistakes.
+
### Delete Dict Keys
Use path notation for nested keys:
@@ -214,28 +217,6 @@ dataloaders:
**Why?** Path notation is designed for dict keys, not list indices. The batch syntax handles index normalization and processes deletions correctly (high to low order).
-### Idempotent Delete
-
-Delete operations don't error if the key doesn't exist:
-
-```yaml
-# production.yaml - Remove debug settings if they exist
-~debug_mode: null
-~dev_logger: null
-~test_data: null
-# No errors if these don't exist!
-```
-
-This enables **reusable configs** that work with multiple bases:
-
-```yaml
-# production.yaml works with ANY base config
-~debug_settings: null
-~verbose_logging: null
-database:
- pool_size: 100
-```
-
## Combining Operators
Mix composition, replace, and delete:
@@ -298,7 +279,7 @@ config.update({"model": {"hidden_size": 1024}})
# Replace explicitly
config.update({"=optimizer": {"type": "sgd", "lr": 0.1}})
-# Delete keys (idempotent)
+# Delete keys
config.update({
"~training::old_param": None,
"~model::dropout": None
@@ -454,17 +435,40 @@ model:
### Write Reusable Configs
-Use idempotent delete for portable configs:
+!!! warning "Delete Requires Key Existence"
+ The delete operator (`~`) is **strict** - it raises an error if the key doesn't exist. This helps catch typos and configuration mistakes.
+When writing configs that should work with different base configurations, you have a few options:
+
+**Option 1: Document required keys**
```yaml
-# production.yaml - works with ANY base!
-~debug_mode: null # Remove if exists
-~verbose_logging: null # No error if missing
+# production.yaml
+# Requires: base config must have debug_mode and verbose_logging
+~debug_mode: null
+~verbose_logging: null
database:
pool_size: 100
ssl: true
```
+**Option 2: Use composition order**
+```yaml
+# production.yaml - override instead of delete
+debug_mode: false # Overrides if exists, sets if not
+verbose_logging: false
+database:
+ pool_size: 100
+ ssl: true
+```
+
+**Option 3: Conditional deletion with lists**
+```yaml
+# Delete multiple optional keys - fails only if ALL are missing
+~: [debug_mode, verbose_logging] # At least one must exist
+database:
+ pool_size: 100
+```
+
## Common Mistakes
### Using `=` When Not Needed
@@ -519,17 +523,17 @@ plugins: [cache]
|---------|-------|------------|
| Dict merge default | Yes β
| Yes β
|
| List extend default | No β | **Yes** β
|
-| Operators in YAML | No β | Yes β
(`=`, `~`) |
-| Operator count | 4 (`+`, `++`, `~`) | **2** (`=`, `~`) β
|
-| Delete dict keys | No β | Yes β
|
-| Delete list items | No β | Yes β
|
-| Idempotent delete | N/A | Yes β
|
-
-Sparkwheel goes beyond Hydra with:
-- Full composition-first philosophy (dicts **and** lists)
-- Operators directly in YAML files
-- Just 2 simple operators
-- Delete operations for fine-grained control
+| Operators in YAML | CLI-only | **Yes** β
(YAML + CLI) |
+| Operator count | 4 (`=`, `+`, `++`, `~`) | **2** (`=`, `~`) β
|
+| Delete dict keys | CLI-only (`~foo.bar`) | **Yes** β
(YAML + CLI) |
+| Delete list items | No β | **Yes** β
(by index) |
+
+Sparkwheel differs from Hydra:
+- **Full composition philosophy**: Both dicts AND lists compose by default
+- **Operators in YAML files**: Not just CLI overrides
+- **Simpler operator set**: Just 2 operators (`=`, `~`) vs 4 (`=`, `+`, `++`, `~`)
+- **List deletion**: Delete items by index with `~plugins: [0, 2]`
+- **Flexible delete**: Use `~` anywhere (YAML, CLI, programmatic)
## Next Steps
diff --git a/docs/user-guide/references.md b/docs/user-guide/references.md
index be2f42d..0c46aa6 100644
--- a/docs/user-guide/references.md
+++ b/docs/user-guide/references.md
@@ -10,7 +10,7 @@ Sparkwheel provides two types of references for linking configuration values:
| Feature | `@ref` (Resolved) | `%ref` (Raw) | `$expr` (Expression) |
|---------|-------------------|--------------|----------------------|
| **Returns** | Final computed value | Raw YAML content | Evaluated expression result |
-| **When processed** | Lazy (`resolve()`) | Eager (`update()`) | Lazy (`resolve()`) |
+| **When processed** | Lazy (`resolve()`) | External: Eager / Local: Lazy | Lazy (`resolve()`) |
| **Instantiates objects** | β
Yes | β No | β
Yes (if referenced) |
| **Evaluates expressions** | β
Yes | β No | β
Yes |
| **Use in dataclass validation** | β
Yes | β οΈ Limited | β
Yes |
@@ -18,68 +18,80 @@ Sparkwheel provides two types of references for linking configuration values:
| **Cross-file references** | β
Yes | β
Yes | β No |
| **When to use** | Get computed results | Copy config structures | Compute new values |
-## Two-Stage Processing Model
+## Two-Phase Processing Model
-Sparkwheel processes references at different times to enable safe config composition:
+Sparkwheel processes raw references (`%`) in two phases to support CLI overrides:
!!! abstract "When References Are Processed"
- **Stage 1: Eager Processing (during `update()`)**
+ **Phase 1: Eager Processing (during `update()`)**
- - **Raw References (`%`)** are expanded immediately when configs are merged
- - Enables safe config composition and pruning workflows
- - External file references resolved at load time
+ - **External file raw refs (`%file.yaml::key`)** are expanded immediately
+ - External files are frozenβtheir content won't change based on CLI overrides
+ - Enables copy-then-delete workflows with external files
- **Stage 2: Lazy Processing (during `resolve()`)**
+ **Phase 2: Lazy Processing (during `resolve()`)**
+ - **Local raw refs (`%key`)** are expanded after all composition is complete
- **Resolved References (`@`)** are processed on-demand
- **Expressions (`$`)** are evaluated when needed
- **Components (`_target_`)** are instantiated only when requested
- - Supports complex dependency graphs and deferred instantiation
+ - CLI overrides can affect local `%` refs
-**Why two stages?**
+**Why two phases?**
-This separation enables powerful workflows like config pruning:
+This design ensures CLI overrides work intuitively with local raw references:
```yaml
# base.yaml
-system:
- lr: 0.001
- batch_size: 32
+vars:
+ features_path: null # Default, will be overridden
-experiment:
- model:
- optimizer:
- lr: "%system::lr" # Copies raw value 0.001 eagerly
-
-~system: null # Delete system section after copying
+# model.yaml
+dataset:
+ path: "%vars::features_path" # Local ref - sees CLI override
```
```python
config = Config()
config.update("base.yaml")
-# % references already expanded during update()
-# ~system deletion applied after expansion
-# Result: experiment::model::optimizer::lr = 0.001 (system deleted safely)
+config.update("model.yaml")
+config.update("vars::features_path=/data/features.npz") # CLI override
+
+# Local % ref sees the override!
+path = config.resolve("dataset::path") # "/data/features.npz"
```
-With `@` references, this would fail because they resolve lazily after deletion.
+!!! tip "External vs Local Raw References"
+
+ | Type | Example | When Expanded | Use Case |
+ |------|---------|---------------|----------|
+ | **External** | `%file.yaml::key` | Eager (update) | Import from frozen files |
+ | **Local** | `%vars::key` | Lazy (resolve) | Reference config values |
+
+ External files are "frozen"βtheir content is fixed at load time.
+ Local config values may be overridden via CLI, so local refs see the final state.
## Resolution Flow
!!! abstract "How References Are Resolved"
- **Step 1: Parse Config** β Detect references in YAML
+ **Step 1: Load Configs** β During `update()`
- **Step 2: Determine Type**
+ - Parse YAML files
+ - Expand external `%file.yaml::key` refs immediately
+ - Keep local `%key` refs as strings
- - **`%key`** β Expanded eagerly during `update()` β
- - **`@key`** β Proceed to dependency resolution (lazy)
+ **Step 2: Apply Overrides** β During `update()` calls
- **Step 3: Resolve Dependencies** (for `@` references during `resolve()`)
+ - CLI overrides modify local config values
+ - Local `%` refs still see the string form
+ **Step 3: Resolve** β During `resolve()`
+
+ - Expand local `%key` refs (now sees final values)
+ - Resolve `@` dependencies in order
- Check for circular references β β **Error if found**
- - Resolve all dependencies first
- Evaluate expressions and instantiate objects
- Return final computed value β
@@ -223,6 +235,8 @@ model_template: "%base.yaml::model"
### Local Raw References
+Local raw references are expanded lazily during `resolve()`, which means CLI overrides can affect them:
+
```yaml
# config.yaml
defaults:
@@ -237,6 +251,17 @@ api_config:
backup_defaults: "%defaults" # Gets the whole defaults dict
```
+!!! tip "CLI Overrides Work with Local Raw Refs"
+
+ ```python
+ config = Config()
+ config.update("config.yaml")
+ config.update("defaults::timeout=60") # CLI override
+
+ # Local % ref sees the override!
+ config.resolve("api_config::timeout") # 60
+ ```
+
### Key Distinction
!!! abstract "@ vs % - When to Use Each"
diff --git a/pyproject.toml b/pyproject.toml
index 4a48981..916b139 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -77,7 +77,6 @@ show_traceback = true
allow_redefinition = false
check_untyped_defs = true
-disallow_any_generics = true
disallow_incomplete_defs = true
ignore_missing_imports = true
implicit_reexport = false
diff --git a/src/sparkwheel/__init__.py b/src/sparkwheel/__init__.py
index 16b8399..375ba14 100644
--- a/src/sparkwheel/__init__.py
+++ b/src/sparkwheel/__init__.py
@@ -5,7 +5,6 @@
"""
from .config import Config, parse_overrides
-from .errors import enable_colors
from .items import Component, Expression, Instantiable, Item
from .operators import apply_operators, validate_operators
from .resolver import Resolver
@@ -19,7 +18,7 @@
EvaluationError,
FrozenConfigError,
InstantiationError,
- SourceLocation,
+ Location,
TargetNotFoundError,
)
@@ -39,7 +38,6 @@
"validate_operators",
"validate",
"validator",
- "enable_colors",
"RESOLVED_REF_KEY",
"RAW_REF_KEY",
"ID_SEP_KEY",
@@ -55,5 +53,5 @@
"EvaluationError",
"FrozenConfigError",
"ValidationError",
- "SourceLocation",
+ "Location",
]
diff --git a/src/sparkwheel/config.py b/src/sparkwheel/config.py
index 9adf20b..1d19aec 100644
--- a/src/sparkwheel/config.py
+++ b/src/sparkwheel/config.py
@@ -3,26 +3,28 @@
Sparkwheel is a YAML-based configuration system with references, expressions, and dynamic instantiation.
This module provides the main Config class for loading, managing, and resolving configurations.
-## Two-Stage Processing
+## Two-Phase Processing
-Sparkwheel uses a two-stage processing model to handle different reference types at appropriate times:
+Sparkwheel uses a two-phase processing model to handle different reference types at appropriate times:
-### Stage 1: Eager Processing (during update())
-- **Raw References (`%`)**: Expanded immediately when configs are merged
-- **Purpose**: Enables safe config composition or deletion
+### Phase 1: Eager Processing (during update())
+- **External file raw refs (`%file.yaml::key`)**: Expanded immediately
+- **Purpose**: External files are frozen - their content won't change
- **Example**: `%base.yaml::lr` is replaced with the actual value from base.yaml
-### Stage 2: Lazy Processing (during resolve())
+### Phase 2: Lazy Processing (during resolve())
+- **Local raw refs (`%key`)**: Expanded after all composition is complete
- **Resolved References (`@`)**: Resolved on-demand to support circular dependencies
- **Expressions (`$`)**: Evaluated when needed using Python's eval()
- **Components (`_target_`)**: Instantiated only when requested
-- **Purpose**: Supports deferred instantiation and complex dependency graphs
+- **Purpose**: CLI overrides can affect local `%` refs, supports deferred instantiation
## Reference Types
| Symbol | Name | When Expanded | Purpose | Example |
|--------|------|---------------|---------|---------|
-| `%` | Raw Reference | Eager (update()) | Copy/paste YAML sections | `%base.yaml::lr` |
+| `%` | Raw Reference (external) | Eager (update()) | Copy from external files | `%base.yaml::lr` |
+| `%` | Raw Reference (local) | Lazy (resolve()) | Copy local config values | `%vars::lr` |
| `@` | Resolved Reference | Lazy (resolve()) | Reference config values | `@model::lr` |
| `$` | Expression | Lazy (resolve()) | Compute values dynamically | `$@lr * 2` |
@@ -97,15 +99,15 @@
from typing import Any
from .loader import Loader
-from .metadata import MetadataRegistry
-from .operators import _validate_delete_operator, apply_operators, validate_operators
+from .locations import LocationRegistry
+from .operators import MergeContext, _validate_delete_operator, apply_operators, validate_operators
from .parser import Parser
-from .path_utils import split_id
+from .path_utils import get_by_id, split_id
from .preprocessor import Preprocessor
from .resolver import Resolver
-from .utils import PathLike, look_up_option, optional_import
+from .utils import PathLike, optional_import
from .utils.constants import ID_SEP_KEY, REMOVE_KEY, REPLACE_KEY
-from .utils.exceptions import ConfigKeyError
+from .utils.exceptions import ConfigKeyError, build_missing_key_error
__all__ = ["Config", "parse_overrides"]
@@ -183,7 +185,7 @@ def __init__(
>>> config = Config(schema=MySchema).update("config.yaml")
"""
self._data: dict[str, Any] = data or {} # Start with provided data or empty
- self._metadata = MetadataRegistry()
+ self._locations = LocationRegistry()
self._resolver = Resolver()
self._is_parsed = False
self._frozen = False # Set via freeze() method later
@@ -257,7 +259,7 @@ def get(self, id: str = "", default: Any = None) -> Any:
"""
try:
return self._get_by_id(id)
- except (KeyError, IndexError, ValueError):
+ except (KeyError, IndexError, TypeError):
return default
def set(self, id: str, value: Any) -> None:
@@ -329,7 +331,7 @@ def validate(self, schema: type) -> None:
"""
from .schema import validate as validate_schema
- validate_schema(self._data, schema, metadata=self._metadata)
+ validate_schema(self._data, schema, metadata=self._locations)
def freeze(self) -> None:
"""Freeze config to prevent further modifications.
@@ -359,6 +361,20 @@ def is_frozen(self) -> bool:
"""
return self._frozen
+ @property
+ def locations(self) -> LocationRegistry:
+ """Get the location registry for this config.
+
+ Returns:
+ LocationRegistry tracking file locations of config keys
+
+ Example:
+ >>> config = Config().update("config.yaml")
+ >>> location = config.locations.get("model::lr")
+ >>> print(f"{location.filepath}:{location.line}")
+ """
+ return self._locations
+
def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config":
"""Update configuration with changes from another source.
@@ -376,7 +392,7 @@ def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config"
Operators:
- key=value - Compose (default): merge dict or extend list
- =key=value - Replace operator: completely replace value
- - ~key - Remove operator: delete key (idempotent)
+ - ~key - Remove operator: delete key (errors if missing)
Examples:
>>> # Update from file
@@ -423,10 +439,13 @@ def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config"
else:
self._update_from_file(source)
- # Eagerly expand raw references (%) immediately after update
- # This matches MONAI's behavior and allows safe pruning with delete operator (~)
- # Must happen BEFORE validation so schema sees final structure, not raw ref strings
- self._data = self._preprocessor.process_raw_refs(self._data, self._data, id="")
+ # Phase 1: Eagerly expand ONLY external file raw references (%file.yaml::key)
+ # Local refs (%key) are kept as strings - they'll be expanded in _parse() after
+ # all composition is complete. This allows CLI overrides to affect local refs.
+ # External files are frozen (their content won't change), so eager expansion is safe.
+ self._data = self._preprocessor.process_raw_refs(
+ self._data, self._data, id="", locations=self._locations, external_only=True
+ )
# Validate after raw ref expansion if schema exists
# This validates the final structure, not intermediate raw reference strings
@@ -436,7 +455,7 @@ def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config"
validate_schema(
self._data,
self._schema,
- metadata=self._metadata,
+ metadata=self._locations,
allow_missing=self._allow_missing,
strict=self._strict,
)
@@ -445,8 +464,9 @@ def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config"
def _update_from_config(self, source: "Config") -> None:
"""Update from another Config instance."""
- self._data = apply_operators(self._data, source._data)
- self._metadata.merge(source._metadata)
+ context = MergeContext(locations=source.locations)
+ self._data = apply_operators(self._data, source._data, context=context)
+ self._locations.merge(source.locations)
self._invalidate_resolution()
def _uses_nested_paths(self, source: dict[str, Any]) -> bool:
@@ -466,17 +486,43 @@ def _apply_path_updates(self, source: dict[str, Any]) -> None:
self.set(actual_key, value)
elif key.startswith(REMOVE_KEY):
- # Delete operator: ~key (idempotent)
+ # Delete operator: ~key
actual_key = key[1:]
_validate_delete_operator(actual_key, value)
- if actual_key in self:
- self._delete_nested_key(actual_key)
+ if actual_key not in self:
+ # Try to find source location for the key being deleted
+ source_location = self._locations.get(actual_key) if self._locations else None
+
+ # For nested keys, get available keys from the parent container
+ available_keys: list[str] = []
+ parent_key_name: str | None = None
+ error_key = actual_key # The key to show in error message
+
+ if ID_SEP_KEY in actual_key:
+ # Nested key like "model::lr"
+ parent_path, child_key = actual_key.rsplit(ID_SEP_KEY, 1)
+ parent_key_name = parent_path
+ error_key = child_key # Show only the child key in nested errors
+ try:
+ parent = self._get_by_id(parent_path)
+ if isinstance(parent, dict):
+ available_keys = list(parent.keys())
+ except (KeyError, IndexError, TypeError):
+ # Parent doesn't exist, fall back to top-level
+ available_keys = list(self._data.keys()) if isinstance(self._data, dict) else []
+ else:
+ # Top-level key
+ available_keys = list(self._data.keys()) if isinstance(self._data, dict) else []
+
+ raise build_missing_key_error(error_key, available_keys, source_location, parent_key=parent_key_name)
+ self._delete_nested_key(actual_key)
else:
# Default: compose (merge dict or extend list)
if key in self and isinstance(self[key], dict) and isinstance(value, dict):
- merged = apply_operators(self[key], value)
+ context = MergeContext(locations=self._locations, current_path=key)
+ merged = apply_operators(self[key], value, context=context)
self.set(key, merged)
elif key in self and isinstance(self[key], list) and isinstance(value, list):
self.set(key, self[key] + value)
@@ -501,7 +547,8 @@ def _delete_nested_key(self, key: str) -> None:
def _apply_structural_update(self, source: dict[str, Any]) -> None:
"""Apply structural update with operators."""
validate_operators(source)
- self._data = apply_operators(self._data, source)
+ context = MergeContext(locations=self._locations)
+ self._data = apply_operators(self._data, source, context=context)
self._invalidate_resolution()
def _update_from_file(self, source: PathLike) -> None:
@@ -512,12 +559,13 @@ def _update_from_file(self, source: PathLike) -> None:
# Check if loaded data uses :: path syntax
if self._uses_nested_paths(new_data):
# Expand nested paths using path updates
- self._metadata.merge(new_metadata)
+ self._locations.merge(new_metadata)
self._apply_path_updates(new_data)
else:
# Normal structural update
- self._data = apply_operators(self._data, new_data)
- self._metadata.merge(new_metadata)
+ context = MergeContext(locations=new_metadata)
+ self._data = apply_operators(self._data, new_data, context=context)
+ self._locations.merge(new_metadata)
self._invalidate_resolution()
@@ -623,7 +671,10 @@ def _parse(self, reset: bool = True) -> None:
"""Parse config tree and prepare for resolution.
Internal method called automatically by resolve().
- Note: % raw references are already expanded during update().
+
+ Two-phase raw reference expansion:
+ - Phase 1 (update): External file refs expanded eagerly
+ - Phase 2 (here): Local refs expanded now, after all composition
Args:
reset: Whether to reset the resolver before parsing (default: True)
@@ -632,12 +683,17 @@ def _parse(self, reset: bool = True) -> None:
if reset:
self._resolver.reset()
+ # Phase 2: Expand local raw references (%key) now that all composition is complete
+ # CLI overrides have been applied, so local refs will see final values
+ self._data = self._preprocessor.process_raw_refs(
+ self._data, self._data, id="", locations=self._locations, external_only=False
+ )
+
# Stage 1: Preprocess (@:: relative resolved IDs)
- # Note: % raw references were already expanded in update()
self._data = self._preprocessor.process(self._data, self._data, id="")
# Stage 2: Parse config tree to create Items
- parser = Parser(globals=self._globals, metadata=self._metadata)
+ parser = Parser(globals=self._globals, metadata=self._locations)
items = parser.parse(self._data)
# Stage 3: Add items to resolver
@@ -655,21 +711,10 @@ def _get_by_id(self, id: str) -> Any:
Config value at that path
Raises:
- KeyError: If path not found
+ KeyError: If path not found (includes available keys in message)
+ TypeError: If trying to index a non-dict/list value
"""
- if id == "":
- return self._data
-
- config = self._data
- for k in split_id(id):
- if not isinstance(config, (dict, list)):
- raise ValueError(f"Config must be dict or list for key `{k}`, but got {type(config)}: {config}")
- try:
- config = look_up_option(k, config, print_all_options=False) if isinstance(config, dict) else config[int(k)]
- except ValueError as e:
- raise KeyError(f"Key not found: {k}") from e
-
- return config
+ return get_by_id(self._data, id)
def _invalidate_resolution(self) -> None:
"""Invalidate cached resolution (called when config changes)."""
@@ -717,7 +762,7 @@ def __contains__(self, id: str) -> bool:
try:
self._get_by_id(id)
return True
- except (KeyError, IndexError, ValueError):
+ except (KeyError, IndexError, TypeError):
return False
def __repr__(self) -> str:
diff --git a/src/sparkwheel/errors/__init__.py b/src/sparkwheel/errors/__init__.py
index eb61dba..346291d 100644
--- a/src/sparkwheel/errors/__init__.py
+++ b/src/sparkwheel/errors/__init__.py
@@ -1,13 +1,7 @@
from .context import format_available_keys, format_resolution_chain
-from .formatters import enable_colors, format_code, format_error, format_suggestion
from .suggestions import format_suggestions, get_suggestions, levenshtein_distance
__all__ = [
- # Formatters
- "enable_colors",
- "format_error",
- "format_suggestion",
- "format_code",
# Suggestions
"levenshtein_distance",
"get_suggestions",
diff --git a/src/sparkwheel/errors/formatters.py b/src/sparkwheel/errors/formatters.py
deleted file mode 100644
index f3f5f06..0000000
--- a/src/sparkwheel/errors/formatters.py
+++ /dev/null
@@ -1,196 +0,0 @@
-"""Color formatting utilities for terminal output with auto-detection."""
-
-import os
-import sys
-
-__all__ = [
- "enable_colors",
- "format_error",
- "format_suggestion",
- "format_code",
- "RED",
- "YELLOW",
- "GREEN",
- "BLUE",
- "GRAY",
- "RESET",
-]
-
-# ANSI color codes
-RED = "\033[31m"
-YELLOW = "\033[33m"
-GREEN = "\033[32m"
-BLUE = "\033[34m"
-GRAY = "\033[90m"
-BOLD = "\033[1m"
-RESET = "\033[0m"
-
-# Global flag for color support
-_COLORS_ENABLED: bool | None = None
-
-
-def _supports_color() -> bool:
- """Auto-detect if the terminal supports colors.
-
- Follows industry standards for color detection:
- 1. NO_COLOR environment variable disables colors (https://no-color.org/)
- 2. SPARKWHEEL_NO_COLOR environment variable disables colors (sparkwheel-specific)
- 3. FORCE_COLOR environment variable enables colors (https://force-color.org/)
- 4. stdout TTY detection (auto-detect)
- 5. Default: disable colors
-
- Returns:
- True if colors should be enabled, False otherwise
- """
- # Check NO_COLOR environment variable (https://no-color.org/)
- # Highest priority - explicit user preference to disable
- if os.environ.get("NO_COLOR"):
- return False
-
- # Check sparkwheel-specific disable flag
- if os.environ.get("SPARKWHEEL_NO_COLOR"):
- return False
-
- # Check FORCE_COLOR environment variable (https://force-color.org/)
- # Explicit enable for CI environments, piping, etc.
- if os.environ.get("FORCE_COLOR"):
- return True
-
- # Auto-detect: Check if stdout is a TTY
- if hasattr(sys.stdout, "isatty") and sys.stdout.isatty():
- return True
-
- # Default: disable colors
- return False
-
-
-def enable_colors(enabled: bool | None = None) -> bool:
- """Enable or disable color output.
-
- Args:
- enabled: True to enable, False to disable, None for auto-detection
-
- Returns:
- Current color enable status
-
- Examples:
- >>> enable_colors(False) # Disable colors
- False
- >>> enable_colors(True) # Force enable colors
- True
- >>> enable_colors() # Auto-detect
- True # (if terminal supports it)
- """
- global _COLORS_ENABLED
-
- if enabled is None:
- _COLORS_ENABLED = _supports_color()
- else:
- _COLORS_ENABLED = enabled
-
- return _COLORS_ENABLED
-
-
-def _get_colors_enabled() -> bool:
- """Get current color enable status, initializing if needed."""
- global _COLORS_ENABLED
-
- if _COLORS_ENABLED is None:
- enable_colors() # Auto-detect
-
- return _COLORS_ENABLED # type: ignore[return-value]
-
-
-def _colorize(text: str, color: str) -> str:
- """Apply color to text if colors are enabled.
-
- Args:
- text: Text to colorize
- color: ANSI color code
-
- Returns:
- Colorized text if colors enabled, otherwise plain text
- """
- if _get_colors_enabled():
- return f"{color}{text}{RESET}"
- return text
-
-
-def format_error(text: str) -> str:
- """Format text as an error (red).
-
- Args:
- text: Text to format
-
- Returns:
- Formatted text
-
- Examples:
- >>> format_error("Error message")
- '\x1b[31mError message\x1b[0m' # With colors enabled
- >>> format_error("Error message")
- 'Error message' # With colors disabled
- """
- return _colorize(text, RED)
-
-
-def format_suggestion(text: str) -> str:
- """Format text as a suggestion (yellow).
-
- Args:
- text: Text to format
-
- Returns:
- Formatted text
- """
- return _colorize(text, YELLOW)
-
-
-def format_success(text: str) -> str:
- """Format text as success/correct (green).
-
- Args:
- text: Text to format
-
- Returns:
- Formatted text
- """
- return _colorize(text, GREEN)
-
-
-def format_code(text: str) -> str:
- """Format text as code/metadata (blue).
-
- Args:
- text: Text to format
-
- Returns:
- Formatted text
- """
- return _colorize(text, BLUE)
-
-
-def format_context(text: str) -> str:
- """Format text as context (gray).
-
- Args:
- text: Text to format
-
- Returns:
- Formatted text
- """
- return _colorize(text, GRAY)
-
-
-def format_bold(text: str) -> str:
- """Format text as bold.
-
- Args:
- text: Text to format
-
- Returns:
- Formatted text
- """
- if _get_colors_enabled():
- return f"{BOLD}{text}{RESET}"
- return text
diff --git a/src/sparkwheel/items.py b/src/sparkwheel/items.py
index 255ea24..dfe5f02 100644
--- a/src/sparkwheel/items.py
+++ b/src/sparkwheel/items.py
@@ -7,7 +7,7 @@
from .utils import CompInitMode, first, instantiate, optional_import, run_debug, run_eval
from .utils.constants import EXPR_KEY
-from .utils.exceptions import EvaluationError, InstantiationError, SourceLocation, TargetNotFoundError
+from .utils.exceptions import EvaluationError, InstantiationError, Location, TargetNotFoundError
__all__ = ["Item", "Expression", "Component", "Instantiable"]
@@ -46,7 +46,7 @@ class Item:
source_location: optional location in source file where this config item was defined.
"""
- def __init__(self, config: Any, id: str = "", source_location: SourceLocation | None = None) -> None:
+ def __init__(self, config: Any, id: str = "", source_location: Location | None = None) -> None:
self.config = config
self.id = id
self.source_location = source_location
@@ -126,7 +126,7 @@ class Component(Item, Instantiable):
non_arg_keys = {"_target_", "_disabled_", "_requires_", "_mode_", "_args_"}
- def __init__(self, config: Any, id: str = "", source_location: SourceLocation | None = None) -> None:
+ def __init__(self, config: Any, id: str = "", source_location: Location | None = None) -> None:
super().__init__(config=config, id=id, source_location=source_location)
@staticmethod
@@ -218,8 +218,15 @@ def instantiate(self, **kwargs: Any) -> object:
source_location=self.source_location,
suggestion=suggestion,
) from e
+ except InstantiationError as e:
+ # Add source location if not already present, preserving original message and suggestion
+ if e.source_location is None:
+ raise InstantiationError(
+ e._original_message, source_location=self.source_location, suggestion=e.suggestion
+ ) from e
+ raise # Already has location, re-raise as-is
except Exception as e:
- # Wrap other errors with location context (points to _target_ line)
+ # Wrap unexpected errors with component context and location
raise InstantiationError(
f"Failed to instantiate '{modname}': {type(e).__name__}: {e}",
source_location=self.source_location,
@@ -304,7 +311,7 @@ def __init__(
config: Any,
id: str = "",
globals: dict[str, Any] | None = None,
- source_location: SourceLocation | None = None,
+ source_location: Location | None = None,
) -> None:
super().__init__(config=config, id=id, source_location=source_location)
self.globals = globals if globals is not None else {}
diff --git a/src/sparkwheel/loader.py b/src/sparkwheel/loader.py
index 948da8f..748f2f1 100644
--- a/src/sparkwheel/loader.py
+++ b/src/sparkwheel/loader.py
@@ -7,23 +7,23 @@
import yaml # type: ignore[import-untyped]
-from .metadata import MetadataRegistry
+from .locations import LocationRegistry
from .path_utils import is_yaml_file
from .utils import CheckKeyDuplicatesYamlLoader, PathLike
from .utils.constants import ID_SEP_KEY
-from .utils.exceptions import SourceLocation
+from .utils.exceptions import Location
__all__ = ["Loader"]
class MetadataTrackingYamlLoader(CheckKeyDuplicatesYamlLoader):
- """YAML loader that tracks source locations into MetadataRegistry.
+ """YAML loader that tracks source locations into LocationRegistry.
Unlike the old approach that added __sparkwheel_metadata__ keys to dicts,
- this loader populates a separate MetadataRegistry during loading.
+ this loader populates a separate LocationRegistry during loading.
"""
- def __init__(self, stream, filepath: str, registry: MetadataRegistry): # type: ignore[no-untyped-def]
+ def __init__(self, stream, filepath: str, registry: LocationRegistry): # type: ignore[no-untyped-def]
super().__init__(stream)
self.filepath = filepath
self.registry = registry
@@ -35,7 +35,7 @@ def construct_mapping(self, node, deep=False):
current_id = ID_SEP_KEY.join(self.id_path_stack) if self.id_path_stack else ""
if node.start_mark:
- location = SourceLocation(
+ location = Location(
filepath=self.filepath,
line=node.start_mark.line + 1,
column=node.start_mark.column + 1,
@@ -53,6 +53,18 @@ def construct_mapping(self, node, deep=False):
# Push key onto path stack before constructing value
self.id_path_stack.append(str(key))
+ # Register source location for this specific key
+ # This allows us to track where each key was defined
+ key_id = ID_SEP_KEY.join(self.id_path_stack) if self.id_path_stack else ""
+ if key_node.start_mark:
+ key_location = Location(
+ filepath=self.filepath,
+ line=key_node.start_mark.line + 1,
+ column=key_node.start_mark.column + 1,
+ id=key_id,
+ )
+ self.registry.register(key_id, key_location)
+
# Construct value with updated path
value = self.construct_object(value_node, deep=True)
@@ -72,7 +84,7 @@ def construct_sequence(self, node, deep=False):
current_id = ID_SEP_KEY.join(self.id_path_stack) if self.id_path_stack else ""
if node.start_mark:
- location = SourceLocation(
+ location = Location(
filepath=self.filepath,
line=node.start_mark.line + 1,
column=node.start_mark.column + 1,
@@ -121,7 +133,7 @@ class Loader:
```
"""
- def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], MetadataRegistry]:
+ def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], LocationRegistry]:
"""Load a single YAML file with metadata tracking.
Args:
@@ -134,7 +146,7 @@ def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], MetadataRegistr
ValueError: If file is not a YAML file
"""
if not filepath:
- return {}, MetadataRegistry()
+ return {}, LocationRegistry()
filepath_str = str(Path(filepath))
@@ -154,7 +166,7 @@ def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], MetadataRegistr
)
# Load YAML with metadata tracking
- registry = MetadataRegistry()
+ registry = LocationRegistry()
with open(resolved_path) as f:
config = self._load_yaml_with_metadata(f, str(resolved_path), registry)
@@ -163,13 +175,13 @@ def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], MetadataRegistr
return config, registry
- def _load_yaml_with_metadata(self, stream, filepath: str, registry: MetadataRegistry) -> dict[str, Any]: # type: ignore[no-untyped-def]
+ def _load_yaml_with_metadata(self, stream, filepath: str, registry: LocationRegistry) -> dict[str, Any]: # type: ignore[no-untyped-def]
"""Load YAML and populate metadata registry during construction.
Args:
stream: File stream to load from
filepath: Path string for error messages
- registry: MetadataRegistry to populate
+ registry: LocationRegistry to populate
Returns:
Config dictionary (clean, no metadata keys)
@@ -206,7 +218,7 @@ def _strip_metadata(config: Any) -> Any:
else:
return config
- def load_files(self, filepaths: Sequence[PathLike]) -> tuple[dict[str, Any], MetadataRegistry]:
+ def load_files(self, filepaths: Sequence[PathLike]) -> tuple[dict[str, Any], LocationRegistry]:
"""Load multiple YAML files sequentially.
Files are loaded in order and merged using simple dict update
@@ -219,7 +231,7 @@ def load_files(self, filepaths: Sequence[PathLike]) -> tuple[dict[str, Any], Met
Tuple of (merged_config_dict, merged_metadata_registry)
"""
combined_config = {}
- combined_registry = MetadataRegistry()
+ combined_registry = LocationRegistry()
for filepath in filepaths:
config, registry = self.load_file(filepath)
diff --git a/src/sparkwheel/locations.py b/src/sparkwheel/locations.py
new file mode 100644
index 0000000..e2cc0de
--- /dev/null
+++ b/src/sparkwheel/locations.py
@@ -0,0 +1,74 @@
+"""Location tracking for configuration keys."""
+
+from sparkwheel.utils.exceptions import Location
+
+__all__ = ["LocationRegistry"]
+
+
+class LocationRegistry:
+ """Registry that tracks file locations for configuration keys.
+
+ Maintains a clean separation between config data and location information
+ about where config items came from. This avoids polluting config dictionaries
+ with metadata keys.
+
+ Example:
+ ```python
+ registry = LocationRegistry()
+ registry.register("model::lr", Location("config.yaml", 10, 2, "model::lr"))
+
+ location = registry.get("model::lr")
+ print(location.filepath) # "config.yaml"
+ print(location.line) # 10
+ ```
+ """
+
+ def __init__(self):
+ """Initialize empty location registry."""
+ self._locations: dict[str, Location] = {}
+
+ def register(self, id_path: str, location: Location) -> None:
+ """Register location for a config path.
+
+ Args:
+ id_path: Configuration path (e.g., "model::lr", "optimizer::params::0")
+ location: Location information
+ """
+ self._locations[id_path] = location
+
+ def get(self, id_path: str) -> Location | None:
+ """Get location for a config path.
+
+ Args:
+ id_path: Configuration path to look up
+
+ Returns:
+ Location if registered, None otherwise
+ """
+ return self._locations.get(id_path)
+
+ def merge(self, other: "LocationRegistry") -> None:
+ """Merge another registry into this one.
+
+ Args:
+ other: LocationRegistry to merge from
+ """
+ self._locations.update(other._locations)
+
+ def copy(self) -> "LocationRegistry":
+ """Create a copy of this registry.
+
+ Returns:
+ New LocationRegistry with same data
+ """
+ new_registry = LocationRegistry()
+ new_registry._locations = self._locations.copy()
+ return new_registry
+
+ def __len__(self) -> int:
+ """Return number of registered locations."""
+ return len(self._locations)
+
+ def __contains__(self, id_path: str) -> bool:
+ """Check if id_path has registered location."""
+ return id_path in self._locations
diff --git a/src/sparkwheel/metadata.py b/src/sparkwheel/metadata.py
deleted file mode 100644
index f5f6ae4..0000000
--- a/src/sparkwheel/metadata.py
+++ /dev/null
@@ -1,74 +0,0 @@
-"""Source location metadata tracking."""
-
-from sparkwheel.utils.exceptions import SourceLocation
-
-__all__ = ["MetadataRegistry"]
-
-
-class MetadataRegistry:
- """Track source locations for config items.
-
- Maintains a clean separation between config data and metadata about where
- config items came from. This avoids polluting config dictionaries with
- metadata keys.
-
- Example:
- ```python
- registry = MetadataRegistry()
- registry.register("model::lr", SourceLocation("config.yaml", 10, 2, "model::lr"))
-
- location = registry.get("model::lr")
- print(location.filepath) # "config.yaml"
- print(location.line) # 10
- ```
- """
-
- def __init__(self):
- """Initialize empty metadata registry."""
- self._locations: dict[str, SourceLocation] = {}
-
- def register(self, id_path: str, location: SourceLocation) -> None:
- """Register source location for a config path.
-
- Args:
- id_path: Configuration path (e.g., "model::lr", "optimizer::params::0")
- location: Source location information
- """
- self._locations[id_path] = location
-
- def get(self, id_path: str) -> SourceLocation | None:
- """Get source location for a config path.
-
- Args:
- id_path: Configuration path to look up
-
- Returns:
- SourceLocation if registered, None otherwise
- """
- return self._locations.get(id_path)
-
- def merge(self, other: "MetadataRegistry") -> None:
- """Merge another registry into this one.
-
- Args:
- other: MetadataRegistry to merge from
- """
- self._locations.update(other._locations)
-
- def copy(self) -> "MetadataRegistry":
- """Create a copy of this registry.
-
- Returns:
- New MetadataRegistry with same data
- """
- new_registry = MetadataRegistry()
- new_registry._locations = self._locations.copy()
- return new_registry
-
- def __len__(self) -> int:
- """Return number of registered locations."""
- return len(self._locations)
-
- def __contains__(self, id_path: str) -> bool:
- """Check if id_path has registered location."""
- return id_path in self._locations
diff --git a/src/sparkwheel/operators.py b/src/sparkwheel/operators.py
index b324043..e2e9313 100644
--- a/src/sparkwheel/operators.py
+++ b/src/sparkwheel/operators.py
@@ -1,12 +1,97 @@
"""Configuration merging with composition-by-default and operators (=, ~)."""
from copy import deepcopy
-from typing import Any
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any
-from .utils.constants import REMOVE_KEY, REPLACE_KEY
-from .utils.exceptions import ConfigMergeError
+if TYPE_CHECKING:
+ from .locations import LocationRegistry
+ from .utils.exceptions import Location
-__all__ = ["apply_operators", "validate_operators", "_validate_delete_operator"]
+from .utils.constants import ID_SEP_KEY, REMOVE_KEY, REPLACE_KEY
+from .utils.exceptions import ConfigMergeError, build_missing_key_error
+
+__all__ = ["apply_operators", "validate_operators", "_validate_delete_operator", "MergeContext"]
+
+
+@dataclass
+class MergeContext:
+ """Context for configuration merging operations.
+
+ This class consolidates location tracking and path information needed
+ during recursive config merging. Using a context object reduces parameter
+ threading and makes it easier to extend with additional context in the future.
+
+ Attributes:
+ locations: Optional registry tracking file locations of config keys
+ current_path: Current path in config tree using :: separator (e.g., "model::optimizer::lr")
+
+ Examples:
+ >>> ctx = MergeContext()
+ >>> child_ctx = ctx.child_path("model")
+ >>> child_ctx.current_path
+ 'model'
+ >>> grandchild_ctx = child_ctx.child_path("lr")
+ >>> grandchild_ctx.current_path
+ 'model::lr'
+ """
+
+ locations: "LocationRegistry | None" = None
+ current_path: str = ""
+
+ def child_path(self, key: str) -> "MergeContext":
+ """Create a child context for a nested config key.
+
+ Args:
+ key: Key to append to current path
+
+ Returns:
+ New MergeContext with updated path, sharing the same source location registry
+
+ Examples:
+ >>> ctx = MergeContext(current_path="model")
+ >>> child = ctx.child_path("optimizer")
+ >>> child.current_path
+ 'model::optimizer'
+ """
+ new_path = f"{self.current_path}{ID_SEP_KEY}{key}" if self.current_path else key
+ return MergeContext(locations=self.locations, current_path=new_path)
+
+ def get_source_location(self, key: str) -> "Location | None":
+ """Get source location for a key in the current context.
+
+ Tries both the exact key and the key without operator prefix to handle
+ operator keys correctly (e.g., ~key, =key).
+
+ Args:
+ key: Key to look up (may include operators like ~key or =key)
+
+ Returns:
+ Location if found, None otherwise
+
+ Examples:
+ >>> ctx = MergeContext(locations=registry, current_path="model")
+ >>> ctx.get_source_location("~lr") # Looks up both "model::~lr" and "model::lr"
+ >>> ctx.get_source_location("=lr") # Looks up both "model::=lr" and "model::lr"
+ """
+ if not self.locations:
+ return None
+
+ # Build full path for the key
+ key_path = f"{self.current_path}{ID_SEP_KEY}{key}" if self.current_path else key
+
+ # Try the key as-is first
+ location = self.locations.get(key_path)
+ if location:
+ return location
+
+ # If key starts with an operator (~, =), also try without the operator
+ if key.startswith((REMOVE_KEY, REPLACE_KEY)):
+ actual_key = key[1:]
+ actual_path = f"{self.current_path}{ID_SEP_KEY}{actual_key}" if self.current_path else actual_key
+ return self.locations.get(actual_path)
+
+ return None
def _validate_delete_operator(key: str, value: Any) -> None:
@@ -58,7 +143,7 @@ def validate_operators(config: dict[str, Any], parent_key: str = "") -> None:
"""Validate operator usage in config tree.
With composition-by-default, validation is simpler:
- 1. Remove operators always work (idempotent delete)
+ 1. Remove operators error if key doesn't exist
2. Replace operators work on any type
3. No parent context requirements
@@ -98,13 +183,17 @@ def validate_operators(config: dict[str, Any], parent_key: str = "") -> None:
validate_operators(value, full_key)
-def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
+def apply_operators(
+ base: dict[str, Any],
+ override: dict[str, Any],
+ context: MergeContext | None = None,
+) -> dict[str, Any]:
"""Apply configuration changes with composition-by-default semantics.
Default behavior: Compose (merge dicts, extend lists)
Operators:
=key: value - Replace operator: completely replace value (override default)
- ~key: null - Remove operator: delete key or list items (idempotent)
+ ~key: null - Remove operator: delete key or list items (errors if missing)
key: value - Compose (default): merge dict or extend list
Composition-by-Default Philosophy:
@@ -112,11 +201,12 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
- Lists extend by default (append new items)
- Only scalars and type mismatches replace
- Use = to explicitly replace entire dicts or lists
- - Use ~ to delete keys (idempotent - no error if missing)
+ - Use ~ to delete keys (errors if key doesn't exist)
Args:
base: Base configuration dict
override: Override configuration dict with optional =/~ operators
+ context: Optional merge context with metadata and path tracking
Returns:
Merged configuration dict
@@ -143,12 +233,21 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
>>> apply_operators(base, override)
{"model": {"lr": 0.01}}
- >>> # Remove operator: delete key (idempotent)
+ >>> # Remove operator: delete key
>>> base = {"a": 1, "b": 2, "c": 3}
>>> override = {"b": 5, "~c": None}
>>> apply_operators(base, override)
{"a": 1, "b": 5}
+
+ >>> # With context for better error messages
+ >>> from .locations import LocationRegistry
+ >>> ctx = MergeContext(locations=LocationRegistry())
+ >>> apply_operators(base, override, context=ctx)
"""
+ # Create context if not provided
+ if context is None:
+ context = MergeContext()
+
if not isinstance(base, dict) or not isinstance(override, dict):
return deepcopy(override) # type: ignore[unreachable]
@@ -170,9 +269,11 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
actual_key = key[1:]
_validate_delete_operator(actual_key, value)
- # Idempotent: no error if key doesn't exist
+ # Error if key doesn't exist
if actual_key not in result:
- continue # Silently skip
+ # Get source location using context helper
+ source_location = context.get_source_location(key)
+ raise build_missing_key_error(actual_key, list(result.keys()), source_location)
# Handle remove entire key (null or empty value)
if value is None or value == "":
@@ -223,11 +324,10 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
elif isinstance(base_val, dict):
for del_key in value:
if del_key not in base_val:
- raise ConfigMergeError(
- f"Cannot remove non-existent key '{del_key}' from '{actual_key}'",
- suggestion=f"The key '{del_key}' does not exist in '{actual_key}'.\n"
- f"Available keys: {list(base_val.keys())}",
- )
+ # Use helper to build error with suggestions
+ available_keys = list(base_val.keys())
+ source_location = context.get_source_location(key)
+ raise build_missing_key_error(del_key, available_keys, source_location, parent_key=actual_key)
del base_val[del_key]
else:
@@ -246,7 +346,9 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str,
# For dicts: MERGE (composition)
if isinstance(base_val, dict) and isinstance(value, dict):
- result[key] = apply_operators(base_val, value)
+ # Create child context for recursion
+ child_context = context.child_path(key)
+ result[key] = apply_operators(base_val, value, context=child_context)
continue
# For lists: EXTEND (composition)
diff --git a/src/sparkwheel/parser.py b/src/sparkwheel/parser.py
index 5264d66..66f6034 100644
--- a/src/sparkwheel/parser.py
+++ b/src/sparkwheel/parser.py
@@ -3,7 +3,7 @@
from typing import Any
from .items import Component, Expression, Item
-from .metadata import MetadataRegistry
+from .locations import LocationRegistry
from .utils.constants import ID_SEP_KEY
__all__ = ["Parser"]
@@ -27,7 +27,7 @@ class Parser:
}
}
- metadata = MetadataRegistry()
+ metadata = LocationRegistry()
parser = Parser(globals={}, metadata=metadata)
items = parser.parse(config)
@@ -43,15 +43,15 @@ class Parser:
Args:
globals: Global context for expression evaluation
- metadata: MetadataRegistry for source location lookup
+ metadata: LocationRegistry for source location lookup
"""
- def __init__(self, globals: dict[str, Any], metadata: MetadataRegistry):
+ def __init__(self, globals: dict[str, Any], metadata: LocationRegistry):
"""Initialize parser with globals and metadata.
Args:
globals: Dictionary of global variables for expression evaluation
- metadata: MetadataRegistry for looking up source locations
+ metadata: LocationRegistry for looking up source locations
"""
self._globals = globals
self._metadata = metadata
diff --git a/src/sparkwheel/path_utils.py b/src/sparkwheel/path_utils.py
index b0f5b2e..ad07245 100644
--- a/src/sparkwheel/path_utils.py
+++ b/src/sparkwheel/path_utils.py
@@ -18,6 +18,7 @@
"replace_references",
"split_file_and_id",
"is_yaml_file",
+ "get_by_id",
"PathPatterns", # Export for backward compatibility
]
@@ -286,6 +287,85 @@ def normalize_id(id: str | int) -> str:
return str(id)
+def _format_path_context(path_parts: list[str], index: int) -> str:
+ """Format parent path context for error messages.
+
+ Internal helper for get_by_id error messages.
+
+ Args:
+ path_parts: List of path components
+ index: Current index in path_parts (0 = first level)
+
+ Returns:
+ Empty string for first level, or " in 'parent::path'" for nested levels
+ """
+ if index == 0:
+ return ""
+ parent_path = ID_SEP_KEY.join(path_parts[:index])
+ return f" in '{parent_path}'"
+
+
+def get_by_id(config: dict[str, Any] | list[Any], id: str) -> Any:
+ """Navigate config structure by ID path.
+
+ Traverses nested dicts and lists using :: separated path components.
+ Provides detailed error messages with available keys when navigation fails.
+
+ Args:
+ config: Config dict or list to navigate
+ id: ID path (e.g., "model::optimizer::lr" or "items::0::value")
+
+ Returns:
+ Value at the specified path
+
+ Raises:
+ KeyError: If a dict key is not found (includes available keys in message)
+ TypeError: If trying to index a non-dict/list value
+
+ Examples:
+ >>> config = {"model": {"lr": 0.001, "layers": [64, 128]}}
+ >>> get_by_id(config, "model::lr")
+ 0.001
+ >>> get_by_id(config, "model::layers::1")
+ 128
+ >>> get_by_id(config, "")
+ {"model": {"lr": 0.001, "layers": [64, 128]}}
+
+ Error messages include context:
+ >>> get_by_id({"a": {"b": 1}}, "a::missing")
+ KeyError: "Key 'missing' not found in 'a'. Available keys: ['b']"
+ """
+ if not id:
+ return config
+
+ current = config
+ path_parts = split_id(id)
+
+ for i, key in enumerate(path_parts):
+ context = _format_path_context(path_parts, i)
+
+ if isinstance(current, dict):
+ if key not in current:
+ available_keys = list(current.keys())
+ error_msg = f"Key '{key}' not found{context}"
+ error_msg += f". Available keys: {available_keys[:10]}"
+ if len(available_keys) > 10:
+ error_msg += "..."
+ raise KeyError(error_msg)
+ current = current[key]
+ elif isinstance(current, list):
+ try:
+ current = current[int(key)]
+ except ValueError as e:
+ raise KeyError(f"Invalid list index '{key}'{context}: not an integer") from e
+ except IndexError as e:
+ raise KeyError(f"List index '{key}' out of range{context}: {e}") from e
+ else:
+ raise TypeError(f"Cannot index {type(current).__name__} with key '{key}'{context}")
+
+ return current
+
+
def resolve_relative_ids(current_id: str, value: str) -> str:
"""Resolve relative references (@::, @::::) to absolute paths.
diff --git a/src/sparkwheel/preprocessor.py b/src/sparkwheel/preprocessor.py
index cfc70ab..0e4262a 100644
--- a/src/sparkwheel/preprocessor.py
+++ b/src/sparkwheel/preprocessor.py
@@ -6,10 +6,14 @@
"""
from copy import deepcopy
-from typing import Any
+from typing import TYPE_CHECKING, Any, Optional
-from .path_utils import resolve_relative_ids, split_file_and_id, split_id
+from .path_utils import get_by_id, resolve_relative_ids, split_file_and_id
from .utils.constants import ID_SEP_KEY, RAW_REF_KEY
+from .utils.exceptions import CircularReferenceError, ConfigKeyError
+
+if TYPE_CHECKING:
+ from .locations import LocationRegistry
__all__ = ["Preprocessor"]
@@ -25,27 +29,38 @@ class Preprocessor:
Operates on raw Python dicts/lists, not on Item objects.
+ ## Two-Phase Raw Reference Expansion
+
+ Raw references are expanded in two phases to support CLI overrides:
+
+ **Phase 1 (Eager, during update()):**
+ - External file refs (`%file.yaml::key`) are expanded immediately
+ - The external file is frozen - its contents won't change
+
+ **Phase 2 (Lazy, during resolve()):**
+ - Local refs (`%key`) are expanded after all composition
+ - This allows CLI overrides to affect values referenced by local `%` refs
+
Example:
>>> loader = Loader()
>>> preprocessor = Preprocessor(loader)
>>>
>>> raw_config = {
... "lr": 0.001,
- ... "base": "%defaults.yaml::learning_rate", # Raw reference (external)
+ ... "base": "%defaults.yaml::learning_rate", # External - expanded eagerly
+ ... "ref": "%lr", # Local - expanded lazily
... "model": {
... "lr": "@::lr" # Relative resolved reference
... }
... }
>>>
- >>> preprocessed = preprocessor.process(raw_config, raw_config)
- >>> # Result:
- >>> # {
- >>> # "lr": 0.001,
- >>> # "base": 0.0005, # Loaded from defaults.yaml
- >>> # "model": {
- >>> # "lr": "@model::lr" # Converted to absolute
- >>> # }
- >>> # }
+ >>> # Phase 1: Expand only external refs
+ >>> preprocessed = preprocessor.process_raw_refs(raw_config, raw_config, external_only=True)
+ >>> # Result: base=0.0005, ref="%lr" (still string)
+ >>>
+ >>> # Phase 2: Expand local refs (after CLI overrides applied)
+ >>> preprocessed = preprocessor.process_raw_refs(preprocessed, preprocessed, external_only=False)
+ >>> # Result: ref=0.001 (now expanded)
"""
def __init__(self, loader, globals: dict[str, Any] | None = None): # type: ignore[no-untyped-def]
@@ -58,14 +73,29 @@ def __init__(self, loader, globals: dict[str, Any] | None = None): # type: igno
self.loader = loader
self.globals = globals or {}
- def process_raw_refs(self, config: Any, base_data: dict[str, Any], id: str = "") -> Any:
- """Preprocess config tree - expand only % raw references.
+ def process_raw_refs(
+ self,
+ config: Any,
+ base_data: dict[str, Any],
+ id: str = "",
+ locations: Optional["LocationRegistry"] = None,
+ *,
+ external_only: bool = False,
+ ) -> Any:
+ """Preprocess config tree - expand % raw references.
+
+ Supports two-phase expansion for CLI override compatibility:
+
+ **Phase 1 (external_only=True, during update()):**
+ - Only expands external file refs (`%file.yaml::key`)
+ - Local refs (`%key`) are kept as strings for later expansion
+ - This allows CLI overrides to affect values used by local refs
+
+ **Phase 2 (external_only=False, during resolve()):**
+ - Expands all remaining local refs (`%key`)
+ - At this point, all CLI overrides have been applied
- This is the first preprocessing stage that runs eagerly during update().
- It expands all % raw references including those with relative syntax:
- - Local: %key
- - External: %file.yaml::key
- - Relative: %::key, %::::key (converted to absolute before expansion)
+ Also handles relative syntax: %::key, %::::key (converted to absolute before expansion)
Leaves @ resolved references untouched (they're processed lazily during resolve()).
@@ -73,14 +103,18 @@ def process_raw_refs(self, config: Any, base_data: dict[str, Any], id: str = "")
config: Raw config structure to process
base_data: Root config dict (for resolving local raw references)
id: Current ID path in tree
+ locations: LocationRegistry for error reporting (optional)
+ external_only: If True, only expand external file refs (Phase 1).
+ If False, expand all refs including local (Phase 2).
Returns:
- Config with raw references expanded
+ Config with raw references expanded (or partially expanded if external_only=True)
Raises:
- ValueError: If circular raw reference detected
+ CircularReferenceError: If circular raw reference detected
+ ConfigKeyError: If referenced key not found
"""
- return self._process_raw_refs_recursive(config, base_data, id, set())
+ return self._process_raw_refs_recursive(config, base_data, id, set(), locations, external_only)
def process(self, config: Any, base_data: dict[str, Any], id: str = "") -> Any:
"""Preprocess entire config tree.
@@ -98,7 +132,7 @@ def process(self, config: Any, base_data: dict[str, Any], id: str = "") -> Any:
Preprocessed config ready for parsing
Raises:
- ValueError: If circular raw reference detected
+ CircularReferenceError: If circular raw reference detected
"""
return self._process_recursive(config, base_data, id, set())
@@ -108,10 +142,13 @@ def _process_raw_refs_recursive(
base_data: dict[str, Any],
id: str,
raw_ref_stack: set[str],
+ locations: Optional["LocationRegistry"] = None,
+ external_only: bool = False,
) -> Any:
- """Internal recursive implementation for expanding only raw references.
+ """Internal recursive implementation for expanding raw references.
- This method only expands % raw references and leaves @ references untouched.
+ This method expands % raw references and leaves @ references untouched.
+ When external_only=True, only external file refs are expanded.
Performance optimization: Skips recursion for nodes that don't contain any
raw reference strings, avoiding unnecessary tree traversal.
@@ -121,9 +158,11 @@ def _process_raw_refs_recursive(
base_data: Root config dict
id: Current ID path
raw_ref_stack: Circular reference detection
+ locations: LocationRegistry for error reporting (optional)
+ external_only: If True, skip local refs (expand only external file refs)
Returns:
- Config with raw references expanded
+ Config with raw references expanded (or partially if external_only=True)
"""
# Early exit optimization: Skip processing if this subtree has no raw references
# This avoids unnecessary recursion for large config sections without % refs
@@ -134,12 +173,16 @@ def _process_raw_refs_recursive(
if isinstance(config, dict):
for key in list(config.keys()):
sub_id = f"{id}{ID_SEP_KEY}{key}" if id else str(key)
- config[key] = self._process_raw_refs_recursive(config[key], base_data, sub_id, raw_ref_stack)
+ config[key] = self._process_raw_refs_recursive(
+ config[key], base_data, sub_id, raw_ref_stack, locations, external_only
+ )
elif isinstance(config, list):
for idx in range(len(config)):
sub_id = f"{id}{ID_SEP_KEY}{idx}" if id else str(idx)
- config[idx] = self._process_raw_refs_recursive(config[idx], base_data, sub_id, raw_ref_stack)
+ config[idx] = self._process_raw_refs_recursive(
+ config[idx], base_data, sub_id, raw_ref_stack, locations, external_only
+ )
# Process string values - only expand raw references (%)
if isinstance(config, str):
@@ -149,7 +192,7 @@ def _process_raw_refs_recursive(
# Then expand raw references
if config.startswith(RAW_REF_KEY):
- config = self._expand_raw_ref(config, base_data, raw_ref_stack)
+ config = self._expand_raw_ref(config, base_data, raw_ref_stack, id, locations, external_only)
return config
@@ -193,42 +236,95 @@ def _process_recursive(
return config
- def _expand_raw_ref(self, raw_ref: str, base_data: dict[str, Any], raw_ref_stack: set[str]) -> Any:
+ def _expand_raw_ref(
+ self,
+ raw_ref: str,
+ base_data: dict[str, Any],
+ raw_ref_stack: set[str],
+ current_id: str = "",
+ locations: Optional["LocationRegistry"] = None,
+ external_only: bool = False,
+ ) -> Any:
"""Expand a single raw reference by loading external file or local YAML.
Args:
raw_ref: Raw reference string (e.g., "%file.yaml::key" or "%key")
base_data: Root config for local raw references
raw_ref_stack: Circular reference detection
+ current_id: Current ID path (where this raw reference was found)
+ locations: LocationRegistry for error reporting (optional)
+ external_only: If True, skip local refs and return raw_ref unchanged
Returns:
- Value from raw reference (deep copied)
+ Value from raw reference (deep copied), or raw_ref unchanged if
+ external_only=True and this is a local reference
Raises:
- ValueError: If circular reference detected
+ CircularReferenceError: If circular reference detected
+ ConfigKeyError: If referenced key not found
"""
+ # Parse: "%file.yaml::key" β ("file.yaml", "key")
+ path, ids = split_file_and_id(raw_ref[len(RAW_REF_KEY) :])
+
+ # Phase 1 (external_only=True): Skip local refs, they'll be expanded later
+ # This allows CLI overrides to affect values used by local % refs
+ is_local_ref = not path
+ if external_only and is_local_ref:
+ return raw_ref # Keep as string, expand in Phase 2
+
# Circular reference check
if raw_ref in raw_ref_stack:
chain = " -> ".join(sorted(raw_ref_stack))
- raise ValueError(f"Circular raw reference detected: '{raw_ref}'\nRaw reference chain: {chain} -> {raw_ref}")
- # Parse: "%file.yaml::key" β ("file.yaml", "key")
- path, ids = split_file_and_id(raw_ref[len(RAW_REF_KEY) :])
+ # Get location information if available
+ location = None
+ if locations and current_id:
+ location = locations.get(current_id)
+
+ raise CircularReferenceError(
+ message=f"Circular raw reference detected: '{raw_ref}'\nReference chain: {chain} -> {raw_ref}",
+ source_location=location,
+ )
raw_ref_stack.add(raw_ref)
try:
# Load config (external file or local)
- if not path:
+ if is_local_ref:
loaded_config = base_data # Local raw reference: %key
+ loaded_locations = locations # Use same location registry
+ source_description = "local config"
else:
- loaded_config, _ = self.loader.load_file(path) # External: %file.yaml::key
+ loaded_config, loaded_locations = self.loader.load_file(path) # External: %file.yaml::key
+ source_description = f"'{path}'"
# Navigate to referenced value
- result = self._get_by_id(loaded_config, ids)
-
- # Recursively preprocess the loaded value (expand nested raw references only)
- result = self._process_raw_refs_recursive(result, loaded_config, ids, raw_ref_stack)
+ try:
+ result = get_by_id(loaded_config, ids)
+ except (KeyError, TypeError, IndexError) as e:
+ # Get location information if available
+ location = None
+ if locations and current_id:
+ location = locations.get(current_id)
+
+ # Build error message
+ if is_local_ref:
+ error_msg = f"Error resolving raw reference '{raw_ref}' from local config:\n{e}"
+ else:
+ error_msg = f"Error resolving raw reference '{raw_ref}' from {source_description}:\n{e}"
+
+ # Raise custom error with proper formatting
+ raise ConfigKeyError(
+ message=error_msg,
+ source_location=location,
+ ) from e
+
+ # Recursively preprocess the loaded value (expand nested raw references)
+ # For external files, always expand all refs within that file
+ # For local refs (Phase 2), expand all nested refs too
+ result = self._process_raw_refs_recursive(
+ result, loaded_config, ids, raw_ref_stack, loaded_locations, external_only=False
+ )
# Deep copy for independence
return deepcopy(result)
@@ -255,32 +351,3 @@ def _contains_raw_refs(config: Any) -> bool:
elif isinstance(config, list):
return any(Preprocessor._contains_raw_refs(item) for item in config)
return False
-
- @staticmethod
- def _get_by_id(config: dict[str, Any], id: str) -> Any:
- """Navigate config dict by ID path.
-
- Args:
- config: Config dict to navigate
- id: ID path (e.g., "model::optimizer::lr")
-
- Returns:
- Value at ID path
-
- Raises:
- KeyError: If path not found
- TypeError: If trying to index non-dict/list
- """
- if not id:
- return config
-
- current = config
- for key in split_id(id):
- if isinstance(current, dict):
- current = current[key]
- elif isinstance(current, list): # type: ignore[unreachable]
- current = current[int(key)]
- else:
- raise TypeError(f"Cannot index {type(current).__name__} with key '{key}' at path '{id}'")
-
- return current
diff --git a/src/sparkwheel/schema.py b/src/sparkwheel/schema.py
index 973cfa6..2d8bf40 100644
--- a/src/sparkwheel/schema.py
+++ b/src/sparkwheel/schema.py
@@ -39,7 +39,7 @@ class ModelConfig:
import types
from typing import Any, Union, get_args, get_origin
-from .utils.exceptions import BaseError, SourceLocation
+from .utils.exceptions import BaseError, Location
__all__ = ["validate", "validator", "ValidationError", "MISSING"]
@@ -192,7 +192,7 @@ def __init__(
field_path: str = "",
expected_type: type | None = None,
actual_value: Any = None,
- source_location: SourceLocation | None = None,
+ source_location: Location | None = None,
):
"""Initialize validation error.
@@ -682,15 +682,15 @@ def _validate_field(
)
-def _get_source_location(metadata: Any, field_path: str) -> SourceLocation | None:
+def _get_source_location(metadata: Any, field_path: str) -> Location | None:
"""Get source location from metadata registry.
Args:
- metadata: MetadataRegistry instance
+ metadata: LocationRegistry instance
field_path: Dot-separated field path to look up
Returns:
- SourceLocation if found, None otherwise
+ Location if found, None otherwise
"""
if metadata is None:
return None
@@ -699,6 +699,6 @@ def _get_source_location(metadata: Any, field_path: str) -> SourceLocation | Non
# Convert dot notation to :: notation used by sparkwheel
id_path = field_path.replace(".", "::")
result = metadata.get(id_path)
- return result if result is None or isinstance(result, SourceLocation) else None
+ return result if result is None or isinstance(result, Location) else None
except Exception:
return None
diff --git a/src/sparkwheel/utils/constants.py b/src/sparkwheel/utils/constants.py
index 872b1ce..ebccc16 100644
--- a/src/sparkwheel/utils/constants.py
+++ b/src/sparkwheel/utils/constants.py
@@ -5,6 +5,7 @@
"EXPR_KEY",
"REMOVE_KEY",
"REPLACE_KEY",
+ "SIMILARITY_THRESHOLD",
]
RESOLVED_REF_KEY = "@" # start of a resolved reference (to instantiated/evaluated value)
@@ -13,3 +14,4 @@
EXPR_KEY = "$" # start of an Expression
REMOVE_KEY = "~" # remove operator for config modifications (delete keys/items)
REPLACE_KEY = "=" # replace operator for config modifications (explicit override)
+SIMILARITY_THRESHOLD = 0.6 # minimum similarity score for key name suggestions (0-1)
diff --git a/src/sparkwheel/utils/exceptions.py b/src/sparkwheel/utils/exceptions.py
index cbb9645..d04eedb 100644
--- a/src/sparkwheel/utils/exceptions.py
+++ b/src/sparkwheel/utils/exceptions.py
@@ -5,7 +5,7 @@
from typing import Any
__all__ = [
- "SourceLocation",
+ "Location",
"BaseError",
"TargetNotFoundError",
"CircularReferenceError",
@@ -14,18 +14,31 @@
"ConfigMergeError",
"EvaluationError",
"FrozenConfigError",
+ "build_missing_key_error",
]
@dataclass
-class SourceLocation:
- """Tracks the source location of a config item."""
+class Location:
+ """Tracks the location of a config item in source files.
+
+ Attributes:
+ filepath: Path to the source file
+ line: Line number in the file (must be >= 1)
+ column: Column number (0 if not available)
+ id: Config path ID (e.g., "model::lr")
+ """
filepath: str
line: int
column: int = 0
id: str = ""
+ def __post_init__(self) -> None:
+ """Validate line number after initialization."""
+ if self.line < 1:
+ raise ValueError(f"line must be >= 1, got {self.line}")
+
def __str__(self) -> str:
return f"{self.filepath}:{self.line}"
@@ -42,7 +55,7 @@ class BaseError(Exception):
def __init__(
self,
message: str,
- source_location: SourceLocation | None = None,
+ source_location: Location | None = None,
suggestion: str | None = None,
) -> None:
self.source_location = source_location
@@ -62,9 +75,9 @@ def _format_message(self) -> str:
if self.source_location:
location = f"{self.source_location.filepath}:{self.source_location.line}"
if self.source_location.id:
- parts.append(f"[{location} @ {self.source_location.id}] {self._original_message}")
+ parts.append(f"{self._original_message}\n\n[{location} β {self.source_location.id}]:")
else:
- parts.append(f"[{location}] {self._original_message}")
+ parts.append(f"{self._original_message}\n\n[{location}]:")
else:
parts.append(self._original_message)
@@ -72,10 +85,10 @@ def _format_message(self) -> str:
if self.source_location:
snippet = self._get_config_snippet()
if snippet:
- parts.append(f"\n\n{snippet}")
+ parts.append(f"\n{snippet}\n")
if self.suggestion:
- parts.append(f"\n\n π‘ {self.suggestion}")
+ parts.append(f"\n π‘ {self.suggestion}\n")
return "".join(parts)
@@ -136,7 +149,7 @@ class ConfigKeyError(BaseError):
def __init__(
self,
message: str,
- source_location: SourceLocation | None = None,
+ source_location: Location | None = None,
suggestion: str | None = None,
missing_key: str | None = None,
available_keys: list[str] | None = None,
@@ -217,3 +230,59 @@ def __init__(self, message: str, field_path: str = ""):
if field_path:
full_message = f"Cannot modify frozen config at '{field_path}': {message}"
super().__init__(full_message)
+
+
+def build_missing_key_error(
+ key: str,
+ available_keys: list[str],
+ source_location: Location | None = None,
+ *,
+ max_suggestions: int = 3,
+ max_available_keys: int = 10,
+ parent_key: str | None = None,
+) -> ConfigMergeError:
+ """Build a ConfigMergeError for a missing key with helpful suggestions.
+
+ Args:
+ key: The key that wasn't found
+ available_keys: List of available keys to compare against
+ source_location: Optional location where error occurred
+ max_suggestions: Maximum number of suggestions to show (default: 3)
+ max_available_keys: Maximum number of available keys to show (default: 10)
+ parent_key: Optional parent key for nested deletions (e.g., "model" when deleting "model.lr")
+
+ Returns:
+ ConfigMergeError with helpful suggestions
+
+ Examples:
+ >>> error = build_missing_key_error("paramters", ["parameters", "param_groups"])
+ >>> print(error._original_message)
+ Cannot delete key 'paramters': key does not exist
+
+ >>> error = build_missing_key_error("lr", ["learning_rate"], parent_key="model")
+ >>> print(error._original_message)
+ Cannot remove non-existent key 'lr' from 'model'
+ """
+ from ..errors import get_suggestions
+ from .constants import SIMILARITY_THRESHOLD
+
+ # Build appropriate message based on context
+ if parent_key:
+ message = f"Cannot remove non-existent key '{key}' from '{parent_key}'"
+ else:
+ message = f"Cannot delete key '{key}': key does not exist"
+
+ suggestion_parts = []
+ if available_keys:
+ suggestions = get_suggestions(
+ key, available_keys, max_suggestions=max_suggestions, similarity_threshold=SIMILARITY_THRESHOLD
+ )
+ if suggestions:
+ suggestion_keys = [s[0] for s in suggestions]
+ suggestion_parts.append(f"Did you mean: {', '.join(repr(s) for s in suggestion_keys)}?")
+
+ if len(available_keys) <= max_available_keys:
+ suggestion_parts.append(f"Available keys: {', '.join(repr(k) for k in available_keys)}")
+
+ suggestion = "\n".join(suggestion_parts) if suggestion_parts else None
+ return ConfigMergeError(message, source_location=source_location, suggestion=suggestion)
diff --git a/src/sparkwheel/utils/module.py b/src/sparkwheel/utils/module.py
index 3eabb65..8a21d0b 100644
--- a/src/sparkwheel/utils/module.py
+++ b/src/sparkwheel/utils/module.py
@@ -248,13 +248,8 @@ def instantiate(__path: str, __mode: str, *args: Any, **kwargs: Any) -> Any:
)
return pdb.runcall(component, *args, **kwargs)
except Exception as e:
- # Preserve the original exception type and message for better debugging
- args_str = f"{len(args)} positional args, " if args else ""
- error_msg = (
- f"Failed to instantiate component '{__path}' with {args_str}keywords: {','.join(kwargs.keys())}\n"
- f" Original error ({type(e).__name__}): {str(e)}\n"
- f" Set '_mode_={CompInitMode.DEBUG}' to enter debugging mode."
- )
+ error_msg = f"Could not instantiate component '{__path}':\n From {type(e).__name__}: {e}"
+
raise InstantiationError(error_msg) from e
warnings.warn(f"Component to instantiate must represent a valid class or function, but got {__path}.", stacklevel=2)
diff --git a/tests/test_config.py b/tests/test_config.py
index 80b8359..874bbdb 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -106,7 +106,7 @@ def test_merge_nested_paths(self):
def test_getitem_invalid_config_type(self):
"""Test __getitem__ raises error for invalid config type."""
parser = Config({"scalar": 42})
- with pytest.raises(ValueError, match="Config must be dict or list"):
+ with pytest.raises(TypeError, match="Cannot index int"):
_ = parser["scalar::invalid"]
def test_getitem_list_indexing(self):
@@ -277,45 +277,86 @@ def test_do_resolve_macro_load(self):
finally:
Path(filepath).unlink()
- def test_eager_raw_reference_expansion(self):
- """Test that raw references are expanded during update(), not resolve()."""
+ def test_local_raw_reference_lazy_expansion(self):
+ """Test that local raw references are expanded lazily (during resolve()).
+
+ This allows CLI overrides to affect values used by local % refs.
+ """
config = {"original": {"a": 1, "b": 2}, "copy": "%original"}
parser = Config().update(config)
- # After update(), raw references should be expanded (not resolve() yet!)
- # The copy should be the actual value, not the "%original" string
- assert parser.get("copy") == {"a": 1, "b": 2}
- assert parser.get("copy") is not parser.get("original") # Deep copy
+ # After update(), LOCAL raw references are NOT expanded yet
+ # They remain as strings until resolve() is called
+ assert parser.get("copy") == "%original"
+
+ # After resolve(), local refs are expanded
+ resolved = parser.resolve("copy")
+ assert resolved == {"a": 1, "b": 2}
+
+ # Verify it's a deep copy (independent of original)
+ assert resolved is not parser.resolve("original")
+
+ def test_cli_override_affects_local_raw_ref(self):
+ """Test that CLI overrides affect values used by local % refs.
+
+ This is the key use case: vars::path can be overridden via CLI
+ and local % refs will see the overridden value.
+ """
+ parser = Config()
+ parser.update({"vars": {"path": None}})
+ parser.update({"data": {"path": "%vars::path"}})
+
+ # Before CLI override, local ref is still a string
+ assert parser.get("data::path") == "%vars::path"
+
+ # Apply CLI override
+ parser.update("vars::path=/data/features.npz")
+
+ # Now resolve - local ref should see the overridden value
+ assert parser.resolve("data::path") == "/data/features.npz"
+
+ def test_external_raw_reference_eager_expansion(self, tmp_path):
+ """Test that external file raw references are expanded eagerly."""
+ # Create external file
+ external = tmp_path / "external.yaml"
+ external.write_text("value: 42\nnested:\n a: 1\n b: 2")
+
+ parser = Config()
+ parser.update({"imported": f"%{external}::value", "section": f"%{external}::nested"})
+
+ # After update(), EXTERNAL raw references ARE expanded (eager)
+ assert parser.get("imported") == 42
+ assert parser.get("section") == {"a": 1, "b": 2}
+
+ def test_pruning_with_external_raw_references(self, tmp_path):
+ """Test that pruning works with external file raw references.
- # Verify it's a real copy by modifying it
- parser.set("copy::c", 3)
- assert parser.get("copy::c") == 3
- assert "c" not in parser.get("original")
+ External file refs are expanded eagerly, so copy-then-delete works.
+ For local refs, use @ references instead (they also support this pattern).
+ """
+ # Create external file with dataloader configs
+ external = tmp_path / "dataloaders.yaml"
+ external.write_text("train:\n batch_size: 32\nval:\n batch_size: 64")
- def test_pruning_with_raw_references(self):
- """Test that pruning works with raw references (the Lighter use case)."""
- # This is the critical test: raw refs should be expanded before pruning
config = {
"system": {
- "dataloaders": {
- "train": {"batch_size": 32},
- "val": {"batch_size": 64},
- }
+ "dataloaders": f"%{external}", # External ref - expanded eagerly
},
"train": {
- "dataloader": "%system::dataloaders::train" # Raw reference
+ "dataloader": f"%{external}::train", # External ref - expanded eagerly
},
}
parser = Config().update(config)
- # Raw reference should already be expanded
+ # External raw reference should already be expanded (eager)
assert parser.get("train::dataloader") == {"batch_size": 32}
+ assert parser.get("system::dataloaders") == {"train": {"batch_size": 32}, "val": {"batch_size": 64}}
# Now prune the system section (delete it)
parser.update("~system")
- # The raw reference was already expanded, so train::dataloader should still exist
+ # The external raw reference was already expanded, so train::dataloader still exists
assert parser.get("train::dataloader") == {"batch_size": 32}
assert "system" not in parser.get() # system is deleted
@@ -323,6 +364,31 @@ def test_pruning_with_raw_references(self):
result = parser.resolve("train::dataloader")
assert result == {"batch_size": 32}
+ def test_local_raw_ref_to_deleted_key_fails(self):
+ """Test that local % ref to deleted key fails clearly.
+
+ With lazy local refs, deleting the source before resolve() will fail.
+ Use external file refs or @ refs if you need copy-then-delete.
+ """
+ from sparkwheel.utils.exceptions import ConfigKeyError
+
+ config = {
+ "system": {"train": {"batch_size": 32}},
+ "train": {"dataloader": "%system::train"}, # Local ref - expanded lazily
+ }
+
+ parser = Config().update(config)
+
+ # Local ref is NOT expanded yet
+ assert parser.get("train::dataloader") == "%system::train"
+
+ # Delete the source
+ parser.update("~system")
+
+ # Attempting to resolve will fail - source was deleted
+ with pytest.raises(ConfigKeyError, match="system"):
+ parser.resolve("train::dataloader")
+
class TestComponents:
"""Test component instantiation and handling."""
@@ -430,6 +496,34 @@ def test_split_path_id_no_path(self):
assert path == ""
assert ids == "key::subkey"
+ def test_update_from_file_with_nested_paths_merges_locations(self, tmp_path):
+ """Test that nested-path syntax (::) in YAML files properly merges location tracking."""
+ # Create a YAML file with nested-path syntax
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text("model::lr: 0.001\nmodel::dropout: 0.5\ntrainer::epochs: 10")
+
+ # Load the file
+ config = Config()
+ config.update(str(config_file))
+
+ # Verify the values were set correctly
+ assert config["model"]["lr"] == 0.001
+ assert config["model"]["dropout"] == 0.5
+ assert config["trainer"]["epochs"] == 10
+
+ # Verify that locations were tracked
+ # The location registry should have entries for the nested paths
+ assert "model::lr" in config._locations or "model" in config._locations
+ assert "model::dropout" in config._locations or "model" in config._locations
+ assert "trainer::epochs" in config._locations or "trainer" in config._locations
+
+ # Verify the location points to the correct file
+ if "model::lr" in config._locations:
+ location = config._locations.get("model::lr")
+ assert location is not None
+ assert location.filepath == str(config_file)
+ assert location.line >= 1
+
class TestConfigMerging:
"""Test merging configurations with composition-by-default and =/~ operators."""
@@ -629,14 +723,16 @@ def test_merge_combined_operators(self):
assert "c" not in parser
assert parser["d"] == {"new": 4} # Replaced
- def test_delete_on_nonexistent_key_idempotent(self):
- """Test that ~key is idempotent (no error when key doesn't exist)."""
+ def test_delete_on_nonexistent_key_raises_error(self):
+ """Test that ~key raises error when key doesn't exist."""
+ from sparkwheel.utils.exceptions import ConfigMergeError
+
base = {"a": 1}
override = {"~b": None}
- # NEW: No error! Idempotent delete
- result = apply_operators(base, override)
- assert result == {"a": 1}
+ # Should raise error when key doesn't exist
+ with pytest.raises(ConfigMergeError, match="Cannot delete key 'b': key does not exist"):
+ apply_operators(base, override)
def test_delete_directive_with_invalid_value_raises_error(self):
"""Test that ~key raises error when value is not null, empty, or list."""
@@ -888,6 +984,52 @@ def test_merge_with_empty_list(self):
assert result == {"items": ["a", "b"]}
+ def test_delete_nonexistent_top_level_key_shows_available_keys(self):
+ """Test that deleting a nonexistent top-level key shows available top-level keys in error."""
+ from sparkwheel.utils.exceptions import ConfigMergeError
+
+ config = Config().update({"model": {"lr": 0.001}, "trainer": {"epochs": 10}})
+
+ with pytest.raises(ConfigMergeError) as exc_info:
+ config.update({"~missing": None})
+
+ error_msg = str(exc_info.value)
+ assert "Cannot delete key 'missing'" in error_msg
+ # Should suggest top-level keys
+ assert "'model'" in error_msg
+ assert "'trainer'" in error_msg
+
+ def test_delete_nonexistent_nested_key_shows_parent_keys(self):
+ """Test that deleting a nonexistent nested key shows keys from parent container."""
+ from sparkwheel.utils.exceptions import ConfigMergeError
+
+ config = Config().update({"model": {"lr": 0.001, "dropout": 0.5, "hidden_size": 1024}})
+
+ with pytest.raises(ConfigMergeError) as exc_info:
+ config.update({"~model::missing": None})
+
+ error_msg = str(exc_info.value)
+ assert "Cannot remove non-existent key 'missing' from 'model'" in error_msg
+ # Should suggest keys from the parent (model)
+ assert "'lr'" in error_msg
+ assert "'dropout'" in error_msg
+ assert "'hidden_size'" in error_msg
+ # Should NOT show unrelated top-level keys
+ assert "'trainer'" not in error_msg or "trainer" not in config._data
+
+ def test_delete_nested_key_when_parent_doesnt_exist(self):
+ """Test error when trying to delete nested key but parent doesn't exist."""
+ from sparkwheel.utils.exceptions import ConfigMergeError
+
+ config = Config().update({"model": {"lr": 0.001}})
+
+ with pytest.raises(ConfigMergeError) as exc_info:
+ config.update({"~trainer::epochs": None})
+
+ # Should fail because 'trainer' doesn't exist
+ error_msg = str(exc_info.value)
+ assert "Cannot remove non-existent key 'epochs' from 'trainer'" in error_msg
+
class TestConfigAdvanced:
"""Test advanced Config features."""
diff --git a/tests/test_error_messages.py b/tests/test_error_messages.py
index a50a1d5..a2b6b86 100644
--- a/tests/test_error_messages.py
+++ b/tests/test_error_messages.py
@@ -2,7 +2,6 @@
from sparkwheel import Config
from sparkwheel.errors import (
- enable_colors,
format_available_keys,
format_resolution_chain,
format_suggestions,
@@ -341,221 +340,6 @@ def test_format_chain_without_reference(self):
assert "β" in result
-class TestColorFormatting:
- """Test color formatting with auto-detection."""
-
- def test_enable_colors_explicit(self):
- """Test explicitly enabling colors."""
- assert enable_colors(True) is True
- assert enable_colors(False) is False
-
- def test_enable_colors_auto_detect(self):
- """Test auto-detection."""
- result = enable_colors(None)
- # Should return a boolean
- assert isinstance(result, bool)
-
- def test_colors_disabled_in_tests(self):
- """Test that colors can be disabled."""
- from sparkwheel.errors.formatters import format_error
-
- enable_colors(False)
- formatted = format_error("error message")
-
- # Should not contain ANSI codes when disabled
- assert "\033[" not in formatted
- assert formatted == "error message"
-
- def test_colors_enabled(self):
- """Test that colors work when enabled."""
- from sparkwheel.errors.formatters import format_error
-
- enable_colors(True)
- formatted = format_error("error message")
-
- # Should contain ANSI codes when enabled
- assert "\033[" in formatted
- assert "error message" in formatted
-
- def test_format_suggestion(self):
- """Test format_suggestion function."""
- from sparkwheel.errors.formatters import format_suggestion
-
- enable_colors(True)
- result = format_suggestion("suggestion text")
- assert "suggestion text" in result
- assert "\033[" in result # Contains ANSI codes
-
- enable_colors(False)
- result = format_suggestion("suggestion text")
- assert result == "suggestion text"
-
- def test_format_success(self):
- """Test format_success function."""
- from sparkwheel.errors.formatters import format_success
-
- enable_colors(True)
- result = format_success("success text")
- assert "success text" in result
- assert "\033[" in result # Contains ANSI codes
-
- enable_colors(False)
- result = format_success("success text")
- assert result == "success text"
-
- def test_format_code(self):
- """Test format_code function."""
- from sparkwheel.errors.formatters import format_code
-
- enable_colors(True)
- result = format_code("code text")
- assert "code text" in result
- assert "\033[" in result # Contains ANSI codes
-
- enable_colors(False)
- result = format_code("code text")
- assert result == "code text"
-
- def test_format_context(self):
- """Test format_context function."""
- from sparkwheel.errors.formatters import format_context
-
- enable_colors(True)
- result = format_context("context text")
- assert "context text" in result
- assert "\033[" in result # Contains ANSI codes
-
- enable_colors(False)
- result = format_context("context text")
- assert result == "context text"
-
- def test_format_bold(self):
- """Test format_bold function."""
- from sparkwheel.errors.formatters import format_bold
-
- enable_colors(True)
- result = format_bold("bold text")
- assert "bold text" in result
- assert "\033[" in result # Contains ANSI codes
-
- enable_colors(False)
- result = format_bold("bold text")
- assert result == "bold text"
-
- def test_supports_color_with_no_color_env(self, monkeypatch):
- """Test that NO_COLOR environment variable disables colors."""
- from sparkwheel.errors.formatters import _supports_color
-
- monkeypatch.setenv("NO_COLOR", "1")
- assert _supports_color() is False
-
- def test_supports_color_with_sparkwheel_no_color_env(self, monkeypatch):
- """Test that SPARKWHEEL_NO_COLOR environment variable disables colors."""
- from sparkwheel.errors.formatters import _supports_color
-
- monkeypatch.setenv("SPARKWHEEL_NO_COLOR", "1")
- assert _supports_color() is False
-
- def test_get_colors_enabled_lazy_init(self):
- """Test that _get_colors_enabled initializes colors if needed."""
- from sparkwheel.errors import formatters
-
- # Reset global state
- formatters._COLORS_ENABLED = None
-
- # Should auto-detect when None
- result = formatters._get_colors_enabled()
- assert isinstance(result, bool)
- assert formatters._COLORS_ENABLED is not None
-
- def test_supports_color_force_color(self, monkeypatch):
- """Test FORCE_COLOR enables colors even without TTY."""
- import sys
-
- from sparkwheel.errors.formatters import _supports_color
-
- # Clear NO_COLOR and set FORCE_COLOR
- monkeypatch.delenv("NO_COLOR", raising=False)
- monkeypatch.delenv("SPARKWHEEL_NO_COLOR", raising=False)
- monkeypatch.setenv("FORCE_COLOR", "1")
-
- # Mock stdout without isatty
- class MockStdout:
- def isatty(self):
- return False
-
- original_stdout = sys.stdout
- try:
- sys.stdout = MockStdout()
- assert _supports_color() is True
- finally:
- sys.stdout = original_stdout
-
- def test_supports_color_no_color_overrides_force_color(self, monkeypatch):
- """Test NO_COLOR takes precedence over FORCE_COLOR."""
- import sys
-
- from sparkwheel.errors.formatters import _supports_color
-
- # Set both NO_COLOR and FORCE_COLOR - NO_COLOR should win
- monkeypatch.setenv("NO_COLOR", "1")
- monkeypatch.setenv("FORCE_COLOR", "1")
-
- # Mock stdout with TTY
- class MockStdout:
- def isatty(self):
- return True
-
- original_stdout = sys.stdout
- try:
- sys.stdout = MockStdout()
- # NO_COLOR should override both FORCE_COLOR and TTY
- assert _supports_color() is False
- finally:
- sys.stdout = original_stdout
-
- def test_supports_color_no_isatty(self, monkeypatch):
- """Test color detection when stdout has no isatty method."""
- import sys
-
- from sparkwheel.errors.formatters import _supports_color
-
- # Disable colors to override any CI environment settings
- monkeypatch.setenv("NO_COLOR", "1")
-
- # Mock stdout without isatty
- class MockStdout:
- pass
-
- original_stdout = sys.stdout
- try:
- sys.stdout = MockStdout()
- assert _supports_color() is False
- finally:
- sys.stdout = original_stdout
-
- def test_supports_color_isatty_false(self, monkeypatch):
- """Test color detection when isatty returns False."""
- import sys
-
- from sparkwheel.errors.formatters import _supports_color
-
- # Disable colors to override any CI environment settings
- monkeypatch.setenv("NO_COLOR", "1")
-
- # Mock stdout with isatty that returns False
- class MockStdout:
- def isatty(self):
- return False
-
- original_stdout = sys.stdout
- try:
- sys.stdout = MockStdout()
- assert _supports_color() is False
- finally:
- sys.stdout = original_stdout
-
-
class TestConfigKeyErrorEnhanced:
"""Test enhanced ConfigKeyError with suggestions."""
@@ -642,17 +426,17 @@ class TestExceptionEdgeCases:
"""Test edge cases in exception handling."""
def test_source_location_without_id(self):
- """Test SourceLocation string formatting without ID."""
- from sparkwheel.utils.exceptions import SourceLocation
+ """Test Location string formatting without ID."""
+ from sparkwheel.utils.exceptions import Location
- loc = SourceLocation(filepath="test.yaml", line=10)
+ loc = Location(filepath="test.yaml", line=10)
assert str(loc) == "test.yaml:10"
def test_base_error_without_source_location_id(self):
"""Test BaseError formatting when source_location has no ID."""
- from sparkwheel.utils.exceptions import BaseError, SourceLocation
+ from sparkwheel.utils.exceptions import BaseError, Location
- loc = SourceLocation(filepath="test.yaml", line=5)
+ loc = Location(filepath="test.yaml", line=5)
error = BaseError("Test error", source_location=loc)
msg = str(error)
assert "[test.yaml:5]" in msg
@@ -660,10 +444,10 @@ def test_base_error_without_source_location_id(self):
def test_base_error_snippet_file_read_error(self, tmp_path):
"""Test BaseError snippet handling when file can't be read."""
- from sparkwheel.utils.exceptions import BaseError, SourceLocation
+ from sparkwheel.utils.exceptions import BaseError, Location
# Create a source location pointing to non-existent file
- loc = SourceLocation(filepath="/nonexistent/file.yaml", line=5)
+ loc = Location(filepath="/nonexistent/file.yaml", line=5)
error = BaseError("Test error", source_location=loc)
# Should not raise, just skip snippet
msg = str(error)
diff --git a/tests/test_items.py b/tests/test_items.py
index 0a664f6..bbd07da 100644
--- a/tests/test_items.py
+++ b/tests/test_items.py
@@ -193,7 +193,7 @@ def test_instantiate_instantiation_error(self):
config = {"_target_": "builtins.int", "invalid_arg": 123}
component = Component(config=config)
- with pytest.raises(InstantiationError, match="Failed to instantiate"):
+ with pytest.raises(InstantiationError):
component.instantiate()
def test_instantiate_with_kwargs_override(self):
@@ -287,7 +287,7 @@ def test_instantiate_args_not_in_non_arg_keys(self):
"""Test that _args_ is properly excluded from kwargs."""
config = {"_target_": "builtins.dict", "_args_": [], "a": 1}
component = Component(config=config)
- args, kwargs = component.resolve_args()
+ _, kwargs = component.resolve_args()
# Verify _args_ is not passed as a kwarg
assert "_args_" not in kwargs
@@ -454,14 +454,14 @@ def test_evaluate_debug_mode_enabled(self):
sparkwheel.utils.run_debug = original_debug
-class TestItemWithSourceLocation:
+class TestItemWithLocation:
"""Test Item classes with source location tracking."""
def test_item_with_source_location(self):
"""Test creating item with source location."""
- from sparkwheel.utils.exceptions import SourceLocation
+ from sparkwheel.utils.exceptions import Location
- location = SourceLocation(filepath="/tmp/config.yaml", line=10, column=5, id="model::lr")
+ location = Location(filepath="/tmp/config.yaml", line=10, column=5, id="model::lr")
item = Item(config={"lr": 0.001}, id="model::lr", source_location=location)
assert item.source_location == location
@@ -470,9 +470,9 @@ def test_item_with_source_location(self):
def test_component_error_includes_source_location(self):
"""Test that component errors include source location."""
- from sparkwheel.utils.exceptions import SourceLocation
+ from sparkwheel.utils.exceptions import Location
- location = SourceLocation(filepath="/tmp/config.yaml", line=15, column=2, id="model")
+ location = Location(filepath="/tmp/config.yaml", line=15, column=2, id="model")
config = {"_target_": "nonexistent.Module"}
component = Component(config=config, id="model", source_location=location)
@@ -484,9 +484,9 @@ def test_component_error_includes_source_location(self):
def test_expression_error_includes_source_location(self):
"""Test that expression errors include source location."""
- from sparkwheel.utils.exceptions import SourceLocation
+ from sparkwheel.utils.exceptions import Location
- location = SourceLocation(filepath="/tmp/config.yaml", line=20, column=2, id="calc")
+ location = Location(filepath="/tmp/config.yaml", line=20, column=2, id="calc")
expr = Expression(config="$undefined_var", id="calc", source_location=location)
with pytest.raises(EvaluationError) as exc_info:
@@ -495,6 +495,47 @@ def test_expression_error_includes_source_location(self):
error = exc_info.value
assert error.source_location == location
+ def test_instantiation_error_with_location_reraises_unchanged(self):
+ """Test that InstantiationError with source_location is re-raised unchanged.
+
+ This tests the re-raise path in Component.instantiate() where an InstantiationError
+ already has source_location set and should be re-raised as-is without modification.
+ """
+ from unittest.mock import patch
+
+ from sparkwheel.utils.exceptions import Location
+
+ # Create a location for the inner error (different from component's location)
+ inner_location = Location(filepath="/tmp/inner.yaml", line=99, column=1, id="inner_id")
+ component_location = Location(filepath="/tmp/outer.yaml", line=5, column=2, id="outer_id")
+
+ # Create the exception to be raised
+ inner_error = InstantiationError(
+ "Inner error message",
+ source_location=inner_location,
+ suggestion="Inner suggestion",
+ )
+
+ # Mock the instantiate function to raise InstantiationError with source_location
+ def mock_instantiate(*args, **kwargs):
+ raise inner_error
+
+ config = {"_target_": "some.module.Class"}
+ component = Component(config=config, id="test", source_location=component_location)
+
+ with patch("sparkwheel.items.instantiate", mock_instantiate):
+ with pytest.raises(InstantiationError) as exc_info:
+ component.instantiate()
+
+ error = exc_info.value
+ # The exception should be re-raised unchanged (not wrapped with component's location)
+ assert error is inner_error
+ assert error._original_message == "Inner error message"
+ assert error.suggestion == "Inner suggestion"
+ assert error.source_location == inner_location
+ assert error.source_location.filepath == "/tmp/inner.yaml"
+ assert error.source_location.line == 99
+
class TestItemsEdgeCases:
"""Test edge cases in items module."""
diff --git a/tests/test_loader.py b/tests/test_loader.py
index 7b40811..7ec67dd 100644
--- a/tests/test_loader.py
+++ b/tests/test_loader.py
@@ -3,8 +3,8 @@
import pytest
from sparkwheel.loader import Loader
-from sparkwheel.metadata import MetadataRegistry
-from sparkwheel.utils.exceptions import SourceLocation
+from sparkwheel.locations import LocationRegistry
+from sparkwheel.utils.exceptions import Location
class TestLoaderBasic:
@@ -19,7 +19,7 @@ def test_load_file_basic(self, tmp_path):
config, metadata = loader.load_file(str(config_file))
assert config == {"key": "value", "number": 42}
- assert isinstance(metadata, MetadataRegistry)
+ assert isinstance(metadata, LocationRegistry)
def test_load_file_empty_filepath(self):
"""Test loading with empty filepath returns empty config."""
@@ -27,7 +27,7 @@ def test_load_file_empty_filepath(self):
config, metadata = loader.load_file("")
assert config == {}
- assert isinstance(metadata, MetadataRegistry)
+ assert isinstance(metadata, LocationRegistry)
def test_load_file_none_filepath(self):
"""Test loading with None filepath returns empty config."""
@@ -35,7 +35,7 @@ def test_load_file_none_filepath(self):
config, metadata = loader.load_file(None)
assert config == {}
- assert isinstance(metadata, MetadataRegistry)
+ assert isinstance(metadata, LocationRegistry)
def test_load_file_non_yaml_extension(self, tmp_path):
"""Test loading non-YAML file raises ValueError."""
@@ -136,7 +136,7 @@ def test_load_files_empty_list(self):
config, metadata = loader.load_files([])
assert config == {}
- assert isinstance(metadata, MetadataRegistry)
+ assert isinstance(metadata, LocationRegistry)
def test_load_files_single_file(self, tmp_path):
"""Test loading single file via load_files."""
@@ -236,18 +236,18 @@ def test_strip_metadata_from_dicts_in_lists(self, tmp_path):
assert "__sparkwheel_metadata__" not in str(config)
-class TestMetadataRegistry:
- """Test MetadataRegistry functionality."""
+class TestLocationRegistry:
+ """Test LocationRegistry functionality."""
def test_create_registry(self):
"""Test creating empty registry."""
- registry = MetadataRegistry()
+ registry = LocationRegistry()
assert len(registry) == 0
def test_register_and_get(self):
"""Test registering and getting source locations."""
- registry = MetadataRegistry()
- location = SourceLocation(filepath="config.yaml", line=10, column=5, id="model::lr")
+ registry = LocationRegistry()
+ location = Location(filepath="config.yaml", line=10, column=5, id="model::lr")
registry.register("model::lr", location)
@@ -257,13 +257,13 @@ def test_register_and_get(self):
def test_get_nonexistent_returns_none(self):
"""Test getting nonexistent location returns None."""
- registry = MetadataRegistry()
+ registry = LocationRegistry()
assert registry.get("nonexistent") is None
def test_len(self):
"""Test registry length."""
- registry = MetadataRegistry()
- location = SourceLocation(filepath="config.yaml", line=10, column=5, id="key")
+ registry = LocationRegistry()
+ location = Location(filepath="config.yaml", line=10, column=5, id="key")
assert len(registry) == 0
@@ -275,8 +275,8 @@ def test_len(self):
def test_contains(self):
"""Test __contains__ operator."""
- registry = MetadataRegistry()
- location = SourceLocation(filepath="config.yaml", line=10, column=5, id="key")
+ registry = LocationRegistry()
+ location = Location(filepath="config.yaml", line=10, column=5, id="key")
assert "key" not in registry
@@ -285,11 +285,11 @@ def test_contains(self):
def test_merge(self):
"""Test merging registries."""
- registry1 = MetadataRegistry()
- registry2 = MetadataRegistry()
+ registry1 = LocationRegistry()
+ registry2 = LocationRegistry()
- location1 = SourceLocation(filepath="file1.yaml", line=5, column=2, id="key1")
- location2 = SourceLocation(filepath="file2.yaml", line=10, column=3, id="key2")
+ location1 = Location(filepath="file1.yaml", line=5, column=2, id="key1")
+ location2 = Location(filepath="file2.yaml", line=10, column=3, id="key2")
registry1.register("key1", location1)
registry2.register("key2", location2)
@@ -302,8 +302,8 @@ def test_merge(self):
def test_copy(self):
"""Test copying registry."""
- registry = MetadataRegistry()
- location = SourceLocation(filepath="config.yaml", line=10, column=5, id="key")
+ registry = LocationRegistry()
+ location = Location(filepath="config.yaml", line=10, column=5, id="key")
registry.register("key", location)
@@ -313,7 +313,7 @@ def test_copy(self):
assert copied.get("key") == location
# Ensure it's a real copy, not a reference
- location2 = SourceLocation(filepath="other.yaml", line=20, column=1, id="other")
+ location2 = Location(filepath="other.yaml", line=20, column=1, id="other")
copied.register("other", location2)
assert "other" in copied
diff --git a/tests/test_operators.py b/tests/test_operators.py
new file mode 100644
index 0000000..93ee68f
--- /dev/null
+++ b/tests/test_operators.py
@@ -0,0 +1,395 @@
+"""Comprehensive tests for operators module.
+
+This module tests the operators.py module functionality:
+- MergeContext class and location tracking
+- Operator validation (_validate_delete_operator)
+- apply_operators function with all edge cases
+"""
+
+import pytest
+
+from sparkwheel.locations import LocationRegistry
+from sparkwheel.operators import MergeContext, _validate_delete_operator, apply_operators, validate_operators
+from sparkwheel.utils.exceptions import ConfigMergeError, Location
+
+
+class TestMergeContext:
+ """Test MergeContext class for tracking merge operations."""
+
+ def test_child_path_empty_base(self):
+ """Test creating child path from empty base."""
+ ctx = MergeContext()
+ child = ctx.child_path("model")
+ assert child.current_path == "model"
+
+ def test_child_path_with_base(self):
+ """Test creating child path with existing base."""
+ ctx = MergeContext(current_path="model")
+ child = ctx.child_path("optimizer")
+ assert child.current_path == "model::optimizer"
+
+ def test_child_path_nested(self):
+ """Test creating deeply nested child paths."""
+ ctx = MergeContext()
+ child1 = ctx.child_path("model")
+ child2 = child1.child_path("optimizer")
+ child3 = child2.child_path("lr")
+ assert child3.current_path == "model::optimizer::lr"
+
+ def test_get_source_location_no_registry(self):
+ """Test get_source_location returns None when no registry."""
+ ctx = MergeContext()
+ location = ctx.get_source_location("key")
+ assert location is None
+
+ def test_get_source_location_with_registry(self):
+ """Test get_source_location with registry."""
+ registry = LocationRegistry()
+ test_location = Location(filepath="config.yaml", line=10)
+ registry.register("model::lr", test_location)
+
+ ctx = MergeContext(locations=registry, current_path="model")
+ location = ctx.get_source_location("lr")
+ assert location == test_location
+
+ def test_get_source_location_with_remove_operator(self):
+ """Test get_source_location strips ~ operator prefix."""
+ registry = LocationRegistry()
+ test_location = Location(filepath="config.yaml", line=10)
+ registry.register("model::lr", test_location)
+
+ ctx = MergeContext(locations=registry, current_path="model")
+ # Should strip ~ and find "lr"
+ location = ctx.get_source_location("~lr")
+ assert location == test_location
+
+ def test_get_source_location_with_replace_operator(self):
+ """Test get_source_location strips = operator prefix."""
+ registry = LocationRegistry()
+ test_location = Location(filepath="config.yaml", line=10)
+ registry.register("model::lr", test_location)
+
+ ctx = MergeContext(locations=registry, current_path="model")
+ # Should strip = and find "lr"
+ location = ctx.get_source_location("=lr")
+ assert location == test_location
+
+ def test_get_source_location_operator_key_takes_precedence(self):
+ """Test that exact operator key match takes precedence."""
+ registry = LocationRegistry()
+ exact_location = Location(filepath="config.yaml", line=5)
+ fallback_location = Location(filepath="config.yaml", line=10)
+
+ # Register both ~lr and lr
+ registry.register("model::~lr", exact_location)
+ registry.register("model::lr", fallback_location)
+
+ ctx = MergeContext(locations=registry, current_path="model")
+ # Should find exact match first
+ location = ctx.get_source_location("~lr")
+ assert location == exact_location
+
+ def test_get_source_location_not_found(self):
+ """Test get_source_location returns None for missing keys."""
+ registry = LocationRegistry()
+ ctx = MergeContext(locations=registry, current_path="model")
+ location = ctx.get_source_location("missing_key")
+ assert location is None
+
+ def test_get_source_location_empty_path(self):
+ """Test get_source_location with empty current path."""
+ registry = LocationRegistry()
+ test_location = Location(filepath="config.yaml", line=10)
+ registry.register("lr", test_location)
+
+ ctx = MergeContext(locations=registry, current_path="")
+ location = ctx.get_source_location("lr")
+ assert location == test_location
+
+ def test_get_source_location_operator_key_not_found(self):
+ """Test get_source_location returns None when operator key not found."""
+ registry = LocationRegistry()
+ # Register only "other", not "lr" or "~lr"
+ test_location = Location(filepath="config.yaml", line=10)
+ registry.register("model::other", test_location)
+
+ ctx = MergeContext(locations=registry, current_path="model")
+ # Should not find ~lr or lr
+ location = ctx.get_source_location("~lr")
+ assert location is None
+
+ def test_get_source_location_regular_key_not_found(self):
+ """Test get_source_location returns None for regular key not found."""
+ registry = LocationRegistry()
+ test_location = Location(filepath="config.yaml", line=10)
+ registry.register("model::other", test_location)
+
+ ctx = MergeContext(locations=registry, current_path="model")
+ # Should not find "missing" (regular key, no operator)
+ location = ctx.get_source_location("missing")
+ assert location is None
+
+
+class TestValidateDeleteOperator:
+ """Test _validate_delete_operator function."""
+
+ def test_valid_null_value(self):
+ """Test that null value is valid."""
+ _validate_delete_operator("key", None) # Should not raise
+
+ def test_valid_empty_string(self):
+ """Test that empty string is valid."""
+ _validate_delete_operator("key", "") # Should not raise
+
+ def test_valid_list_value(self):
+ """Test that list value is valid."""
+ _validate_delete_operator("key", [0, 1, 2]) # Should not raise
+
+ def test_invalid_string_value(self):
+ """Test that non-empty string value raises error."""
+ with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"):
+ _validate_delete_operator("key", "invalid")
+
+ def test_invalid_dict_value(self):
+ """Test that dict value raises error."""
+ with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"):
+ _validate_delete_operator("key", {"nested": "value"})
+
+ def test_invalid_int_value(self):
+ """Test that int value raises error."""
+ with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"):
+ _validate_delete_operator("key", 123)
+
+ def test_empty_list_raises_error(self):
+ """Test that empty list raises error."""
+ with pytest.raises(ConfigMergeError, match="cannot be empty"):
+ _validate_delete_operator("key", [])
+
+
+class TestValidateOperators:
+ """Test validate_operators function."""
+
+ def test_validate_non_dict_config(self):
+ """Test that non-dict config is handled gracefully."""
+ # Should not raise for non-dict
+ validate_operators("not a dict") # type: ignore[arg-type]
+ validate_operators(123) # type: ignore[arg-type]
+ validate_operators(None) # type: ignore[arg-type]
+
+ def test_validate_remove_operator(self):
+ """Test validation of remove operator."""
+ config = {"~key": None}
+ validate_operators(config) # Should not raise
+
+ def test_validate_remove_operator_invalid_value(self):
+ """Test validation catches invalid remove operator value."""
+ config = {"~key": "invalid"}
+ with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"):
+ validate_operators(config)
+
+ def test_validate_replace_operator(self):
+ """Test validation of replace operator."""
+ config = {"=key": "value"}
+ validate_operators(config) # Should not raise
+
+ def test_validate_nested_remove_operator(self):
+ """Test validation of nested remove operator."""
+ config = {"model": {"~lr": None}}
+ validate_operators(config) # Should not raise
+
+ def test_validate_nested_remove_operator_invalid(self):
+ """Test validation catches invalid nested remove operator."""
+ config = {"model": {"~lr": 123}}
+ with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"):
+ validate_operators(config)
+
+ def test_validate_skips_dict_under_remove(self):
+ """Test that validation catches dict value under remove operator."""
+ # A dict value under a remove operator is invalid
+ # Remove operators must have null, empty, or list values
+ config = {"~key": {"nested": "value"}}
+ with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"):
+ validate_operators(config)
+
+
+class TestApplyOperatorsEdgeCases:
+ """Test edge cases in apply_operators function."""
+
+ def test_non_dict_base_returns_override(self):
+ """Test that non-dict base returns deepcopy of override."""
+ result = apply_operators("not a dict", {"key": "value"}) # type: ignore[arg-type]
+ assert result == {"key": "value"}
+
+ def test_non_dict_override_returns_override(self):
+ """Test that non-dict override returns deepcopy of override."""
+ result = apply_operators({"key": "value"}, "not a dict") # type: ignore[arg-type]
+ assert result == "not a dict"
+
+ def test_both_non_dict_returns_override(self):
+ """Test that both non-dict returns override."""
+ result = apply_operators("base", "override") # type: ignore[arg-type]
+ assert result == "override"
+
+ def test_non_string_key_copied_directly(self):
+ """Test that non-string keys are copied directly."""
+ base = {}
+ override = {123: "numeric_key"} # type: ignore[dict-item]
+ result = apply_operators(base, override)
+ assert result[123] == "numeric_key" # type: ignore[index]
+
+ def test_context_propagates_to_nested_merges(self):
+ """Test that context is properly propagated in nested merges."""
+ registry = LocationRegistry()
+ ctx = MergeContext(locations=registry)
+
+ base = {"model": {"lr": 0.001}}
+ override = {"model": {"dropout": 0.1}}
+
+ result = apply_operators(base, override, context=ctx)
+ assert result == {"model": {"lr": 0.001, "dropout": 0.1}}
+
+ def test_delete_with_context_location_tracking(self):
+ """Test that delete operator uses context for error messages."""
+ registry = LocationRegistry()
+ test_location = Location(filepath="config.yaml", line=10)
+ registry.register("~missing", test_location)
+
+ ctx = MergeContext(locations=registry)
+
+ base = {"existing": "value"}
+ override = {"~missing": None}
+
+ with pytest.raises(ConfigMergeError, match="Cannot delete key 'missing'"):
+ apply_operators(base, override, context=ctx)
+
+ def test_delete_list_items_no_context(self):
+ """Test delete list items works without context."""
+ base = {"items": [1, 2, 3, 4, 5]}
+ override = {"~items": [0, 2, 4]}
+ result = apply_operators(base, override)
+ assert result == {"items": [2, 4]}
+
+ def test_delete_dict_keys_with_context(self):
+ """Test delete dict keys with context for better error messages."""
+ registry = LocationRegistry()
+ test_location = Location(filepath="config.yaml", line=10)
+ registry.register("~model", test_location)
+
+ ctx = MergeContext(locations=registry)
+
+ base = {"model": {"lr": 0.001, "dropout": 0.1}}
+ override = {"~model": ["missing_key"]}
+
+ with pytest.raises(ConfigMergeError, match="Cannot remove non-existent key 'missing_key'"):
+ apply_operators(base, override, context=ctx)
+
+ def test_replace_operator_with_none_value(self):
+ """Test replace operator can set value to None."""
+ base = {"key": "value"}
+ override = {"=key": None}
+ result = apply_operators(base, override)
+ assert result == {"key": None}
+
+ def test_composition_list_extend_preserves_order(self):
+ """Test that list composition preserves order."""
+ base = {"items": [1, 2, 3]}
+ override = {"items": [4, 5, 6]}
+ result = apply_operators(base, override)
+ assert result == {"items": [1, 2, 3, 4, 5, 6]}
+
+ def test_composition_dict_merge_deep(self):
+ """Test that dict composition merges deeply."""
+ base = {"model": {"optimizer": {"lr": 0.001, "momentum": 0.9}}}
+ override = {"model": {"optimizer": {"lr": 0.01}}}
+ result = apply_operators(base, override)
+ assert result == {"model": {"optimizer": {"lr": 0.01, "momentum": 0.9}}}
+
+ def test_scalar_replacement_on_type_mismatch(self):
+ """Test that type mismatches cause replacement."""
+ base = {"value": [1, 2, 3]}
+ override = {"value": "string"}
+ result = apply_operators(base, override)
+ assert result == {"value": "string"}
+
+ def test_new_key_addition(self):
+ """Test adding new keys to config."""
+ base = {"existing": "value"}
+ override = {"new_key": "new_value"}
+ result = apply_operators(base, override)
+ assert result == {"existing": "value", "new_key": "new_value"}
+
+ def test_delete_list_negative_indices_normalized(self):
+ """Test that negative indices are properly normalized."""
+ base = {"items": [1, 2, 3, 4, 5]}
+ override = {"~items": [-1, -2]} # Remove last two items
+ result = apply_operators(base, override)
+ assert result == {"items": [1, 2, 3]}
+
+ def test_delete_list_duplicate_indices_handled(self):
+ """Test that duplicate indices are handled correctly."""
+ base = {"items": [1, 2, 3, 4, 5]}
+ override = {"~items": [1, 1, 1]} # Duplicate index
+ result = apply_operators(base, override)
+ assert result == {"items": [1, 3, 4, 5]} # Only removed once
+
+ def test_delete_items_errors_on_scalar(self):
+ """Test that deleting items from scalar raises error."""
+ base = {"value": "scalar"}
+ override = {"~value": [0]}
+
+ with pytest.raises(ConfigMergeError, match="expected list or dict"):
+ apply_operators(base, override)
+
+ def test_multiple_operators_in_one_override(self):
+ """Test multiple operators in single override dict."""
+ base = {"a": 1, "b": 2, "c": 3, "d": {"x": 1, "y": 2}}
+ override = {
+ "a": 10, # Compose (replace scalar)
+ "=b": 20, # Explicit replace
+ "~c": None, # Delete
+ "d": {"x": 10}, # Compose dict (merge)
+ }
+ result = apply_operators(base, override)
+ assert result == {"a": 10, "b": 20, "d": {"x": 10, "y": 2}}
+
+
+class TestApplyOperatorsDeepCopy:
+ """Test that apply_operators properly deep copies values."""
+
+ def test_base_not_mutated(self):
+ """Test that base dict is not mutated."""
+ base = {"model": {"lr": 0.001}}
+ override = {"model": {"dropout": 0.1}}
+ result = apply_operators(base, override)
+
+ # Modify result
+ result["model"]["lr"] = 0.01
+
+ # Base should be unchanged
+ assert base["model"]["lr"] == 0.001
+
+ def test_override_not_mutated(self):
+ """Test that override dict is not mutated."""
+ base = {"model": {"lr": 0.001}}
+ override = {"model": {"dropout": 0.1}}
+ result = apply_operators(base, override)
+
+ # Modify result
+ result["model"]["dropout"] = 0.5
+
+ # Override should be unchanged
+ assert override["model"]["dropout"] == 0.1
+
+ def test_result_is_independent(self):
+ """Test that result is independent of base and override."""
+ base = {"items": [1, 2, 3]}
+ override = {"items": [4, 5]}
+ result = apply_operators(base, override)
+
+ # Modify result
+ result["items"].append(6)
+
+ # Base and override should be unchanged
+ assert base["items"] == [1, 2, 3]
+ assert override["items"] == [4, 5]
+ assert result["items"] == [1, 2, 3, 4, 5, 6]
diff --git a/tests/test_path_utils.py b/tests/test_path_utils.py
index b90f607..a0c69ac 100644
--- a/tests/test_path_utils.py
+++ b/tests/test_path_utils.py
@@ -1,6 +1,8 @@
"""Tests for path utility functions."""
-from sparkwheel.path_utils import PathPatterns, find_references
+import pytest
+
+from sparkwheel.path_utils import PathPatterns, find_references, get_by_id
class TestPathPatterns:
@@ -28,3 +30,131 @@ def test_find_references_empty_for_plain_text(self):
"""Test find_references returns empty for plain text."""
refs = find_references("just plain text")
assert refs == []
+
+
+class TestGetById:
+ """Test get_by_id function for navigating config structures."""
+
+ def test_empty_id_returns_whole_config(self):
+ """Test get_by_id with empty ID returns whole config."""
+ config = {"key": "value", "nested": {"item": 123}}
+ result = get_by_id(config, "")
+
+ assert result == config
+
+ def test_list_indexing(self):
+ """Test get_by_id with list indexing."""
+ config = {"items": [10, 20, 30]}
+ result = get_by_id(config, "items::1")
+
+ assert result == 20
+
+ def test_nested_list(self):
+ """Test get_by_id with nested structures including lists."""
+ config = {"data": {"values": [{"x": 1}, {"x": 2}, {"x": 3}]}}
+ result = get_by_id(config, "data::values::2::x")
+
+ assert result == 3
+
+ def test_type_error_on_primitive(self):
+ """Test get_by_id raises TypeError when trying to index a primitive value."""
+ config = {"value": 42}
+
+ with pytest.raises(TypeError, match="Cannot index int"):
+ get_by_id(config, "value::subkey")
+
+ def test_missing_key_first_level(self):
+ """Test get_by_id with missing key at first level shows non-redundant error."""
+ config = {"foo": 1, "bar": 2}
+
+ with pytest.raises(KeyError) as exc_info:
+ get_by_id(config, "missing")
+
+ error_msg = str(exc_info.value)
+ assert "Key 'missing' not found" in error_msg
+ assert "Available keys:" in error_msg
+ assert "'foo'" in error_msg
+ assert "'bar'" in error_msg
+ # Should NOT say "at path 'missing'" for first level
+ assert "at path" not in error_msg
+
+ def test_missing_key_nested(self):
+ """Test get_by_id with missing nested key shows parent path."""
+ config = {"data": {"train": {"lr": 0.001, "epochs": 10}}}
+
+ with pytest.raises(KeyError) as exc_info:
+ get_by_id(config, "data::train::missing")
+
+ error_msg = str(exc_info.value)
+ assert "Key 'missing' not found in 'data::train'" in error_msg
+ assert "Available keys:" in error_msg
+ assert "'lr'" in error_msg
+ assert "'epochs'" in error_msg
+
+ def test_invalid_list_index(self):
+ """Test get_by_id with invalid list index."""
+ config = {"items": [1, 2, 3]}
+
+ with pytest.raises(KeyError) as exc_info:
+ get_by_id(config, "items::10")
+
+ error_msg = str(exc_info.value)
+ assert "List index '10' out of range" in error_msg
+ assert "in 'items'" in error_msg
+
+ def test_invalid_list_index_first_level(self):
+ """Test get_by_id with invalid list index at first level."""
+ config = [1, 2, 3]
+
+ with pytest.raises(KeyError) as exc_info:
+ get_by_id(config, "10")
+
+ error_msg = str(exc_info.value)
+ assert "List index '10' out of range" in error_msg
+ # Should not mention parent path for first level
+ assert "in '" not in error_msg
+
+ def test_invalid_list_index_non_integer(self):
+ """Test get_by_id with non-integer list index."""
+ config = {"items": [1, 2, 3]}
+
+ with pytest.raises(KeyError) as exc_info:
+ get_by_id(config, "items::abc")
+
+ error_msg = str(exc_info.value)
+ assert "Invalid list index 'abc'" in error_msg
+ assert "not an integer" in error_msg
+
+ def test_type_error_first_level(self):
+ """Test type error at first level shows clean message."""
+ config = "string_value"
+
+ with pytest.raises(TypeError) as exc_info:
+ get_by_id(config, "foo")
+
+ error_msg = str(exc_info.value)
+ assert "Cannot index str with key 'foo'" in error_msg
+ # Should not mention parent path for first level
+ assert "in '" not in error_msg
+
+ def test_type_error_nested(self):
+ """Test type error in nested path shows parent."""
+ config = {"data": {"value": 42}}
+
+ with pytest.raises(TypeError) as exc_info:
+ get_by_id(config, "data::value::foo")
+
+ error_msg = str(exc_info.value)
+ assert "Cannot index int with key 'foo'" in error_msg
+ assert "in 'data::value'" in error_msg
+
+ def test_available_keys_truncated(self):
+ """Test that available keys list is truncated when > 10 keys."""
+ config = {f"key_{i}": i for i in range(20)}
+
+ with pytest.raises(KeyError) as exc_info:
+ get_by_id(config, "missing")
+
+ error_msg = str(exc_info.value)
+ assert "Available keys:" in error_msg
+ assert "..." in error_msg # Should be truncated
diff --git a/tests/test_preprocessor.py b/tests/test_preprocessor.py
index f33d0f7..121ad38 100644
--- a/tests/test_preprocessor.py
+++ b/tests/test_preprocessor.py
@@ -4,6 +4,7 @@
from sparkwheel.loader import Loader
from sparkwheel.preprocessor import Preprocessor
+from sparkwheel.utils.exceptions import CircularReferenceError, ConfigKeyError
class TestPreprocessor:
@@ -19,35 +20,247 @@ def test_circular_raw_reference(self, tmp_path):
preprocessor = Preprocessor(loader)
# Load the config and try to process it
- config, _ = loader.load_file(str(config_file))
+ config, locations = loader.load_file(str(config_file))
- with pytest.raises(ValueError, match="Circular raw reference detected"):
- preprocessor.process(config, config)
+ with pytest.raises(CircularReferenceError, match="Circular raw reference detected"):
+ preprocessor.process_raw_refs(config, config, locations=locations)
- def test_get_by_id_empty_id(self):
- """Test _get_by_id with empty ID returns whole config."""
- config = {"key": "value", "nested": {"item": 123}}
- result = Preprocessor._get_by_id(config, "")
+ def test_raw_ref_missing_key_with_location(self, tmp_path):
+ """Test raw reference error includes source location."""
+ # Create a config file with a raw reference to a missing key
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text('value: "%missing::key"')
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config, locations = loader.load_file(str(config_file))
+
+ with pytest.raises(ConfigKeyError) as exc_info:
+ preprocessor.process_raw_refs(config, config, locations=locations)
+
+ error = exc_info.value
+ assert error.source_location is not None
+ assert error.source_location.filepath == str(config_file)
+ assert error.source_location.line == 1
+ assert "Error resolving raw reference" in error._original_message
+ assert "Key 'missing' not found" in error._original_message
+
+ def test_raw_ref_external_file_missing_key(self, tmp_path):
+ """Test raw reference to external file with missing key."""
+ # Create external file
+ external_file = tmp_path / "external.yaml"
+ external_file.write_text("foo: 1\nbar: 2")
+
+ # Create main config that references missing key in external file
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text(f'value: "%{external_file}::missing"')
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config, locations = loader.load_file(str(config_file))
+
+ with pytest.raises(ConfigKeyError) as exc_info:
+ preprocessor.process_raw_refs(config, config, locations=locations)
+
+ error = exc_info.value
+ assert error.source_location is not None
+ assert error.source_location.filepath == str(config_file)
+ assert f"from '{external_file}'" in error._original_message
+ assert "Key 'missing' not found" in error._original_message
+
+ def test_raw_ref_nested_missing_key(self, tmp_path):
+ """Test raw reference with nested path where middle key is missing."""
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text('data:\n foo: 1\nvalue: "%data::missing::key"')
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config, locations = loader.load_file(str(config_file))
+
+ with pytest.raises(ConfigKeyError) as exc_info:
+ preprocessor.process_raw_refs(config, config, locations=locations)
+
+ error = exc_info.value
+ assert "Key 'missing' not found in 'data'" in error._original_message
+
+ def test_circular_reference_with_location(self, tmp_path):
+ """Test circular reference error includes source location."""
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text('a: "%b"\nb: "%a"')
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config, locations = loader.load_file(str(config_file))
+
+ with pytest.raises(CircularReferenceError) as exc_info:
+ preprocessor.process_raw_refs(config, config, locations=locations)
+
+ error = exc_info.value
+ assert error.source_location is not None
+ assert error.source_location.filepath == str(config_file)
+ assert "Reference chain:" in error._original_message
+
+ def test_raw_ref_expansion_success(self, tmp_path):
+ """Test successful raw reference expansion."""
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text('base_lr: 0.001\nmodel:\n lr: "%base_lr"')
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config, locations = loader.load_file(str(config_file))
+ result = preprocessor.process_raw_refs(config, config, locations=locations)
+
+ assert result["model"]["lr"] == 0.001
+
+ def test_raw_ref_nested_expansion(self, tmp_path):
+ """Test nested raw reference expansion."""
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text('a:\n b:\n c: 42\nvalue: "%a::b::c"')
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config, locations = loader.load_file(str(config_file))
+ result = preprocessor.process_raw_refs(config, config, locations=locations)
- assert result == config
+ assert result["value"] == 42
- def test_get_by_id_list_indexing(self):
- """Test _get_by_id with list indexing."""
- config = {"items": [10, 20, 30]}
- result = Preprocessor._get_by_id(config, "items::1")
+ def test_raw_ref_external_file_success(self, tmp_path):
+ """Test successful raw reference from external file."""
+ external_file = tmp_path / "base.yaml"
+ external_file.write_text("learning_rate: 0.001")
- assert result == 20
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text(f'lr: "%{external_file}::learning_rate"')
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config, locations = loader.load_file(str(config_file))
+ result = preprocessor.process_raw_refs(config, config, locations=locations)
+
+ assert result["lr"] == 0.001
+
+ def test_raw_ref_in_list(self, tmp_path):
+ """Test raw reference inside a list item is expanded."""
+ config_file = tmp_path / "config.yaml"
+ config_file.write_text('base_value: 42\nitems:\n - "%base_value"\n - 100')
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config, locations = loader.load_file(str(config_file))
+ result = preprocessor.process_raw_refs(config, config, locations=locations)
+
+ assert result["items"][0] == 42
+ assert result["items"][1] == 100
+
+
+class TestPreprocessorExternalOnly:
+ """Test external_only parameter for two-phase raw reference expansion."""
+
+ def test_external_only_expands_external_refs(self, tmp_path):
+ """Test that external_only=True expands external file refs."""
+ external_file = tmp_path / "external.yaml"
+ external_file.write_text("value: 42")
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config = {"external_ref": f"%{external_file}::value", "local_ref": "%local_key", "local_key": 100}
+
+ result = preprocessor.process_raw_refs(config, config, external_only=True)
+
+ # External ref should be expanded
+ assert result["external_ref"] == 42
+ # Local ref should remain as string
+ assert result["local_ref"] == "%local_key"
+ # Local key unchanged
+ assert result["local_key"] == 100
+
+ def test_external_only_false_expands_all_refs(self, tmp_path):
+ """Test that external_only=False expands all refs including local."""
+ external_file = tmp_path / "external.yaml"
+ external_file.write_text("value: 42")
- def test_get_by_id_nested_list(self):
- """Test _get_by_id with nested structures including lists."""
- config = {"data": {"values": [{"x": 1}, {"x": 2}, {"x": 3}]}}
- result = Preprocessor._get_by_id(config, "data::values::2::x")
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config = {"external_ref": f"%{external_file}::value", "local_ref": "%local_key", "local_key": 100}
+
+ result = preprocessor.process_raw_refs(config, config, external_only=False)
+
+ # Both should be expanded
+ assert result["external_ref"] == 42
+ assert result["local_ref"] == 100
+
+ def test_two_phase_expansion_with_override(self, tmp_path):
+ """Test that two-phase expansion allows overrides to affect local refs."""
+ external_file = tmp_path / "external.yaml"
+ external_file.write_text("external_value: 1")
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ # Initial config with both external and local refs
+ config = {
+ "external_ref": f"%{external_file}::external_value",
+ "local_ref": "%vars::value",
+ "vars": {"value": None}, # Will be overridden
+ }
+
+ # Phase 1: Expand only external refs
+ config = preprocessor.process_raw_refs(config, config, external_only=True)
+ assert config["external_ref"] == 1
+ assert config["local_ref"] == "%vars::value" # Still string
+
+ # Simulate CLI override
+ config["vars"]["value"] = "/data/features.npz"
+
+ # Phase 2: Expand local refs (now sees override)
+ config = preprocessor.process_raw_refs(config, config, external_only=False)
+ assert config["local_ref"] == "/data/features.npz"
+
+ def test_nested_local_refs_expanded_together(self, tmp_path):
+ """Test that nested local refs are all expanded in phase 2."""
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
+
+ config = {"a": {"b": {"c": 42}}, "ref_to_b": "%a::b", "ref_to_c": "%a::b::c"}
+
+ # Phase 1: Nothing to expand (no external refs)
+ result = preprocessor.process_raw_refs(config, config, external_only=True)
+ assert result["ref_to_b"] == "%a::b"
+ assert result["ref_to_c"] == "%a::b::c"
+
+ # Phase 2: Expand all local refs
+ result = preprocessor.process_raw_refs(result, result, external_only=False)
+ assert result["ref_to_b"] == {"c": 42}
+ assert result["ref_to_c"] == 42
+
+ def test_external_ref_within_local_ref_expanded_correctly(self, tmp_path):
+ """Test that external refs within locally-referenced values are expanded."""
+ external_file = tmp_path / "external.yaml"
+ external_file.write_text("nested:\n value: 99")
+
+ loader = Loader()
+ preprocessor = Preprocessor(loader)
- assert result == 3
+ config = {
+ "template": {"external": f"%{external_file}::nested"},
+ "copy": "%template",
+ }
- def test_get_by_id_type_error_on_primitive(self):
- """Test _get_by_id raises TypeError when trying to index a primitive value."""
- config = {"value": 42}
+ # Phase 1: Expand external ref inside template
+ result = preprocessor.process_raw_refs(config, config, external_only=True)
+ assert result["template"]["external"] == {"value": 99}
+ assert result["copy"] == "%template" # Local ref still string
- with pytest.raises(TypeError, match="Cannot index int"):
- Preprocessor._get_by_id(config, "value::subkey")
+ # Phase 2: Expand local ref - should get the already-expanded template
+ result = preprocessor.process_raw_refs(result, result, external_only=False)
+ assert result["copy"] == {"external": {"value": 99}}
diff --git a/uv.lock b/uv.lock
index 43c5e30..08e3a48 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1714,7 +1714,7 @@ wheels = [
[[package]]
name = "sparkwheel"
-version = "0.0.7"
+version = "0.0.8"
source = { editable = "." }
dependencies = [
{ name = "pyyaml" },