diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..9d8aaac --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,36 @@ +name: Release + +on: + push: + tags: + - '*' + +permissions: + contents: write + +jobs: + release: + name: Create GitHub Release + runs-on: ubuntu-latest + timeout-minutes: 5 + if: startsWith(github.ref, 'refs/tags') + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Verify tag is on main branch + run: | + git fetch origin main + if ! git merge-base --is-ancestor ${{ github.sha }} origin/main; then + echo "Error: Tag is not on the main branch" + exit 1 + fi + + - name: Create Release + uses: softprops/action-gh-release@v2 + with: + generate_release_notes: true + draft: false + prerelease: false diff --git a/README.md b/README.md index 118cd0e..5f0d961 100644 --- a/README.md +++ b/README.md @@ -10,25 +10,30 @@ License Documentation

-
-

βš™οΈ YAML configuration meets Python 🐍

+

YAML configuration meets Python

Define Python objects in YAML. Reference, compose, and instantiate them effortlessly.


-## What is Sparkwheel? +## Quick Start -Stop hardcoding parameters. Define complex Python objects in clean YAML files, compose them naturally, and instantiate with one line. +```bash +pip install sparkwheel +``` ```yaml # config.yaml +dataset: + num_classes: 10 + batch_size: 32 + model: _target_: torch.nn.Linear in_features: 784 - out_features: "%dataset::num_classes" # Reference other values + out_features: "%dataset::num_classes" # Reference -dataset: - num_classes: 10 +training: + steps_per_epoch: "$10000 // @dataset::batch_size" # Expression ``` ```python @@ -36,77 +41,25 @@ from sparkwheel import Config config = Config() config.update("config.yaml") -model = config.resolve("model") # Actual torch.nn.Linear(784, 10) instance! -``` - -## Key Features - -- **Declarative Object Creation** - Instantiate any Python class from YAML with `_target_` -- **Smart References** - `@` for resolved values, `%` for raw YAML -- **Composition by Default** - Configs merge naturally (dicts merge, lists extend) -- **Explicit Operators** - `=` to replace, `~` to delete when needed -- **Python Expressions** - Compute values dynamically with `$` prefix -- **Schema Validation** - Type-check configs with Python dataclasses -- **CLI Overrides** - Override any value from command line - -## Installation -```bash -pip install sparkwheel +model = config.resolve("model") # Actual torch.nn.Linear(784, 10) ``` -**[β†’ Get Started in 5 Minutes](https://project-lighter.github.io/sparkwheel/getting-started/quickstart/)** +## Features -## Coming from Hydra/OmegaConf? - -Sparkwheel builds on similar ideas but adds powerful features: - -| Feature | Hydra/OmegaConf | Sparkwheel | -|---------|-----------------|------------| -| Config composition | Explicit (`+`, `++`) | **By default** (dicts merge, lists extend) | -| Replace semantics | Default | Explicit with `=` operator | -| Delete keys | Not idempotent | Idempotent `~` operator | -| References | OmegaConf interpolation | `@` (resolved) + `%` (raw YAML) | -| Python expressions | Limited | Full Python with `$` | -| Schema validation | Structured Configs | Python dataclasses | -| List extension | Lists replace | **Lists extend by default** | - -**Composition by default** means configs merge naturally without operators: -```yaml -# base.yaml -model: - hidden_size: 256 - dropout: 0.1 - -# experiment.yaml -model: - hidden_size: 512 # Override - # dropout inherited -``` - -## Documentation +- **Declarative Objects** - Instantiate any Python class with `_target_` +- **Smart References** - `@` for resolved values, `%` for raw YAML +- **Composition by Default** - Dicts merge, lists extend automatically +- **Explicit Control** - `=` to replace, `~` to delete +- **Python Expressions** - Dynamic values with `$` +- **Schema Validation** - Type-check with dataclasses -- [Full Documentation](https://project-lighter.github.io/sparkwheel/) -- [Quick Start Guide](https://project-lighter.github.io/sparkwheel/getting-started/quickstart/) -- [Core Concepts](https://project-lighter.github.io/sparkwheel/user-guide/basics/) -- [API Reference](https://project-lighter.github.io/sparkwheel/reference/) +**[Get Started](https://project-lighter.github.io/sparkwheel/getting-started/quickstart/)** Β· **[Documentation](https://project-lighter.github.io/sparkwheel/)** Β· **[Quick Reference](https://project-lighter.github.io/sparkwheel/user-guide/quick-reference/)** ## Community -- [Discord Server](https://discord.gg/zJcnp6KrUp) - Chat with the community -- [YouTube Channel](https://www.youtube.com/channel/UCef1oTpv2QEBrD2pZtrdk1Q) - Tutorials and demos -- [GitHub Issues](https://github.com/project-lighter/sparkwheel/issues) - Bug reports and feature requests - -## Contributing - -We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for development setup and guidelines. +- [Discord](https://discord.gg/zJcnp6KrUp) Β· [YouTube](https://www.youtube.com/channel/UCef1oTpv2QEBrD2pZtrdk1Q) Β· [Issues](https://github.com/project-lighter/sparkwheel/issues) ## About -Sparkwheel is a hard fork of [MONAI Bundle](https://github.com/Project-MONAI/MONAI/tree/dev/monai/bundle)'s configuration system, refined and expanded for general-purpose use. We're deeply grateful to the MONAI team for their excellent foundation. - -Sparkwheel powers [Lighter](https://project-lighter.github.io/lighter/), our configuration-driven deep learning framework built on PyTorch Lightning. - -## License - -Apache License 2.0 - See [LICENSE](LICENSE) for details. +Sparkwheel is a hard fork of [MONAI Bundle](https://github.com/Project-MONAI/MONAI/tree/dev/monai/bundle)'s config system, with the goal of making a more general-purpose configuration library for Python projects. It combines the best of MONAI Bundle and [Hydra](http://hydra.cc/)/[OmegaComf](https://omegaconf.readthedocs.io/), while introducing new features and improvements not found in either. diff --git a/docs/index.md b/docs/index.md index db33c86..f782043 100644 --- a/docs/index.md +++ b/docs/index.md @@ -199,7 +199,7 @@ Sparkwheel has two types of references with distinct purposes: - **Composition-by-default** - Configs merge/extend naturally, no operators needed for common case - **List extension** - Lists extend by default (unique vs Hydra!) - **`=` replace operator** - Explicit control when you need replacement - - **`~` delete operator** - Remove inherited keys cleanly (idempotent!) + - **`~` delete operator** - Remove inherited keys explicitly - **Python expressions with `$`** - Compute values dynamically - **Dataclass validation** - Type-safe configs without boilerplate - **Dual reference system** - `@` for resolved values, `%` for raw YAML diff --git a/docs/user-guide/advanced.md b/docs/user-guide/advanced.md index 1555750..a3b7c9e 100644 --- a/docs/user-guide/advanced.md +++ b/docs/user-guide/advanced.md @@ -221,8 +221,6 @@ config.update({"~plugins": [0, 2]}) # Remove list items config.update({"~dataloaders": ["train", "test"]}) # Remove dict keys ``` -**Note:** The `~` directive is idempotent - it doesn't error if the key doesn't exist, enabling reusable configs. - ### Programmatic Updates Apply operators programmatically: diff --git a/docs/user-guide/cli.md b/docs/user-guide/cli.md index dc3474d..f32f8a4 100644 --- a/docs/user-guide/cli.md +++ b/docs/user-guide/cli.md @@ -111,7 +111,7 @@ Three operators for fine-grained control: |----------|--------|----------|---------| | **Compose** (default) | `key=value` | Merges dicts, extends lists | `model::lr=0.001` | | **Replace** | `=key=value` | Completely replaces value | `=model={'_target_': 'ResNet'}` | -| **Delete** | `~key` | Removes key (idempotent) | `~debug` | +| **Delete** | `~key` | Removes key (errors if missing) | `~debug` | !!! info "Type Inference" Values are automatically typed using `ast.literal_eval()`: diff --git a/docs/user-guide/operators.md b/docs/user-guide/operators.md index 6baf2d1..9b03f63 100644 --- a/docs/user-guide/operators.md +++ b/docs/user-guide/operators.md @@ -126,11 +126,14 @@ Remove keys or list items with `~key`: ### Delete Entire Keys ```yaml -# Remove keys (idempotent - no error if missing!) +# Remove keys explicitly ~old_param: null ~debug_settings: null ``` +!!! warning "Key Must Exist" + The delete operator will raise an error if the key doesn't exist. This helps catch typos and configuration mistakes. + ### Delete Dict Keys Use path notation for nested keys: @@ -214,28 +217,6 @@ dataloaders: **Why?** Path notation is designed for dict keys, not list indices. The batch syntax handles index normalization and processes deletions correctly (high to low order). -### Idempotent Delete - -Delete operations don't error if the key doesn't exist: - -```yaml -# production.yaml - Remove debug settings if they exist -~debug_mode: null -~dev_logger: null -~test_data: null -# No errors if these don't exist! -``` - -This enables **reusable configs** that work with multiple bases: - -```yaml -# production.yaml works with ANY base config -~debug_settings: null -~verbose_logging: null -database: - pool_size: 100 -``` - ## Combining Operators Mix composition, replace, and delete: @@ -298,7 +279,7 @@ config.update({"model": {"hidden_size": 1024}}) # Replace explicitly config.update({"=optimizer": {"type": "sgd", "lr": 0.1}}) -# Delete keys (idempotent) +# Delete keys config.update({ "~training::old_param": None, "~model::dropout": None @@ -454,17 +435,40 @@ model: ### Write Reusable Configs -Use idempotent delete for portable configs: +!!! warning "Delete Requires Key Existence" + The delete operator (`~`) is **strict** - it raises an error if the key doesn't exist. This helps catch typos and configuration mistakes. +When writing configs that should work with different base configurations, you have a few options: + +**Option 1: Document required keys** ```yaml -# production.yaml - works with ANY base! -~debug_mode: null # Remove if exists -~verbose_logging: null # No error if missing +# production.yaml +# Requires: base config must have debug_mode and verbose_logging +~debug_mode: null +~verbose_logging: null database: pool_size: 100 ssl: true ``` +**Option 2: Use composition order** +```yaml +# production.yaml - override instead of delete +debug_mode: false # Overrides if exists, sets if not +verbose_logging: false +database: + pool_size: 100 + ssl: true +``` + +**Option 3: Conditional deletion with lists** +```yaml +# Delete multiple optional keys - fails only if ALL are missing +~: [debug_mode, verbose_logging] # At least one must exist +database: + pool_size: 100 +``` + ## Common Mistakes ### Using `=` When Not Needed @@ -519,17 +523,17 @@ plugins: [cache] |---------|-------|------------| | Dict merge default | Yes βœ… | Yes βœ… | | List extend default | No ❌ | **Yes** βœ… | -| Operators in YAML | No ❌ | Yes βœ… (`=`, `~`) | -| Operator count | 4 (`+`, `++`, `~`) | **2** (`=`, `~`) βœ… | -| Delete dict keys | No ❌ | Yes βœ… | -| Delete list items | No ❌ | Yes βœ… | -| Idempotent delete | N/A | Yes βœ… | - -Sparkwheel goes beyond Hydra with: -- Full composition-first philosophy (dicts **and** lists) -- Operators directly in YAML files -- Just 2 simple operators -- Delete operations for fine-grained control +| Operators in YAML | CLI-only | **Yes** βœ… (YAML + CLI) | +| Operator count | 4 (`=`, `+`, `++`, `~`) | **2** (`=`, `~`) βœ… | +| Delete dict keys | CLI-only (`~foo.bar`) | **Yes** βœ… (YAML + CLI) | +| Delete list items | No ❌ | **Yes** βœ… (by index) | + +Sparkwheel differs from Hydra: +- **Full composition philosophy**: Both dicts AND lists compose by default +- **Operators in YAML files**: Not just CLI overrides +- **Simpler operator set**: Just 2 operators (`=`, `~`) vs 4 (`=`, `+`, `++`, `~`) +- **List deletion**: Delete items by index with `~plugins: [0, 2]` +- **Flexible delete**: Use `~` anywhere (YAML, CLI, programmatic) ## Next Steps diff --git a/docs/user-guide/references.md b/docs/user-guide/references.md index be2f42d..0c46aa6 100644 --- a/docs/user-guide/references.md +++ b/docs/user-guide/references.md @@ -10,7 +10,7 @@ Sparkwheel provides two types of references for linking configuration values: | Feature | `@ref` (Resolved) | `%ref` (Raw) | `$expr` (Expression) | |---------|-------------------|--------------|----------------------| | **Returns** | Final computed value | Raw YAML content | Evaluated expression result | -| **When processed** | Lazy (`resolve()`) | Eager (`update()`) | Lazy (`resolve()`) | +| **When processed** | Lazy (`resolve()`) | External: Eager / Local: Lazy | Lazy (`resolve()`) | | **Instantiates objects** | βœ… Yes | ❌ No | βœ… Yes (if referenced) | | **Evaluates expressions** | βœ… Yes | ❌ No | βœ… Yes | | **Use in dataclass validation** | βœ… Yes | ⚠️ Limited | βœ… Yes | @@ -18,68 +18,80 @@ Sparkwheel provides two types of references for linking configuration values: | **Cross-file references** | βœ… Yes | βœ… Yes | ❌ No | | **When to use** | Get computed results | Copy config structures | Compute new values | -## Two-Stage Processing Model +## Two-Phase Processing Model -Sparkwheel processes references at different times to enable safe config composition: +Sparkwheel processes raw references (`%`) in two phases to support CLI overrides: !!! abstract "When References Are Processed" - **Stage 1: Eager Processing (during `update()`)** + **Phase 1: Eager Processing (during `update()`)** - - **Raw References (`%`)** are expanded immediately when configs are merged - - Enables safe config composition and pruning workflows - - External file references resolved at load time + - **External file raw refs (`%file.yaml::key`)** are expanded immediately + - External files are frozenβ€”their content won't change based on CLI overrides + - Enables copy-then-delete workflows with external files - **Stage 2: Lazy Processing (during `resolve()`)** + **Phase 2: Lazy Processing (during `resolve()`)** + - **Local raw refs (`%key`)** are expanded after all composition is complete - **Resolved References (`@`)** are processed on-demand - **Expressions (`$`)** are evaluated when needed - **Components (`_target_`)** are instantiated only when requested - - Supports complex dependency graphs and deferred instantiation + - CLI overrides can affect local `%` refs -**Why two stages?** +**Why two phases?** -This separation enables powerful workflows like config pruning: +This design ensures CLI overrides work intuitively with local raw references: ```yaml # base.yaml -system: - lr: 0.001 - batch_size: 32 +vars: + features_path: null # Default, will be overridden -experiment: - model: - optimizer: - lr: "%system::lr" # Copies raw value 0.001 eagerly - -~system: null # Delete system section after copying +# model.yaml +dataset: + path: "%vars::features_path" # Local ref - sees CLI override ``` ```python config = Config() config.update("base.yaml") -# % references already expanded during update() -# ~system deletion applied after expansion -# Result: experiment::model::optimizer::lr = 0.001 (system deleted safely) +config.update("model.yaml") +config.update("vars::features_path=/data/features.npz") # CLI override + +# Local % ref sees the override! +path = config.resolve("dataset::path") # "/data/features.npz" ``` -With `@` references, this would fail because they resolve lazily after deletion. +!!! tip "External vs Local Raw References" + + | Type | Example | When Expanded | Use Case | + |------|---------|---------------|----------| + | **External** | `%file.yaml::key` | Eager (update) | Import from frozen files | + | **Local** | `%vars::key` | Lazy (resolve) | Reference config values | + + External files are "frozen"β€”their content is fixed at load time. + Local config values may be overridden via CLI, so local refs see the final state. ## Resolution Flow !!! abstract "How References Are Resolved" - **Step 1: Parse Config** β†’ Detect references in YAML + **Step 1: Load Configs** β†’ During `update()` - **Step 2: Determine Type** + - Parse YAML files + - Expand external `%file.yaml::key` refs immediately + - Keep local `%key` refs as strings - - **`%key`** β†’ Expanded eagerly during `update()` βœ… - - **`@key`** β†’ Proceed to dependency resolution (lazy) + **Step 2: Apply Overrides** β†’ During `update()` calls - **Step 3: Resolve Dependencies** (for `@` references during `resolve()`) + - CLI overrides modify local config values + - Local `%` refs still see the string form + **Step 3: Resolve** β†’ During `resolve()` + + - Expand local `%key` refs (now sees final values) + - Resolve `@` dependencies in order - Check for circular references β†’ ❌ **Error if found** - - Resolve all dependencies first - Evaluate expressions and instantiate objects - Return final computed value βœ… @@ -223,6 +235,8 @@ model_template: "%base.yaml::model" ### Local Raw References +Local raw references are expanded lazily during `resolve()`, which means CLI overrides can affect them: + ```yaml # config.yaml defaults: @@ -237,6 +251,17 @@ api_config: backup_defaults: "%defaults" # Gets the whole defaults dict ``` +!!! tip "CLI Overrides Work with Local Raw Refs" + + ```python + config = Config() + config.update("config.yaml") + config.update("defaults::timeout=60") # CLI override + + # Local % ref sees the override! + config.resolve("api_config::timeout") # 60 + ``` + ### Key Distinction !!! abstract "@ vs % - When to Use Each" diff --git a/pyproject.toml b/pyproject.toml index 4a48981..916b139 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,6 @@ show_traceback = true allow_redefinition = false check_untyped_defs = true -disallow_any_generics = true disallow_incomplete_defs = true ignore_missing_imports = true implicit_reexport = false diff --git a/src/sparkwheel/__init__.py b/src/sparkwheel/__init__.py index 16b8399..375ba14 100644 --- a/src/sparkwheel/__init__.py +++ b/src/sparkwheel/__init__.py @@ -5,7 +5,6 @@ """ from .config import Config, parse_overrides -from .errors import enable_colors from .items import Component, Expression, Instantiable, Item from .operators import apply_operators, validate_operators from .resolver import Resolver @@ -19,7 +18,7 @@ EvaluationError, FrozenConfigError, InstantiationError, - SourceLocation, + Location, TargetNotFoundError, ) @@ -39,7 +38,6 @@ "validate_operators", "validate", "validator", - "enable_colors", "RESOLVED_REF_KEY", "RAW_REF_KEY", "ID_SEP_KEY", @@ -55,5 +53,5 @@ "EvaluationError", "FrozenConfigError", "ValidationError", - "SourceLocation", + "Location", ] diff --git a/src/sparkwheel/config.py b/src/sparkwheel/config.py index 9adf20b..1d19aec 100644 --- a/src/sparkwheel/config.py +++ b/src/sparkwheel/config.py @@ -3,26 +3,28 @@ Sparkwheel is a YAML-based configuration system with references, expressions, and dynamic instantiation. This module provides the main Config class for loading, managing, and resolving configurations. -## Two-Stage Processing +## Two-Phase Processing -Sparkwheel uses a two-stage processing model to handle different reference types at appropriate times: +Sparkwheel uses a two-phase processing model to handle different reference types at appropriate times: -### Stage 1: Eager Processing (during update()) -- **Raw References (`%`)**: Expanded immediately when configs are merged -- **Purpose**: Enables safe config composition or deletion +### Phase 1: Eager Processing (during update()) +- **External file raw refs (`%file.yaml::key`)**: Expanded immediately +- **Purpose**: External files are frozen - their content won't change - **Example**: `%base.yaml::lr` is replaced with the actual value from base.yaml -### Stage 2: Lazy Processing (during resolve()) +### Phase 2: Lazy Processing (during resolve()) +- **Local raw refs (`%key`)**: Expanded after all composition is complete - **Resolved References (`@`)**: Resolved on-demand to support circular dependencies - **Expressions (`$`)**: Evaluated when needed using Python's eval() - **Components (`_target_`)**: Instantiated only when requested -- **Purpose**: Supports deferred instantiation and complex dependency graphs +- **Purpose**: CLI overrides can affect local `%` refs, supports deferred instantiation ## Reference Types | Symbol | Name | When Expanded | Purpose | Example | |--------|------|---------------|---------|---------| -| `%` | Raw Reference | Eager (update()) | Copy/paste YAML sections | `%base.yaml::lr` | +| `%` | Raw Reference (external) | Eager (update()) | Copy from external files | `%base.yaml::lr` | +| `%` | Raw Reference (local) | Lazy (resolve()) | Copy local config values | `%vars::lr` | | `@` | Resolved Reference | Lazy (resolve()) | Reference config values | `@model::lr` | | `$` | Expression | Lazy (resolve()) | Compute values dynamically | `$@lr * 2` | @@ -97,15 +99,15 @@ from typing import Any from .loader import Loader -from .metadata import MetadataRegistry -from .operators import _validate_delete_operator, apply_operators, validate_operators +from .locations import LocationRegistry +from .operators import MergeContext, _validate_delete_operator, apply_operators, validate_operators from .parser import Parser -from .path_utils import split_id +from .path_utils import get_by_id, split_id from .preprocessor import Preprocessor from .resolver import Resolver -from .utils import PathLike, look_up_option, optional_import +from .utils import PathLike, optional_import from .utils.constants import ID_SEP_KEY, REMOVE_KEY, REPLACE_KEY -from .utils.exceptions import ConfigKeyError +from .utils.exceptions import ConfigKeyError, build_missing_key_error __all__ = ["Config", "parse_overrides"] @@ -183,7 +185,7 @@ def __init__( >>> config = Config(schema=MySchema).update("config.yaml") """ self._data: dict[str, Any] = data or {} # Start with provided data or empty - self._metadata = MetadataRegistry() + self._locations = LocationRegistry() self._resolver = Resolver() self._is_parsed = False self._frozen = False # Set via freeze() method later @@ -257,7 +259,7 @@ def get(self, id: str = "", default: Any = None) -> Any: """ try: return self._get_by_id(id) - except (KeyError, IndexError, ValueError): + except (KeyError, IndexError, TypeError): return default def set(self, id: str, value: Any) -> None: @@ -329,7 +331,7 @@ def validate(self, schema: type) -> None: """ from .schema import validate as validate_schema - validate_schema(self._data, schema, metadata=self._metadata) + validate_schema(self._data, schema, metadata=self._locations) def freeze(self) -> None: """Freeze config to prevent further modifications. @@ -359,6 +361,20 @@ def is_frozen(self) -> bool: """ return self._frozen + @property + def locations(self) -> LocationRegistry: + """Get the location registry for this config. + + Returns: + LocationRegistry tracking file locations of config keys + + Example: + >>> config = Config().update("config.yaml") + >>> location = config.locations.get("model::lr") + >>> print(f"{location.filepath}:{location.line}") + """ + return self._locations + def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config": """Update configuration with changes from another source. @@ -376,7 +392,7 @@ def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config" Operators: - key=value - Compose (default): merge dict or extend list - =key=value - Replace operator: completely replace value - - ~key - Remove operator: delete key (idempotent) + - ~key - Remove operator: delete key (errors if missing) Examples: >>> # Update from file @@ -423,10 +439,13 @@ def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config" else: self._update_from_file(source) - # Eagerly expand raw references (%) immediately after update - # This matches MONAI's behavior and allows safe pruning with delete operator (~) - # Must happen BEFORE validation so schema sees final structure, not raw ref strings - self._data = self._preprocessor.process_raw_refs(self._data, self._data, id="") + # Phase 1: Eagerly expand ONLY external file raw references (%file.yaml::key) + # Local refs (%key) are kept as strings - they'll be expanded in _parse() after + # all composition is complete. This allows CLI overrides to affect local refs. + # External files are frozen (their content won't change), so eager expansion is safe. + self._data = self._preprocessor.process_raw_refs( + self._data, self._data, id="", locations=self._locations, external_only=True + ) # Validate after raw ref expansion if schema exists # This validates the final structure, not intermediate raw reference strings @@ -436,7 +455,7 @@ def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config" validate_schema( self._data, self._schema, - metadata=self._metadata, + metadata=self._locations, allow_missing=self._allow_missing, strict=self._strict, ) @@ -445,8 +464,9 @@ def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config" def _update_from_config(self, source: "Config") -> None: """Update from another Config instance.""" - self._data = apply_operators(self._data, source._data) - self._metadata.merge(source._metadata) + context = MergeContext(locations=source.locations) + self._data = apply_operators(self._data, source._data, context=context) + self._locations.merge(source.locations) self._invalidate_resolution() def _uses_nested_paths(self, source: dict[str, Any]) -> bool: @@ -466,17 +486,43 @@ def _apply_path_updates(self, source: dict[str, Any]) -> None: self.set(actual_key, value) elif key.startswith(REMOVE_KEY): - # Delete operator: ~key (idempotent) + # Delete operator: ~key actual_key = key[1:] _validate_delete_operator(actual_key, value) - if actual_key in self: - self._delete_nested_key(actual_key) + if actual_key not in self: + # Try to find source location for the key being deleted + source_location = self._locations.get(actual_key) if self._locations else None + + # For nested keys, get available keys from the parent container + available_keys: list[str] = [] + parent_key_name: str | None = None + error_key = actual_key # The key to show in error message + + if ID_SEP_KEY in actual_key: + # Nested key like "model::lr" + parent_path, child_key = actual_key.rsplit(ID_SEP_KEY, 1) + parent_key_name = parent_path + error_key = child_key # Show only the child key in nested errors + try: + parent = self._get_by_id(parent_path) + if isinstance(parent, dict): + available_keys = list(parent.keys()) + except (KeyError, IndexError, TypeError): + # Parent doesn't exist, fall back to top-level + available_keys = list(self._data.keys()) if isinstance(self._data, dict) else [] + else: + # Top-level key + available_keys = list(self._data.keys()) if isinstance(self._data, dict) else [] + + raise build_missing_key_error(error_key, available_keys, source_location, parent_key=parent_key_name) + self._delete_nested_key(actual_key) else: # Default: compose (merge dict or extend list) if key in self and isinstance(self[key], dict) and isinstance(value, dict): - merged = apply_operators(self[key], value) + context = MergeContext(locations=self._locations, current_path=key) + merged = apply_operators(self[key], value, context=context) self.set(key, merged) elif key in self and isinstance(self[key], list) and isinstance(value, list): self.set(key, self[key] + value) @@ -501,7 +547,8 @@ def _delete_nested_key(self, key: str) -> None: def _apply_structural_update(self, source: dict[str, Any]) -> None: """Apply structural update with operators.""" validate_operators(source) - self._data = apply_operators(self._data, source) + context = MergeContext(locations=self._locations) + self._data = apply_operators(self._data, source, context=context) self._invalidate_resolution() def _update_from_file(self, source: PathLike) -> None: @@ -512,12 +559,13 @@ def _update_from_file(self, source: PathLike) -> None: # Check if loaded data uses :: path syntax if self._uses_nested_paths(new_data): # Expand nested paths using path updates - self._metadata.merge(new_metadata) + self._locations.merge(new_metadata) self._apply_path_updates(new_data) else: # Normal structural update - self._data = apply_operators(self._data, new_data) - self._metadata.merge(new_metadata) + context = MergeContext(locations=new_metadata) + self._data = apply_operators(self._data, new_data, context=context) + self._locations.merge(new_metadata) self._invalidate_resolution() @@ -623,7 +671,10 @@ def _parse(self, reset: bool = True) -> None: """Parse config tree and prepare for resolution. Internal method called automatically by resolve(). - Note: % raw references are already expanded during update(). + + Two-phase raw reference expansion: + - Phase 1 (update): External file refs expanded eagerly + - Phase 2 (here): Local refs expanded now, after all composition Args: reset: Whether to reset the resolver before parsing (default: True) @@ -632,12 +683,17 @@ def _parse(self, reset: bool = True) -> None: if reset: self._resolver.reset() + # Phase 2: Expand local raw references (%key) now that all composition is complete + # CLI overrides have been applied, so local refs will see final values + self._data = self._preprocessor.process_raw_refs( + self._data, self._data, id="", locations=self._locations, external_only=False + ) + # Stage 1: Preprocess (@:: relative resolved IDs) - # Note: % raw references were already expanded in update() self._data = self._preprocessor.process(self._data, self._data, id="") # Stage 2: Parse config tree to create Items - parser = Parser(globals=self._globals, metadata=self._metadata) + parser = Parser(globals=self._globals, metadata=self._locations) items = parser.parse(self._data) # Stage 3: Add items to resolver @@ -655,21 +711,10 @@ def _get_by_id(self, id: str) -> Any: Config value at that path Raises: - KeyError: If path not found + KeyError: If path not found (includes available keys in message) + TypeError: If trying to index a non-dict/list value """ - if id == "": - return self._data - - config = self._data - for k in split_id(id): - if not isinstance(config, (dict, list)): - raise ValueError(f"Config must be dict or list for key `{k}`, but got {type(config)}: {config}") - try: - config = look_up_option(k, config, print_all_options=False) if isinstance(config, dict) else config[int(k)] - except ValueError as e: - raise KeyError(f"Key not found: {k}") from e - - return config + return get_by_id(self._data, id) def _invalidate_resolution(self) -> None: """Invalidate cached resolution (called when config changes).""" @@ -717,7 +762,7 @@ def __contains__(self, id: str) -> bool: try: self._get_by_id(id) return True - except (KeyError, IndexError, ValueError): + except (KeyError, IndexError, TypeError): return False def __repr__(self) -> str: diff --git a/src/sparkwheel/errors/__init__.py b/src/sparkwheel/errors/__init__.py index eb61dba..346291d 100644 --- a/src/sparkwheel/errors/__init__.py +++ b/src/sparkwheel/errors/__init__.py @@ -1,13 +1,7 @@ from .context import format_available_keys, format_resolution_chain -from .formatters import enable_colors, format_code, format_error, format_suggestion from .suggestions import format_suggestions, get_suggestions, levenshtein_distance __all__ = [ - # Formatters - "enable_colors", - "format_error", - "format_suggestion", - "format_code", # Suggestions "levenshtein_distance", "get_suggestions", diff --git a/src/sparkwheel/errors/formatters.py b/src/sparkwheel/errors/formatters.py deleted file mode 100644 index f3f5f06..0000000 --- a/src/sparkwheel/errors/formatters.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Color formatting utilities for terminal output with auto-detection.""" - -import os -import sys - -__all__ = [ - "enable_colors", - "format_error", - "format_suggestion", - "format_code", - "RED", - "YELLOW", - "GREEN", - "BLUE", - "GRAY", - "RESET", -] - -# ANSI color codes -RED = "\033[31m" -YELLOW = "\033[33m" -GREEN = "\033[32m" -BLUE = "\033[34m" -GRAY = "\033[90m" -BOLD = "\033[1m" -RESET = "\033[0m" - -# Global flag for color support -_COLORS_ENABLED: bool | None = None - - -def _supports_color() -> bool: - """Auto-detect if the terminal supports colors. - - Follows industry standards for color detection: - 1. NO_COLOR environment variable disables colors (https://no-color.org/) - 2. SPARKWHEEL_NO_COLOR environment variable disables colors (sparkwheel-specific) - 3. FORCE_COLOR environment variable enables colors (https://force-color.org/) - 4. stdout TTY detection (auto-detect) - 5. Default: disable colors - - Returns: - True if colors should be enabled, False otherwise - """ - # Check NO_COLOR environment variable (https://no-color.org/) - # Highest priority - explicit user preference to disable - if os.environ.get("NO_COLOR"): - return False - - # Check sparkwheel-specific disable flag - if os.environ.get("SPARKWHEEL_NO_COLOR"): - return False - - # Check FORCE_COLOR environment variable (https://force-color.org/) - # Explicit enable for CI environments, piping, etc. - if os.environ.get("FORCE_COLOR"): - return True - - # Auto-detect: Check if stdout is a TTY - if hasattr(sys.stdout, "isatty") and sys.stdout.isatty(): - return True - - # Default: disable colors - return False - - -def enable_colors(enabled: bool | None = None) -> bool: - """Enable or disable color output. - - Args: - enabled: True to enable, False to disable, None for auto-detection - - Returns: - Current color enable status - - Examples: - >>> enable_colors(False) # Disable colors - False - >>> enable_colors(True) # Force enable colors - True - >>> enable_colors() # Auto-detect - True # (if terminal supports it) - """ - global _COLORS_ENABLED - - if enabled is None: - _COLORS_ENABLED = _supports_color() - else: - _COLORS_ENABLED = enabled - - return _COLORS_ENABLED - - -def _get_colors_enabled() -> bool: - """Get current color enable status, initializing if needed.""" - global _COLORS_ENABLED - - if _COLORS_ENABLED is None: - enable_colors() # Auto-detect - - return _COLORS_ENABLED # type: ignore[return-value] - - -def _colorize(text: str, color: str) -> str: - """Apply color to text if colors are enabled. - - Args: - text: Text to colorize - color: ANSI color code - - Returns: - Colorized text if colors enabled, otherwise plain text - """ - if _get_colors_enabled(): - return f"{color}{text}{RESET}" - return text - - -def format_error(text: str) -> str: - """Format text as an error (red). - - Args: - text: Text to format - - Returns: - Formatted text - - Examples: - >>> format_error("Error message") - '\x1b[31mError message\x1b[0m' # With colors enabled - >>> format_error("Error message") - 'Error message' # With colors disabled - """ - return _colorize(text, RED) - - -def format_suggestion(text: str) -> str: - """Format text as a suggestion (yellow). - - Args: - text: Text to format - - Returns: - Formatted text - """ - return _colorize(text, YELLOW) - - -def format_success(text: str) -> str: - """Format text as success/correct (green). - - Args: - text: Text to format - - Returns: - Formatted text - """ - return _colorize(text, GREEN) - - -def format_code(text: str) -> str: - """Format text as code/metadata (blue). - - Args: - text: Text to format - - Returns: - Formatted text - """ - return _colorize(text, BLUE) - - -def format_context(text: str) -> str: - """Format text as context (gray). - - Args: - text: Text to format - - Returns: - Formatted text - """ - return _colorize(text, GRAY) - - -def format_bold(text: str) -> str: - """Format text as bold. - - Args: - text: Text to format - - Returns: - Formatted text - """ - if _get_colors_enabled(): - return f"{BOLD}{text}{RESET}" - return text diff --git a/src/sparkwheel/items.py b/src/sparkwheel/items.py index 255ea24..dfe5f02 100644 --- a/src/sparkwheel/items.py +++ b/src/sparkwheel/items.py @@ -7,7 +7,7 @@ from .utils import CompInitMode, first, instantiate, optional_import, run_debug, run_eval from .utils.constants import EXPR_KEY -from .utils.exceptions import EvaluationError, InstantiationError, SourceLocation, TargetNotFoundError +from .utils.exceptions import EvaluationError, InstantiationError, Location, TargetNotFoundError __all__ = ["Item", "Expression", "Component", "Instantiable"] @@ -46,7 +46,7 @@ class Item: source_location: optional location in source file where this config item was defined. """ - def __init__(self, config: Any, id: str = "", source_location: SourceLocation | None = None) -> None: + def __init__(self, config: Any, id: str = "", source_location: Location | None = None) -> None: self.config = config self.id = id self.source_location = source_location @@ -126,7 +126,7 @@ class Component(Item, Instantiable): non_arg_keys = {"_target_", "_disabled_", "_requires_", "_mode_", "_args_"} - def __init__(self, config: Any, id: str = "", source_location: SourceLocation | None = None) -> None: + def __init__(self, config: Any, id: str = "", source_location: Location | None = None) -> None: super().__init__(config=config, id=id, source_location=source_location) @staticmethod @@ -218,8 +218,15 @@ def instantiate(self, **kwargs: Any) -> object: source_location=self.source_location, suggestion=suggestion, ) from e + except InstantiationError as e: + # Add source location if not already present, preserving original message and suggestion + if e.source_location is None: + raise InstantiationError( + e._original_message, source_location=self.source_location, suggestion=e.suggestion + ) from e + raise # Already has location, re-raise as-is except Exception as e: - # Wrap other errors with location context (points to _target_ line) + # Wrap unexpected errors with component context and location raise InstantiationError( f"Failed to instantiate '{modname}': {type(e).__name__}: {e}", source_location=self.source_location, @@ -304,7 +311,7 @@ def __init__( config: Any, id: str = "", globals: dict[str, Any] | None = None, - source_location: SourceLocation | None = None, + source_location: Location | None = None, ) -> None: super().__init__(config=config, id=id, source_location=source_location) self.globals = globals if globals is not None else {} diff --git a/src/sparkwheel/loader.py b/src/sparkwheel/loader.py index 948da8f..748f2f1 100644 --- a/src/sparkwheel/loader.py +++ b/src/sparkwheel/loader.py @@ -7,23 +7,23 @@ import yaml # type: ignore[import-untyped] -from .metadata import MetadataRegistry +from .locations import LocationRegistry from .path_utils import is_yaml_file from .utils import CheckKeyDuplicatesYamlLoader, PathLike from .utils.constants import ID_SEP_KEY -from .utils.exceptions import SourceLocation +from .utils.exceptions import Location __all__ = ["Loader"] class MetadataTrackingYamlLoader(CheckKeyDuplicatesYamlLoader): - """YAML loader that tracks source locations into MetadataRegistry. + """YAML loader that tracks source locations into LocationRegistry. Unlike the old approach that added __sparkwheel_metadata__ keys to dicts, - this loader populates a separate MetadataRegistry during loading. + this loader populates a separate LocationRegistry during loading. """ - def __init__(self, stream, filepath: str, registry: MetadataRegistry): # type: ignore[no-untyped-def] + def __init__(self, stream, filepath: str, registry: LocationRegistry): # type: ignore[no-untyped-def] super().__init__(stream) self.filepath = filepath self.registry = registry @@ -35,7 +35,7 @@ def construct_mapping(self, node, deep=False): current_id = ID_SEP_KEY.join(self.id_path_stack) if self.id_path_stack else "" if node.start_mark: - location = SourceLocation( + location = Location( filepath=self.filepath, line=node.start_mark.line + 1, column=node.start_mark.column + 1, @@ -53,6 +53,18 @@ def construct_mapping(self, node, deep=False): # Push key onto path stack before constructing value self.id_path_stack.append(str(key)) + # Register source location for this specific key + # This allows us to track where each key was defined + key_id = ID_SEP_KEY.join(self.id_path_stack) if self.id_path_stack else "" + if key_node.start_mark: + key_location = Location( + filepath=self.filepath, + line=key_node.start_mark.line + 1, + column=key_node.start_mark.column + 1, + id=key_id, + ) + self.registry.register(key_id, key_location) + # Construct value with updated path value = self.construct_object(value_node, deep=True) @@ -72,7 +84,7 @@ def construct_sequence(self, node, deep=False): current_id = ID_SEP_KEY.join(self.id_path_stack) if self.id_path_stack else "" if node.start_mark: - location = SourceLocation( + location = Location( filepath=self.filepath, line=node.start_mark.line + 1, column=node.start_mark.column + 1, @@ -121,7 +133,7 @@ class Loader: ``` """ - def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], MetadataRegistry]: + def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], LocationRegistry]: """Load a single YAML file with metadata tracking. Args: @@ -134,7 +146,7 @@ def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], MetadataRegistr ValueError: If file is not a YAML file """ if not filepath: - return {}, MetadataRegistry() + return {}, LocationRegistry() filepath_str = str(Path(filepath)) @@ -154,7 +166,7 @@ def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], MetadataRegistr ) # Load YAML with metadata tracking - registry = MetadataRegistry() + registry = LocationRegistry() with open(resolved_path) as f: config = self._load_yaml_with_metadata(f, str(resolved_path), registry) @@ -163,13 +175,13 @@ def load_file(self, filepath: PathLike) -> tuple[dict[str, Any], MetadataRegistr return config, registry - def _load_yaml_with_metadata(self, stream, filepath: str, registry: MetadataRegistry) -> dict[str, Any]: # type: ignore[no-untyped-def] + def _load_yaml_with_metadata(self, stream, filepath: str, registry: LocationRegistry) -> dict[str, Any]: # type: ignore[no-untyped-def] """Load YAML and populate metadata registry during construction. Args: stream: File stream to load from filepath: Path string for error messages - registry: MetadataRegistry to populate + registry: LocationRegistry to populate Returns: Config dictionary (clean, no metadata keys) @@ -206,7 +218,7 @@ def _strip_metadata(config: Any) -> Any: else: return config - def load_files(self, filepaths: Sequence[PathLike]) -> tuple[dict[str, Any], MetadataRegistry]: + def load_files(self, filepaths: Sequence[PathLike]) -> tuple[dict[str, Any], LocationRegistry]: """Load multiple YAML files sequentially. Files are loaded in order and merged using simple dict update @@ -219,7 +231,7 @@ def load_files(self, filepaths: Sequence[PathLike]) -> tuple[dict[str, Any], Met Tuple of (merged_config_dict, merged_metadata_registry) """ combined_config = {} - combined_registry = MetadataRegistry() + combined_registry = LocationRegistry() for filepath in filepaths: config, registry = self.load_file(filepath) diff --git a/src/sparkwheel/locations.py b/src/sparkwheel/locations.py new file mode 100644 index 0000000..e2cc0de --- /dev/null +++ b/src/sparkwheel/locations.py @@ -0,0 +1,74 @@ +"""Location tracking for configuration keys.""" + +from sparkwheel.utils.exceptions import Location + +__all__ = ["LocationRegistry"] + + +class LocationRegistry: + """Registry that tracks file locations for configuration keys. + + Maintains a clean separation between config data and location information + about where config items came from. This avoids polluting config dictionaries + with metadata keys. + + Example: + ```python + registry = LocationRegistry() + registry.register("model::lr", Location("config.yaml", 10, 2, "model::lr")) + + location = registry.get("model::lr") + print(location.filepath) # "config.yaml" + print(location.line) # 10 + ``` + """ + + def __init__(self): + """Initialize empty location registry.""" + self._locations: dict[str, Location] = {} + + def register(self, id_path: str, location: Location) -> None: + """Register location for a config path. + + Args: + id_path: Configuration path (e.g., "model::lr", "optimizer::params::0") + location: Location information + """ + self._locations[id_path] = location + + def get(self, id_path: str) -> Location | None: + """Get location for a config path. + + Args: + id_path: Configuration path to look up + + Returns: + Location if registered, None otherwise + """ + return self._locations.get(id_path) + + def merge(self, other: "LocationRegistry") -> None: + """Merge another registry into this one. + + Args: + other: LocationRegistry to merge from + """ + self._locations.update(other._locations) + + def copy(self) -> "LocationRegistry": + """Create a copy of this registry. + + Returns: + New LocationRegistry with same data + """ + new_registry = LocationRegistry() + new_registry._locations = self._locations.copy() + return new_registry + + def __len__(self) -> int: + """Return number of registered locations.""" + return len(self._locations) + + def __contains__(self, id_path: str) -> bool: + """Check if id_path has registered location.""" + return id_path in self._locations diff --git a/src/sparkwheel/metadata.py b/src/sparkwheel/metadata.py deleted file mode 100644 index f5f6ae4..0000000 --- a/src/sparkwheel/metadata.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Source location metadata tracking.""" - -from sparkwheel.utils.exceptions import SourceLocation - -__all__ = ["MetadataRegistry"] - - -class MetadataRegistry: - """Track source locations for config items. - - Maintains a clean separation between config data and metadata about where - config items came from. This avoids polluting config dictionaries with - metadata keys. - - Example: - ```python - registry = MetadataRegistry() - registry.register("model::lr", SourceLocation("config.yaml", 10, 2, "model::lr")) - - location = registry.get("model::lr") - print(location.filepath) # "config.yaml" - print(location.line) # 10 - ``` - """ - - def __init__(self): - """Initialize empty metadata registry.""" - self._locations: dict[str, SourceLocation] = {} - - def register(self, id_path: str, location: SourceLocation) -> None: - """Register source location for a config path. - - Args: - id_path: Configuration path (e.g., "model::lr", "optimizer::params::0") - location: Source location information - """ - self._locations[id_path] = location - - def get(self, id_path: str) -> SourceLocation | None: - """Get source location for a config path. - - Args: - id_path: Configuration path to look up - - Returns: - SourceLocation if registered, None otherwise - """ - return self._locations.get(id_path) - - def merge(self, other: "MetadataRegistry") -> None: - """Merge another registry into this one. - - Args: - other: MetadataRegistry to merge from - """ - self._locations.update(other._locations) - - def copy(self) -> "MetadataRegistry": - """Create a copy of this registry. - - Returns: - New MetadataRegistry with same data - """ - new_registry = MetadataRegistry() - new_registry._locations = self._locations.copy() - return new_registry - - def __len__(self) -> int: - """Return number of registered locations.""" - return len(self._locations) - - def __contains__(self, id_path: str) -> bool: - """Check if id_path has registered location.""" - return id_path in self._locations diff --git a/src/sparkwheel/operators.py b/src/sparkwheel/operators.py index b324043..e2e9313 100644 --- a/src/sparkwheel/operators.py +++ b/src/sparkwheel/operators.py @@ -1,12 +1,97 @@ """Configuration merging with composition-by-default and operators (=, ~).""" from copy import deepcopy -from typing import Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any -from .utils.constants import REMOVE_KEY, REPLACE_KEY -from .utils.exceptions import ConfigMergeError +if TYPE_CHECKING: + from .locations import LocationRegistry + from .utils.exceptions import Location -__all__ = ["apply_operators", "validate_operators", "_validate_delete_operator"] +from .utils.constants import ID_SEP_KEY, REMOVE_KEY, REPLACE_KEY +from .utils.exceptions import ConfigMergeError, build_missing_key_error + +__all__ = ["apply_operators", "validate_operators", "_validate_delete_operator", "MergeContext"] + + +@dataclass +class MergeContext: + """Context for configuration merging operations. + + This class consolidates location tracking and path information needed + during recursive config merging. Using a context object reduces parameter + threading and makes it easier to extend with additional context in the future. + + Attributes: + locations: Optional registry tracking file locations of config keys + current_path: Current path in config tree using :: separator (e.g., "model::optimizer::lr") + + Examples: + >>> ctx = MergeContext() + >>> child_ctx = ctx.child_path("model") + >>> child_ctx.current_path + 'model' + >>> grandchild_ctx = child_ctx.child_path("lr") + >>> grandchild_ctx.current_path + 'model::lr' + """ + + locations: "LocationRegistry | None" = None + current_path: str = "" + + def child_path(self, key: str) -> "MergeContext": + """Create a child context for a nested config key. + + Args: + key: Key to append to current path + + Returns: + New MergeContext with updated path, sharing the same source location registry + + Examples: + >>> ctx = MergeContext(current_path="model") + >>> child = ctx.child_path("optimizer") + >>> child.current_path + 'model::optimizer' + """ + new_path = f"{self.current_path}{ID_SEP_KEY}{key}" if self.current_path else key + return MergeContext(locations=self.locations, current_path=new_path) + + def get_source_location(self, key: str) -> "Location | None": + """Get source location for a key in the current context. + + Tries both the exact key and the key without operator prefix to handle + operator keys correctly (e.g., ~key, =key). + + Args: + key: Key to look up (may include operators like ~key or =key) + + Returns: + Location if found, None otherwise + + Examples: + >>> ctx = MergeContext(locations=registry, current_path="model") + >>> ctx.get_source_location("~lr") # Looks up both "model::~lr" and "model::lr" + >>> ctx.get_source_location("=lr") # Looks up both "model::=lr" and "model::lr" + """ + if not self.locations: + return None + + # Build full path for the key + key_path = f"{self.current_path}{ID_SEP_KEY}{key}" if self.current_path else key + + # Try the key as-is first + location = self.locations.get(key_path) + if location: + return location + + # If key starts with an operator (~, =), also try without the operator + if key.startswith((REMOVE_KEY, REPLACE_KEY)): + actual_key = key[1:] + actual_path = f"{self.current_path}{ID_SEP_KEY}{actual_key}" if self.current_path else actual_key + return self.locations.get(actual_path) + + return None def _validate_delete_operator(key: str, value: Any) -> None: @@ -58,7 +143,7 @@ def validate_operators(config: dict[str, Any], parent_key: str = "") -> None: """Validate operator usage in config tree. With composition-by-default, validation is simpler: - 1. Remove operators always work (idempotent delete) + 1. Remove operators error if key doesn't exist 2. Replace operators work on any type 3. No parent context requirements @@ -98,13 +183,17 @@ def validate_operators(config: dict[str, Any], parent_key: str = "") -> None: validate_operators(value, full_key) -def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: +def apply_operators( + base: dict[str, Any], + override: dict[str, Any], + context: MergeContext | None = None, +) -> dict[str, Any]: """Apply configuration changes with composition-by-default semantics. Default behavior: Compose (merge dicts, extend lists) Operators: =key: value - Replace operator: completely replace value (override default) - ~key: null - Remove operator: delete key or list items (idempotent) + ~key: null - Remove operator: delete key or list items (errors if missing) key: value - Compose (default): merge dict or extend list Composition-by-Default Philosophy: @@ -112,11 +201,12 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str, - Lists extend by default (append new items) - Only scalars and type mismatches replace - Use = to explicitly replace entire dicts or lists - - Use ~ to delete keys (idempotent - no error if missing) + - Use ~ to delete keys (errors if key doesn't exist) Args: base: Base configuration dict override: Override configuration dict with optional =/~ operators + context: Optional merge context with metadata and path tracking Returns: Merged configuration dict @@ -143,12 +233,21 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str, >>> apply_operators(base, override) {"model": {"lr": 0.01}} - >>> # Remove operator: delete key (idempotent) + >>> # Remove operator: delete key >>> base = {"a": 1, "b": 2, "c": 3} >>> override = {"b": 5, "~c": None} >>> apply_operators(base, override) {"a": 1, "b": 5} + + >>> # With context for better error messages + >>> from .locations import LocationRegistry + >>> ctx = MergeContext(locations=LocationRegistry()) + >>> apply_operators(base, override, context=ctx) """ + # Create context if not provided + if context is None: + context = MergeContext() + if not isinstance(base, dict) or not isinstance(override, dict): return deepcopy(override) # type: ignore[unreachable] @@ -170,9 +269,11 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str, actual_key = key[1:] _validate_delete_operator(actual_key, value) - # Idempotent: no error if key doesn't exist + # Error if key doesn't exist if actual_key not in result: - continue # Silently skip + # Get source location using context helper + source_location = context.get_source_location(key) + raise build_missing_key_error(actual_key, list(result.keys()), source_location) # Handle remove entire key (null or empty value) if value is None or value == "": @@ -223,11 +324,10 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str, elif isinstance(base_val, dict): for del_key in value: if del_key not in base_val: - raise ConfigMergeError( - f"Cannot remove non-existent key '{del_key}' from '{actual_key}'", - suggestion=f"The key '{del_key}' does not exist in '{actual_key}'.\n" - f"Available keys: {list(base_val.keys())}", - ) + # Use helper to build error with suggestions + available_keys = list(base_val.keys()) + source_location = context.get_source_location(key) + raise build_missing_key_error(del_key, available_keys, source_location, parent_key=actual_key) del base_val[del_key] else: @@ -246,7 +346,9 @@ def apply_operators(base: dict[str, Any], override: dict[str, Any]) -> dict[str, # For dicts: MERGE (composition) if isinstance(base_val, dict) and isinstance(value, dict): - result[key] = apply_operators(base_val, value) + # Create child context for recursion + child_context = context.child_path(key) + result[key] = apply_operators(base_val, value, context=child_context) continue # For lists: EXTEND (composition) diff --git a/src/sparkwheel/parser.py b/src/sparkwheel/parser.py index 5264d66..66f6034 100644 --- a/src/sparkwheel/parser.py +++ b/src/sparkwheel/parser.py @@ -3,7 +3,7 @@ from typing import Any from .items import Component, Expression, Item -from .metadata import MetadataRegistry +from .locations import LocationRegistry from .utils.constants import ID_SEP_KEY __all__ = ["Parser"] @@ -27,7 +27,7 @@ class Parser: } } - metadata = MetadataRegistry() + metadata = LocationRegistry() parser = Parser(globals={}, metadata=metadata) items = parser.parse(config) @@ -43,15 +43,15 @@ class Parser: Args: globals: Global context for expression evaluation - metadata: MetadataRegistry for source location lookup + metadata: LocationRegistry for source location lookup """ - def __init__(self, globals: dict[str, Any], metadata: MetadataRegistry): + def __init__(self, globals: dict[str, Any], metadata: LocationRegistry): """Initialize parser with globals and metadata. Args: globals: Dictionary of global variables for expression evaluation - metadata: MetadataRegistry for looking up source locations + metadata: LocationRegistry for looking up source locations """ self._globals = globals self._metadata = metadata diff --git a/src/sparkwheel/path_utils.py b/src/sparkwheel/path_utils.py index b0f5b2e..ad07245 100644 --- a/src/sparkwheel/path_utils.py +++ b/src/sparkwheel/path_utils.py @@ -18,6 +18,7 @@ "replace_references", "split_file_and_id", "is_yaml_file", + "get_by_id", "PathPatterns", # Export for backward compatibility ] @@ -286,6 +287,85 @@ def normalize_id(id: str | int) -> str: return str(id) +def _format_path_context(path_parts: list[str], index: int) -> str: + """Format parent path context for error messages. + + Internal helper for get_by_id error messages. + + Args: + path_parts: List of path components + index: Current index in path_parts (0 = first level) + + Returns: + Empty string for first level, or " in 'parent::path'" for nested levels + """ + if index == 0: + return "" + parent_path = ID_SEP_KEY.join(path_parts[:index]) + return f" in '{parent_path}'" + + +def get_by_id(config: dict[str, Any] | list[Any], id: str) -> Any: + """Navigate config structure by ID path. + + Traverses nested dicts and lists using :: separated path components. + Provides detailed error messages with available keys when navigation fails. + + Args: + config: Config dict or list to navigate + id: ID path (e.g., "model::optimizer::lr" or "items::0::value") + + Returns: + Value at the specified path + + Raises: + KeyError: If a dict key is not found (includes available keys in message) + TypeError: If trying to index a non-dict/list value + + Examples: + >>> config = {"model": {"lr": 0.001, "layers": [64, 128]}} + >>> get_by_id(config, "model::lr") + 0.001 + >>> get_by_id(config, "model::layers::1") + 128 + >>> get_by_id(config, "") + {"model": {"lr": 0.001, "layers": [64, 128]}} + + Error messages include context: + >>> get_by_id({"a": {"b": 1}}, "a::missing") + KeyError: "Key 'missing' not found in 'a'. Available keys: ['b']" + """ + if not id: + return config + + current = config + path_parts = split_id(id) + + for i, key in enumerate(path_parts): + context = _format_path_context(path_parts, i) + + if isinstance(current, dict): + if key not in current: + available_keys = list(current.keys()) + error_msg = f"Key '{key}' not found{context}" + error_msg += f". Available keys: {available_keys[:10]}" + if len(available_keys) > 10: + error_msg += "..." + raise KeyError(error_msg) + current = current[key] + elif isinstance(current, list): + try: + current = current[int(key)] + except ValueError as e: + raise KeyError(f"Invalid list index '{key}'{context}: not an integer") from e + except IndexError as e: + raise KeyError(f"List index '{key}' out of range{context}: {e}") from e + else: + raise TypeError(f"Cannot index {type(current).__name__} with key '{key}'{context}") + + return current + + def resolve_relative_ids(current_id: str, value: str) -> str: """Resolve relative references (@::, @::::) to absolute paths. diff --git a/src/sparkwheel/preprocessor.py b/src/sparkwheel/preprocessor.py index cfc70ab..0e4262a 100644 --- a/src/sparkwheel/preprocessor.py +++ b/src/sparkwheel/preprocessor.py @@ -6,10 +6,14 @@ """ from copy import deepcopy -from typing import Any +from typing import TYPE_CHECKING, Any, Optional -from .path_utils import resolve_relative_ids, split_file_and_id, split_id +from .path_utils import get_by_id, resolve_relative_ids, split_file_and_id from .utils.constants import ID_SEP_KEY, RAW_REF_KEY +from .utils.exceptions import CircularReferenceError, ConfigKeyError + +if TYPE_CHECKING: + from .locations import LocationRegistry __all__ = ["Preprocessor"] @@ -25,27 +29,38 @@ class Preprocessor: Operates on raw Python dicts/lists, not on Item objects. + ## Two-Phase Raw Reference Expansion + + Raw references are expanded in two phases to support CLI overrides: + + **Phase 1 (Eager, during update()):** + - External file refs (`%file.yaml::key`) are expanded immediately + - The external file is frozen - its contents won't change + + **Phase 2 (Lazy, during resolve()):** + - Local refs (`%key`) are expanded after all composition + - This allows CLI overrides to affect values referenced by local `%` refs + Example: >>> loader = Loader() >>> preprocessor = Preprocessor(loader) >>> >>> raw_config = { ... "lr": 0.001, - ... "base": "%defaults.yaml::learning_rate", # Raw reference (external) + ... "base": "%defaults.yaml::learning_rate", # External - expanded eagerly + ... "ref": "%lr", # Local - expanded lazily ... "model": { ... "lr": "@::lr" # Relative resolved reference ... } ... } >>> - >>> preprocessed = preprocessor.process(raw_config, raw_config) - >>> # Result: - >>> # { - >>> # "lr": 0.001, - >>> # "base": 0.0005, # Loaded from defaults.yaml - >>> # "model": { - >>> # "lr": "@model::lr" # Converted to absolute - >>> # } - >>> # } + >>> # Phase 1: Expand only external refs + >>> preprocessed = preprocessor.process_raw_refs(raw_config, raw_config, external_only=True) + >>> # Result: base=0.0005, ref="%lr" (still string) + >>> + >>> # Phase 2: Expand local refs (after CLI overrides applied) + >>> preprocessed = preprocessor.process_raw_refs(preprocessed, preprocessed, external_only=False) + >>> # Result: ref=0.001 (now expanded) """ def __init__(self, loader, globals: dict[str, Any] | None = None): # type: ignore[no-untyped-def] @@ -58,14 +73,29 @@ def __init__(self, loader, globals: dict[str, Any] | None = None): # type: igno self.loader = loader self.globals = globals or {} - def process_raw_refs(self, config: Any, base_data: dict[str, Any], id: str = "") -> Any: - """Preprocess config tree - expand only % raw references. + def process_raw_refs( + self, + config: Any, + base_data: dict[str, Any], + id: str = "", + locations: Optional["LocationRegistry"] = None, + *, + external_only: bool = False, + ) -> Any: + """Preprocess config tree - expand % raw references. + + Supports two-phase expansion for CLI override compatibility: + + **Phase 1 (external_only=True, during update()):** + - Only expands external file refs (`%file.yaml::key`) + - Local refs (`%key`) are kept as strings for later expansion + - This allows CLI overrides to affect values used by local refs + + **Phase 2 (external_only=False, during resolve()):** + - Expands all remaining local refs (`%key`) + - At this point, all CLI overrides have been applied - This is the first preprocessing stage that runs eagerly during update(). - It expands all % raw references including those with relative syntax: - - Local: %key - - External: %file.yaml::key - - Relative: %::key, %::::key (converted to absolute before expansion) + Also handles relative syntax: %::key, %::::key (converted to absolute before expansion) Leaves @ resolved references untouched (they're processed lazily during resolve()). @@ -73,14 +103,18 @@ def process_raw_refs(self, config: Any, base_data: dict[str, Any], id: str = "") config: Raw config structure to process base_data: Root config dict (for resolving local raw references) id: Current ID path in tree + locations: LocationRegistry for error reporting (optional) + external_only: If True, only expand external file refs (Phase 1). + If False, expand all refs including local (Phase 2). Returns: - Config with raw references expanded + Config with raw references expanded (or partially expanded if external_only=True) Raises: - ValueError: If circular raw reference detected + CircularReferenceError: If circular raw reference detected + ConfigKeyError: If referenced key not found """ - return self._process_raw_refs_recursive(config, base_data, id, set()) + return self._process_raw_refs_recursive(config, base_data, id, set(), locations, external_only) def process(self, config: Any, base_data: dict[str, Any], id: str = "") -> Any: """Preprocess entire config tree. @@ -98,7 +132,7 @@ def process(self, config: Any, base_data: dict[str, Any], id: str = "") -> Any: Preprocessed config ready for parsing Raises: - ValueError: If circular raw reference detected + CircularReferenceError: If circular raw reference detected """ return self._process_recursive(config, base_data, id, set()) @@ -108,10 +142,13 @@ def _process_raw_refs_recursive( base_data: dict[str, Any], id: str, raw_ref_stack: set[str], + locations: Optional["LocationRegistry"] = None, + external_only: bool = False, ) -> Any: - """Internal recursive implementation for expanding only raw references. + """Internal recursive implementation for expanding raw references. - This method only expands % raw references and leaves @ references untouched. + This method expands % raw references and leaves @ references untouched. + When external_only=True, only external file refs are expanded. Performance optimization: Skips recursion for nodes that don't contain any raw reference strings, avoiding unnecessary tree traversal. @@ -121,9 +158,11 @@ def _process_raw_refs_recursive( base_data: Root config dict id: Current ID path raw_ref_stack: Circular reference detection + locations: LocationRegistry for error reporting (optional) + external_only: If True, skip local refs (expand only external file refs) Returns: - Config with raw references expanded + Config with raw references expanded (or partially if external_only=True) """ # Early exit optimization: Skip processing if this subtree has no raw references # This avoids unnecessary recursion for large config sections without % refs @@ -134,12 +173,16 @@ def _process_raw_refs_recursive( if isinstance(config, dict): for key in list(config.keys()): sub_id = f"{id}{ID_SEP_KEY}{key}" if id else str(key) - config[key] = self._process_raw_refs_recursive(config[key], base_data, sub_id, raw_ref_stack) + config[key] = self._process_raw_refs_recursive( + config[key], base_data, sub_id, raw_ref_stack, locations, external_only + ) elif isinstance(config, list): for idx in range(len(config)): sub_id = f"{id}{ID_SEP_KEY}{idx}" if id else str(idx) - config[idx] = self._process_raw_refs_recursive(config[idx], base_data, sub_id, raw_ref_stack) + config[idx] = self._process_raw_refs_recursive( + config[idx], base_data, sub_id, raw_ref_stack, locations, external_only + ) # Process string values - only expand raw references (%) if isinstance(config, str): @@ -149,7 +192,7 @@ def _process_raw_refs_recursive( # Then expand raw references if config.startswith(RAW_REF_KEY): - config = self._expand_raw_ref(config, base_data, raw_ref_stack) + config = self._expand_raw_ref(config, base_data, raw_ref_stack, id, locations, external_only) return config @@ -193,42 +236,95 @@ def _process_recursive( return config - def _expand_raw_ref(self, raw_ref: str, base_data: dict[str, Any], raw_ref_stack: set[str]) -> Any: + def _expand_raw_ref( + self, + raw_ref: str, + base_data: dict[str, Any], + raw_ref_stack: set[str], + current_id: str = "", + locations: Optional["LocationRegistry"] = None, + external_only: bool = False, + ) -> Any: """Expand a single raw reference by loading external file or local YAML. Args: raw_ref: Raw reference string (e.g., "%file.yaml::key" or "%key") base_data: Root config for local raw references raw_ref_stack: Circular reference detection + current_id: Current ID path (where this raw reference was found) + locations: LocationRegistry for error reporting (optional) + external_only: If True, skip local refs and return raw_ref unchanged Returns: - Value from raw reference (deep copied) + Value from raw reference (deep copied), or raw_ref unchanged if + external_only=True and this is a local reference Raises: - ValueError: If circular reference detected + CircularReferenceError: If circular reference detected + ConfigKeyError: If referenced key not found """ + # Parse: "%file.yaml::key" β†’ ("file.yaml", "key") + path, ids = split_file_and_id(raw_ref[len(RAW_REF_KEY) :]) + + # Phase 1 (external_only=True): Skip local refs, they'll be expanded later + # This allows CLI overrides to affect values used by local % refs + is_local_ref = not path + if external_only and is_local_ref: + return raw_ref # Keep as string, expand in Phase 2 + # Circular reference check if raw_ref in raw_ref_stack: chain = " -> ".join(sorted(raw_ref_stack)) - raise ValueError(f"Circular raw reference detected: '{raw_ref}'\nRaw reference chain: {chain} -> {raw_ref}") - # Parse: "%file.yaml::key" β†’ ("file.yaml", "key") - path, ids = split_file_and_id(raw_ref[len(RAW_REF_KEY) :]) + # Get location information if available + location = None + if locations and current_id: + location = locations.get(current_id) + + raise CircularReferenceError( + message=f"Circular raw reference detected: '{raw_ref}'\nReference chain: {chain} -> {raw_ref}", + source_location=location, + ) raw_ref_stack.add(raw_ref) try: # Load config (external file or local) - if not path: + if is_local_ref: loaded_config = base_data # Local raw reference: %key + loaded_locations = locations # Use same location registry + source_description = "local config" else: - loaded_config, _ = self.loader.load_file(path) # External: %file.yaml::key + loaded_config, loaded_locations = self.loader.load_file(path) # External: %file.yaml::key + source_description = f"'{path}'" # Navigate to referenced value - result = self._get_by_id(loaded_config, ids) - - # Recursively preprocess the loaded value (expand nested raw references only) - result = self._process_raw_refs_recursive(result, loaded_config, ids, raw_ref_stack) + try: + result = get_by_id(loaded_config, ids) + except (KeyError, TypeError, IndexError) as e: + # Get location information if available + location = None + if locations and current_id: + location = locations.get(current_id) + + # Build error message + if is_local_ref: + error_msg = f"Error resolving raw reference '{raw_ref}' from local config:\n{e}" + else: + error_msg = f"Error resolving raw reference '{raw_ref}' from {source_description}:\n{e}" + + # Raise custom error with proper formatting + raise ConfigKeyError( + message=error_msg, + source_location=location, + ) from e + + # Recursively preprocess the loaded value (expand nested raw references) + # For external files, always expand all refs within that file + # For local refs (Phase 2), expand all nested refs too + result = self._process_raw_refs_recursive( + result, loaded_config, ids, raw_ref_stack, loaded_locations, external_only=False + ) # Deep copy for independence return deepcopy(result) @@ -255,32 +351,3 @@ def _contains_raw_refs(config: Any) -> bool: elif isinstance(config, list): return any(Preprocessor._contains_raw_refs(item) for item in config) return False - - @staticmethod - def _get_by_id(config: dict[str, Any], id: str) -> Any: - """Navigate config dict by ID path. - - Args: - config: Config dict to navigate - id: ID path (e.g., "model::optimizer::lr") - - Returns: - Value at ID path - - Raises: - KeyError: If path not found - TypeError: If trying to index non-dict/list - """ - if not id: - return config - - current = config - for key in split_id(id): - if isinstance(current, dict): - current = current[key] - elif isinstance(current, list): # type: ignore[unreachable] - current = current[int(key)] - else: - raise TypeError(f"Cannot index {type(current).__name__} with key '{key}' at path '{id}'") - - return current diff --git a/src/sparkwheel/schema.py b/src/sparkwheel/schema.py index 973cfa6..2d8bf40 100644 --- a/src/sparkwheel/schema.py +++ b/src/sparkwheel/schema.py @@ -39,7 +39,7 @@ class ModelConfig: import types from typing import Any, Union, get_args, get_origin -from .utils.exceptions import BaseError, SourceLocation +from .utils.exceptions import BaseError, Location __all__ = ["validate", "validator", "ValidationError", "MISSING"] @@ -192,7 +192,7 @@ def __init__( field_path: str = "", expected_type: type | None = None, actual_value: Any = None, - source_location: SourceLocation | None = None, + source_location: Location | None = None, ): """Initialize validation error. @@ -682,15 +682,15 @@ def _validate_field( ) -def _get_source_location(metadata: Any, field_path: str) -> SourceLocation | None: +def _get_source_location(metadata: Any, field_path: str) -> Location | None: """Get source location from metadata registry. Args: - metadata: MetadataRegistry instance + metadata: LocationRegistry instance field_path: Dot-separated field path to look up Returns: - SourceLocation if found, None otherwise + Location if found, None otherwise """ if metadata is None: return None @@ -699,6 +699,6 @@ def _get_source_location(metadata: Any, field_path: str) -> SourceLocation | Non # Convert dot notation to :: notation used by sparkwheel id_path = field_path.replace(".", "::") result = metadata.get(id_path) - return result if result is None or isinstance(result, SourceLocation) else None + return result if result is None or isinstance(result, Location) else None except Exception: return None diff --git a/src/sparkwheel/utils/constants.py b/src/sparkwheel/utils/constants.py index 872b1ce..ebccc16 100644 --- a/src/sparkwheel/utils/constants.py +++ b/src/sparkwheel/utils/constants.py @@ -5,6 +5,7 @@ "EXPR_KEY", "REMOVE_KEY", "REPLACE_KEY", + "SIMILARITY_THRESHOLD", ] RESOLVED_REF_KEY = "@" # start of a resolved reference (to instantiated/evaluated value) @@ -13,3 +14,4 @@ EXPR_KEY = "$" # start of an Expression REMOVE_KEY = "~" # remove operator for config modifications (delete keys/items) REPLACE_KEY = "=" # replace operator for config modifications (explicit override) +SIMILARITY_THRESHOLD = 0.6 # minimum similarity score for key name suggestions (0-1) diff --git a/src/sparkwheel/utils/exceptions.py b/src/sparkwheel/utils/exceptions.py index cbb9645..d04eedb 100644 --- a/src/sparkwheel/utils/exceptions.py +++ b/src/sparkwheel/utils/exceptions.py @@ -5,7 +5,7 @@ from typing import Any __all__ = [ - "SourceLocation", + "Location", "BaseError", "TargetNotFoundError", "CircularReferenceError", @@ -14,18 +14,31 @@ "ConfigMergeError", "EvaluationError", "FrozenConfigError", + "build_missing_key_error", ] @dataclass -class SourceLocation: - """Tracks the source location of a config item.""" +class Location: + """Tracks the location of a config item in source files. + + Attributes: + filepath: Path to the source file + line: Line number in the file (must be >= 1) + column: Column number (0 if not available) + id: Config path ID (e.g., "model::lr") + """ filepath: str line: int column: int = 0 id: str = "" + def __post_init__(self) -> None: + """Validate line number after initialization.""" + if self.line < 1: + raise ValueError(f"line must be >= 1, got {self.line}") + def __str__(self) -> str: return f"{self.filepath}:{self.line}" @@ -42,7 +55,7 @@ class BaseError(Exception): def __init__( self, message: str, - source_location: SourceLocation | None = None, + source_location: Location | None = None, suggestion: str | None = None, ) -> None: self.source_location = source_location @@ -62,9 +75,9 @@ def _format_message(self) -> str: if self.source_location: location = f"{self.source_location.filepath}:{self.source_location.line}" if self.source_location.id: - parts.append(f"[{location} @ {self.source_location.id}] {self._original_message}") + parts.append(f"{self._original_message}\n\n[{location} β†’ {self.source_location.id}]:") else: - parts.append(f"[{location}] {self._original_message}") + parts.append(f"{self._original_message}\n\n[{location}]:") else: parts.append(self._original_message) @@ -72,10 +85,10 @@ def _format_message(self) -> str: if self.source_location: snippet = self._get_config_snippet() if snippet: - parts.append(f"\n\n{snippet}") + parts.append(f"\n{snippet}\n") if self.suggestion: - parts.append(f"\n\n πŸ’‘ {self.suggestion}") + parts.append(f"\n πŸ’‘ {self.suggestion}\n") return "".join(parts) @@ -136,7 +149,7 @@ class ConfigKeyError(BaseError): def __init__( self, message: str, - source_location: SourceLocation | None = None, + source_location: Location | None = None, suggestion: str | None = None, missing_key: str | None = None, available_keys: list[str] | None = None, @@ -217,3 +230,59 @@ def __init__(self, message: str, field_path: str = ""): if field_path: full_message = f"Cannot modify frozen config at '{field_path}': {message}" super().__init__(full_message) + + +def build_missing_key_error( + key: str, + available_keys: list[str], + source_location: Location | None = None, + *, + max_suggestions: int = 3, + max_available_keys: int = 10, + parent_key: str | None = None, +) -> ConfigMergeError: + """Build a ConfigMergeError for a missing key with helpful suggestions. + + Args: + key: The key that wasn't found + available_keys: List of available keys to compare against + source_location: Optional location where error occurred + max_suggestions: Maximum number of suggestions to show (default: 3) + max_available_keys: Maximum number of available keys to show (default: 10) + parent_key: Optional parent key for nested deletions (e.g., "model" when deleting "model.lr") + + Returns: + ConfigMergeError with helpful suggestions + + Examples: + >>> error = build_missing_key_error("paramters", ["parameters", "param_groups"]) + >>> print(error._original_message) + Cannot delete key 'paramters': key does not exist + + >>> error = build_missing_key_error("lr", ["learning_rate"], parent_key="model") + >>> print(error._original_message) + Cannot remove non-existent key 'lr' from 'model' + """ + from ..errors import get_suggestions + from .constants import SIMILARITY_THRESHOLD + + # Build appropriate message based on context + if parent_key: + message = f"Cannot remove non-existent key '{key}' from '{parent_key}'" + else: + message = f"Cannot delete key '{key}': key does not exist" + + suggestion_parts = [] + if available_keys: + suggestions = get_suggestions( + key, available_keys, max_suggestions=max_suggestions, similarity_threshold=SIMILARITY_THRESHOLD + ) + if suggestions: + suggestion_keys = [s[0] for s in suggestions] + suggestion_parts.append(f"Did you mean: {', '.join(repr(s) for s in suggestion_keys)}?") + + if len(available_keys) <= max_available_keys: + suggestion_parts.append(f"Available keys: {', '.join(repr(k) for k in available_keys)}") + + suggestion = "\n".join(suggestion_parts) if suggestion_parts else None + return ConfigMergeError(message, source_location=source_location, suggestion=suggestion) diff --git a/src/sparkwheel/utils/module.py b/src/sparkwheel/utils/module.py index 3eabb65..8a21d0b 100644 --- a/src/sparkwheel/utils/module.py +++ b/src/sparkwheel/utils/module.py @@ -248,13 +248,8 @@ def instantiate(__path: str, __mode: str, *args: Any, **kwargs: Any) -> Any: ) return pdb.runcall(component, *args, **kwargs) except Exception as e: - # Preserve the original exception type and message for better debugging - args_str = f"{len(args)} positional args, " if args else "" - error_msg = ( - f"Failed to instantiate component '{__path}' with {args_str}keywords: {','.join(kwargs.keys())}\n" - f" Original error ({type(e).__name__}): {str(e)}\n" - f" Set '_mode_={CompInitMode.DEBUG}' to enter debugging mode." - ) + error_msg = f"Could not instantiate component '{__path}':\n From {type(e).__name__}: {e}" + raise InstantiationError(error_msg) from e warnings.warn(f"Component to instantiate must represent a valid class or function, but got {__path}.", stacklevel=2) diff --git a/tests/test_config.py b/tests/test_config.py index 80b8359..874bbdb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -106,7 +106,7 @@ def test_merge_nested_paths(self): def test_getitem_invalid_config_type(self): """Test __getitem__ raises error for invalid config type.""" parser = Config({"scalar": 42}) - with pytest.raises(ValueError, match="Config must be dict or list"): + with pytest.raises(TypeError, match="Cannot index int"): _ = parser["scalar::invalid"] def test_getitem_list_indexing(self): @@ -277,45 +277,86 @@ def test_do_resolve_macro_load(self): finally: Path(filepath).unlink() - def test_eager_raw_reference_expansion(self): - """Test that raw references are expanded during update(), not resolve().""" + def test_local_raw_reference_lazy_expansion(self): + """Test that local raw references are expanded lazily (during resolve()). + + This allows CLI overrides to affect values used by local % refs. + """ config = {"original": {"a": 1, "b": 2}, "copy": "%original"} parser = Config().update(config) - # After update(), raw references should be expanded (not resolve() yet!) - # The copy should be the actual value, not the "%original" string - assert parser.get("copy") == {"a": 1, "b": 2} - assert parser.get("copy") is not parser.get("original") # Deep copy + # After update(), LOCAL raw references are NOT expanded yet + # They remain as strings until resolve() is called + assert parser.get("copy") == "%original" + + # After resolve(), local refs are expanded + resolved = parser.resolve("copy") + assert resolved == {"a": 1, "b": 2} + + # Verify it's a deep copy (independent of original) + assert resolved is not parser.resolve("original") + + def test_cli_override_affects_local_raw_ref(self): + """Test that CLI overrides affect values used by local % refs. + + This is the key use case: vars::path can be overridden via CLI + and local % refs will see the overridden value. + """ + parser = Config() + parser.update({"vars": {"path": None}}) + parser.update({"data": {"path": "%vars::path"}}) + + # Before CLI override, local ref is still a string + assert parser.get("data::path") == "%vars::path" + + # Apply CLI override + parser.update("vars::path=/data/features.npz") + + # Now resolve - local ref should see the overridden value + assert parser.resolve("data::path") == "/data/features.npz" + + def test_external_raw_reference_eager_expansion(self, tmp_path): + """Test that external file raw references are expanded eagerly.""" + # Create external file + external = tmp_path / "external.yaml" + external.write_text("value: 42\nnested:\n a: 1\n b: 2") + + parser = Config() + parser.update({"imported": f"%{external}::value", "section": f"%{external}::nested"}) + + # After update(), EXTERNAL raw references ARE expanded (eager) + assert parser.get("imported") == 42 + assert parser.get("section") == {"a": 1, "b": 2} + + def test_pruning_with_external_raw_references(self, tmp_path): + """Test that pruning works with external file raw references. - # Verify it's a real copy by modifying it - parser.set("copy::c", 3) - assert parser.get("copy::c") == 3 - assert "c" not in parser.get("original") + External file refs are expanded eagerly, so copy-then-delete works. + For local refs, use @ references instead (they also support this pattern). + """ + # Create external file with dataloader configs + external = tmp_path / "dataloaders.yaml" + external.write_text("train:\n batch_size: 32\nval:\n batch_size: 64") - def test_pruning_with_raw_references(self): - """Test that pruning works with raw references (the Lighter use case).""" - # This is the critical test: raw refs should be expanded before pruning config = { "system": { - "dataloaders": { - "train": {"batch_size": 32}, - "val": {"batch_size": 64}, - } + "dataloaders": f"%{external}", # External ref - expanded eagerly }, "train": { - "dataloader": "%system::dataloaders::train" # Raw reference + "dataloader": f"%{external}::train", # External ref - expanded eagerly }, } parser = Config().update(config) - # Raw reference should already be expanded + # External raw reference should already be expanded (eager) assert parser.get("train::dataloader") == {"batch_size": 32} + assert parser.get("system::dataloaders") == {"train": {"batch_size": 32}, "val": {"batch_size": 64}} # Now prune the system section (delete it) parser.update("~system") - # The raw reference was already expanded, so train::dataloader should still exist + # The external raw reference was already expanded, so train::dataloader still exists assert parser.get("train::dataloader") == {"batch_size": 32} assert "system" not in parser.get() # system is deleted @@ -323,6 +364,31 @@ def test_pruning_with_raw_references(self): result = parser.resolve("train::dataloader") assert result == {"batch_size": 32} + def test_local_raw_ref_to_deleted_key_fails(self): + """Test that local % ref to deleted key fails clearly. + + With lazy local refs, deleting the source before resolve() will fail. + Use external file refs or @ refs if you need copy-then-delete. + """ + from sparkwheel.utils.exceptions import ConfigKeyError + + config = { + "system": {"train": {"batch_size": 32}}, + "train": {"dataloader": "%system::train"}, # Local ref - expanded lazily + } + + parser = Config().update(config) + + # Local ref is NOT expanded yet + assert parser.get("train::dataloader") == "%system::train" + + # Delete the source + parser.update("~system") + + # Attempting to resolve will fail - source was deleted + with pytest.raises(ConfigKeyError, match="system"): + parser.resolve("train::dataloader") + class TestComponents: """Test component instantiation and handling.""" @@ -430,6 +496,34 @@ def test_split_path_id_no_path(self): assert path == "" assert ids == "key::subkey" + def test_update_from_file_with_nested_paths_merges_locations(self, tmp_path): + """Test that nested-path syntax (::) in YAML files properly merges location tracking.""" + # Create a YAML file with nested-path syntax + config_file = tmp_path / "config.yaml" + config_file.write_text("model::lr: 0.001\nmodel::dropout: 0.5\ntrainer::epochs: 10") + + # Load the file + config = Config() + config.update(str(config_file)) + + # Verify the values were set correctly + assert config["model"]["lr"] == 0.001 + assert config["model"]["dropout"] == 0.5 + assert config["trainer"]["epochs"] == 10 + + # Verify that locations were tracked + # The location registry should have entries for the nested paths + assert "model::lr" in config._locations or "model" in config._locations + assert "model::dropout" in config._locations or "model" in config._locations + assert "trainer::epochs" in config._locations or "trainer" in config._locations + + # Verify the location points to the correct file + if "model::lr" in config._locations: + location = config._locations.get("model::lr") + assert location is not None + assert location.filepath == str(config_file) + assert location.line >= 1 + class TestConfigMerging: """Test merging configurations with composition-by-default and =/~ operators.""" @@ -629,14 +723,16 @@ def test_merge_combined_operators(self): assert "c" not in parser assert parser["d"] == {"new": 4} # Replaced - def test_delete_on_nonexistent_key_idempotent(self): - """Test that ~key is idempotent (no error when key doesn't exist).""" + def test_delete_on_nonexistent_key_raises_error(self): + """Test that ~key raises error when key doesn't exist.""" + from sparkwheel.utils.exceptions import ConfigMergeError + base = {"a": 1} override = {"~b": None} - # NEW: No error! Idempotent delete - result = apply_operators(base, override) - assert result == {"a": 1} + # Should raise error when key doesn't exist + with pytest.raises(ConfigMergeError, match="Cannot delete key 'b': key does not exist"): + apply_operators(base, override) def test_delete_directive_with_invalid_value_raises_error(self): """Test that ~key raises error when value is not null, empty, or list.""" @@ -888,6 +984,52 @@ def test_merge_with_empty_list(self): assert result == {"items": ["a", "b"]} + def test_delete_nonexistent_top_level_key_shows_available_keys(self): + """Test that deleting a nonexistent top-level key shows available top-level keys in error.""" + from sparkwheel.utils.exceptions import ConfigMergeError + + config = Config().update({"model": {"lr": 0.001}, "trainer": {"epochs": 10}}) + + with pytest.raises(ConfigMergeError) as exc_info: + config.update({"~missing": None}) + + error_msg = str(exc_info.value) + assert "Cannot delete key 'missing'" in error_msg + # Should suggest top-level keys + assert "'model'" in error_msg + assert "'trainer'" in error_msg + + def test_delete_nonexistent_nested_key_shows_parent_keys(self): + """Test that deleting a nonexistent nested key shows keys from parent container.""" + from sparkwheel.utils.exceptions import ConfigMergeError + + config = Config().update({"model": {"lr": 0.001, "dropout": 0.5, "hidden_size": 1024}}) + + with pytest.raises(ConfigMergeError) as exc_info: + config.update({"~model::missing": None}) + + error_msg = str(exc_info.value) + assert "Cannot remove non-existent key 'missing' from 'model'" in error_msg + # Should suggest keys from the parent (model) + assert "'lr'" in error_msg + assert "'dropout'" in error_msg + assert "'hidden_size'" in error_msg + # Should NOT show unrelated top-level keys + assert "'trainer'" not in error_msg or "trainer" not in config._data + + def test_delete_nested_key_when_parent_doesnt_exist(self): + """Test error when trying to delete nested key but parent doesn't exist.""" + from sparkwheel.utils.exceptions import ConfigMergeError + + config = Config().update({"model": {"lr": 0.001}}) + + with pytest.raises(ConfigMergeError) as exc_info: + config.update({"~trainer::epochs": None}) + + # Should fail because 'trainer' doesn't exist + error_msg = str(exc_info.value) + assert "Cannot remove non-existent key 'epochs' from 'trainer'" in error_msg + class TestConfigAdvanced: """Test advanced Config features.""" diff --git a/tests/test_error_messages.py b/tests/test_error_messages.py index a50a1d5..a2b6b86 100644 --- a/tests/test_error_messages.py +++ b/tests/test_error_messages.py @@ -2,7 +2,6 @@ from sparkwheel import Config from sparkwheel.errors import ( - enable_colors, format_available_keys, format_resolution_chain, format_suggestions, @@ -341,221 +340,6 @@ def test_format_chain_without_reference(self): assert "βœ“" in result -class TestColorFormatting: - """Test color formatting with auto-detection.""" - - def test_enable_colors_explicit(self): - """Test explicitly enabling colors.""" - assert enable_colors(True) is True - assert enable_colors(False) is False - - def test_enable_colors_auto_detect(self): - """Test auto-detection.""" - result = enable_colors(None) - # Should return a boolean - assert isinstance(result, bool) - - def test_colors_disabled_in_tests(self): - """Test that colors can be disabled.""" - from sparkwheel.errors.formatters import format_error - - enable_colors(False) - formatted = format_error("error message") - - # Should not contain ANSI codes when disabled - assert "\033[" not in formatted - assert formatted == "error message" - - def test_colors_enabled(self): - """Test that colors work when enabled.""" - from sparkwheel.errors.formatters import format_error - - enable_colors(True) - formatted = format_error("error message") - - # Should contain ANSI codes when enabled - assert "\033[" in formatted - assert "error message" in formatted - - def test_format_suggestion(self): - """Test format_suggestion function.""" - from sparkwheel.errors.formatters import format_suggestion - - enable_colors(True) - result = format_suggestion("suggestion text") - assert "suggestion text" in result - assert "\033[" in result # Contains ANSI codes - - enable_colors(False) - result = format_suggestion("suggestion text") - assert result == "suggestion text" - - def test_format_success(self): - """Test format_success function.""" - from sparkwheel.errors.formatters import format_success - - enable_colors(True) - result = format_success("success text") - assert "success text" in result - assert "\033[" in result # Contains ANSI codes - - enable_colors(False) - result = format_success("success text") - assert result == "success text" - - def test_format_code(self): - """Test format_code function.""" - from sparkwheel.errors.formatters import format_code - - enable_colors(True) - result = format_code("code text") - assert "code text" in result - assert "\033[" in result # Contains ANSI codes - - enable_colors(False) - result = format_code("code text") - assert result == "code text" - - def test_format_context(self): - """Test format_context function.""" - from sparkwheel.errors.formatters import format_context - - enable_colors(True) - result = format_context("context text") - assert "context text" in result - assert "\033[" in result # Contains ANSI codes - - enable_colors(False) - result = format_context("context text") - assert result == "context text" - - def test_format_bold(self): - """Test format_bold function.""" - from sparkwheel.errors.formatters import format_bold - - enable_colors(True) - result = format_bold("bold text") - assert "bold text" in result - assert "\033[" in result # Contains ANSI codes - - enable_colors(False) - result = format_bold("bold text") - assert result == "bold text" - - def test_supports_color_with_no_color_env(self, monkeypatch): - """Test that NO_COLOR environment variable disables colors.""" - from sparkwheel.errors.formatters import _supports_color - - monkeypatch.setenv("NO_COLOR", "1") - assert _supports_color() is False - - def test_supports_color_with_sparkwheel_no_color_env(self, monkeypatch): - """Test that SPARKWHEEL_NO_COLOR environment variable disables colors.""" - from sparkwheel.errors.formatters import _supports_color - - monkeypatch.setenv("SPARKWHEEL_NO_COLOR", "1") - assert _supports_color() is False - - def test_get_colors_enabled_lazy_init(self): - """Test that _get_colors_enabled initializes colors if needed.""" - from sparkwheel.errors import formatters - - # Reset global state - formatters._COLORS_ENABLED = None - - # Should auto-detect when None - result = formatters._get_colors_enabled() - assert isinstance(result, bool) - assert formatters._COLORS_ENABLED is not None - - def test_supports_color_force_color(self, monkeypatch): - """Test FORCE_COLOR enables colors even without TTY.""" - import sys - - from sparkwheel.errors.formatters import _supports_color - - # Clear NO_COLOR and set FORCE_COLOR - monkeypatch.delenv("NO_COLOR", raising=False) - monkeypatch.delenv("SPARKWHEEL_NO_COLOR", raising=False) - monkeypatch.setenv("FORCE_COLOR", "1") - - # Mock stdout without isatty - class MockStdout: - def isatty(self): - return False - - original_stdout = sys.stdout - try: - sys.stdout = MockStdout() - assert _supports_color() is True - finally: - sys.stdout = original_stdout - - def test_supports_color_no_color_overrides_force_color(self, monkeypatch): - """Test NO_COLOR takes precedence over FORCE_COLOR.""" - import sys - - from sparkwheel.errors.formatters import _supports_color - - # Set both NO_COLOR and FORCE_COLOR - NO_COLOR should win - monkeypatch.setenv("NO_COLOR", "1") - monkeypatch.setenv("FORCE_COLOR", "1") - - # Mock stdout with TTY - class MockStdout: - def isatty(self): - return True - - original_stdout = sys.stdout - try: - sys.stdout = MockStdout() - # NO_COLOR should override both FORCE_COLOR and TTY - assert _supports_color() is False - finally: - sys.stdout = original_stdout - - def test_supports_color_no_isatty(self, monkeypatch): - """Test color detection when stdout has no isatty method.""" - import sys - - from sparkwheel.errors.formatters import _supports_color - - # Disable colors to override any CI environment settings - monkeypatch.setenv("NO_COLOR", "1") - - # Mock stdout without isatty - class MockStdout: - pass - - original_stdout = sys.stdout - try: - sys.stdout = MockStdout() - assert _supports_color() is False - finally: - sys.stdout = original_stdout - - def test_supports_color_isatty_false(self, monkeypatch): - """Test color detection when isatty returns False.""" - import sys - - from sparkwheel.errors.formatters import _supports_color - - # Disable colors to override any CI environment settings - monkeypatch.setenv("NO_COLOR", "1") - - # Mock stdout with isatty that returns False - class MockStdout: - def isatty(self): - return False - - original_stdout = sys.stdout - try: - sys.stdout = MockStdout() - assert _supports_color() is False - finally: - sys.stdout = original_stdout - - class TestConfigKeyErrorEnhanced: """Test enhanced ConfigKeyError with suggestions.""" @@ -642,17 +426,17 @@ class TestExceptionEdgeCases: """Test edge cases in exception handling.""" def test_source_location_without_id(self): - """Test SourceLocation string formatting without ID.""" - from sparkwheel.utils.exceptions import SourceLocation + """Test Location string formatting without ID.""" + from sparkwheel.utils.exceptions import Location - loc = SourceLocation(filepath="test.yaml", line=10) + loc = Location(filepath="test.yaml", line=10) assert str(loc) == "test.yaml:10" def test_base_error_without_source_location_id(self): """Test BaseError formatting when source_location has no ID.""" - from sparkwheel.utils.exceptions import BaseError, SourceLocation + from sparkwheel.utils.exceptions import BaseError, Location - loc = SourceLocation(filepath="test.yaml", line=5) + loc = Location(filepath="test.yaml", line=5) error = BaseError("Test error", source_location=loc) msg = str(error) assert "[test.yaml:5]" in msg @@ -660,10 +444,10 @@ def test_base_error_without_source_location_id(self): def test_base_error_snippet_file_read_error(self, tmp_path): """Test BaseError snippet handling when file can't be read.""" - from sparkwheel.utils.exceptions import BaseError, SourceLocation + from sparkwheel.utils.exceptions import BaseError, Location # Create a source location pointing to non-existent file - loc = SourceLocation(filepath="/nonexistent/file.yaml", line=5) + loc = Location(filepath="/nonexistent/file.yaml", line=5) error = BaseError("Test error", source_location=loc) # Should not raise, just skip snippet msg = str(error) diff --git a/tests/test_items.py b/tests/test_items.py index 0a664f6..bbd07da 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -193,7 +193,7 @@ def test_instantiate_instantiation_error(self): config = {"_target_": "builtins.int", "invalid_arg": 123} component = Component(config=config) - with pytest.raises(InstantiationError, match="Failed to instantiate"): + with pytest.raises(InstantiationError): component.instantiate() def test_instantiate_with_kwargs_override(self): @@ -287,7 +287,7 @@ def test_instantiate_args_not_in_non_arg_keys(self): """Test that _args_ is properly excluded from kwargs.""" config = {"_target_": "builtins.dict", "_args_": [], "a": 1} component = Component(config=config) - args, kwargs = component.resolve_args() + _, kwargs = component.resolve_args() # Verify _args_ is not passed as a kwarg assert "_args_" not in kwargs @@ -454,14 +454,14 @@ def test_evaluate_debug_mode_enabled(self): sparkwheel.utils.run_debug = original_debug -class TestItemWithSourceLocation: +class TestItemWithLocation: """Test Item classes with source location tracking.""" def test_item_with_source_location(self): """Test creating item with source location.""" - from sparkwheel.utils.exceptions import SourceLocation + from sparkwheel.utils.exceptions import Location - location = SourceLocation(filepath="/tmp/config.yaml", line=10, column=5, id="model::lr") + location = Location(filepath="/tmp/config.yaml", line=10, column=5, id="model::lr") item = Item(config={"lr": 0.001}, id="model::lr", source_location=location) assert item.source_location == location @@ -470,9 +470,9 @@ def test_item_with_source_location(self): def test_component_error_includes_source_location(self): """Test that component errors include source location.""" - from sparkwheel.utils.exceptions import SourceLocation + from sparkwheel.utils.exceptions import Location - location = SourceLocation(filepath="/tmp/config.yaml", line=15, column=2, id="model") + location = Location(filepath="/tmp/config.yaml", line=15, column=2, id="model") config = {"_target_": "nonexistent.Module"} component = Component(config=config, id="model", source_location=location) @@ -484,9 +484,9 @@ def test_component_error_includes_source_location(self): def test_expression_error_includes_source_location(self): """Test that expression errors include source location.""" - from sparkwheel.utils.exceptions import SourceLocation + from sparkwheel.utils.exceptions import Location - location = SourceLocation(filepath="/tmp/config.yaml", line=20, column=2, id="calc") + location = Location(filepath="/tmp/config.yaml", line=20, column=2, id="calc") expr = Expression(config="$undefined_var", id="calc", source_location=location) with pytest.raises(EvaluationError) as exc_info: @@ -495,6 +495,47 @@ def test_expression_error_includes_source_location(self): error = exc_info.value assert error.source_location == location + def test_instantiation_error_with_location_reraises_unchanged(self): + """Test that InstantiationError with source_location is re-raised unchanged. + + This tests the re-raise path in Component.instantiate() where an InstantiationError + already has source_location set and should be re-raised as-is without modification. + """ + from unittest.mock import patch + + from sparkwheel.utils.exceptions import Location + + # Create a location for the inner error (different from component's location) + inner_location = Location(filepath="/tmp/inner.yaml", line=99, column=1, id="inner_id") + component_location = Location(filepath="/tmp/outer.yaml", line=5, column=2, id="outer_id") + + # Create the exception to be raised + inner_error = InstantiationError( + "Inner error message", + source_location=inner_location, + suggestion="Inner suggestion", + ) + + # Mock the instantiate function to raise InstantiationError with source_location + def mock_instantiate(*args, **kwargs): + raise inner_error + + config = {"_target_": "some.module.Class"} + component = Component(config=config, id="test", source_location=component_location) + + with patch("sparkwheel.items.instantiate", mock_instantiate): + with pytest.raises(InstantiationError) as exc_info: + component.instantiate() + + error = exc_info.value + # The exception should be re-raised unchanged (not wrapped with component's location) + assert error is inner_error + assert error._original_message == "Inner error message" + assert error.suggestion == "Inner suggestion" + assert error.source_location == inner_location + assert error.source_location.filepath == "/tmp/inner.yaml" + assert error.source_location.line == 99 + class TestItemsEdgeCases: """Test edge cases in items module.""" diff --git a/tests/test_loader.py b/tests/test_loader.py index 7b40811..7ec67dd 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -3,8 +3,8 @@ import pytest from sparkwheel.loader import Loader -from sparkwheel.metadata import MetadataRegistry -from sparkwheel.utils.exceptions import SourceLocation +from sparkwheel.locations import LocationRegistry +from sparkwheel.utils.exceptions import Location class TestLoaderBasic: @@ -19,7 +19,7 @@ def test_load_file_basic(self, tmp_path): config, metadata = loader.load_file(str(config_file)) assert config == {"key": "value", "number": 42} - assert isinstance(metadata, MetadataRegistry) + assert isinstance(metadata, LocationRegistry) def test_load_file_empty_filepath(self): """Test loading with empty filepath returns empty config.""" @@ -27,7 +27,7 @@ def test_load_file_empty_filepath(self): config, metadata = loader.load_file("") assert config == {} - assert isinstance(metadata, MetadataRegistry) + assert isinstance(metadata, LocationRegistry) def test_load_file_none_filepath(self): """Test loading with None filepath returns empty config.""" @@ -35,7 +35,7 @@ def test_load_file_none_filepath(self): config, metadata = loader.load_file(None) assert config == {} - assert isinstance(metadata, MetadataRegistry) + assert isinstance(metadata, LocationRegistry) def test_load_file_non_yaml_extension(self, tmp_path): """Test loading non-YAML file raises ValueError.""" @@ -136,7 +136,7 @@ def test_load_files_empty_list(self): config, metadata = loader.load_files([]) assert config == {} - assert isinstance(metadata, MetadataRegistry) + assert isinstance(metadata, LocationRegistry) def test_load_files_single_file(self, tmp_path): """Test loading single file via load_files.""" @@ -236,18 +236,18 @@ def test_strip_metadata_from_dicts_in_lists(self, tmp_path): assert "__sparkwheel_metadata__" not in str(config) -class TestMetadataRegistry: - """Test MetadataRegistry functionality.""" +class TestLocationRegistry: + """Test LocationRegistry functionality.""" def test_create_registry(self): """Test creating empty registry.""" - registry = MetadataRegistry() + registry = LocationRegistry() assert len(registry) == 0 def test_register_and_get(self): """Test registering and getting source locations.""" - registry = MetadataRegistry() - location = SourceLocation(filepath="config.yaml", line=10, column=5, id="model::lr") + registry = LocationRegistry() + location = Location(filepath="config.yaml", line=10, column=5, id="model::lr") registry.register("model::lr", location) @@ -257,13 +257,13 @@ def test_register_and_get(self): def test_get_nonexistent_returns_none(self): """Test getting nonexistent location returns None.""" - registry = MetadataRegistry() + registry = LocationRegistry() assert registry.get("nonexistent") is None def test_len(self): """Test registry length.""" - registry = MetadataRegistry() - location = SourceLocation(filepath="config.yaml", line=10, column=5, id="key") + registry = LocationRegistry() + location = Location(filepath="config.yaml", line=10, column=5, id="key") assert len(registry) == 0 @@ -275,8 +275,8 @@ def test_len(self): def test_contains(self): """Test __contains__ operator.""" - registry = MetadataRegistry() - location = SourceLocation(filepath="config.yaml", line=10, column=5, id="key") + registry = LocationRegistry() + location = Location(filepath="config.yaml", line=10, column=5, id="key") assert "key" not in registry @@ -285,11 +285,11 @@ def test_contains(self): def test_merge(self): """Test merging registries.""" - registry1 = MetadataRegistry() - registry2 = MetadataRegistry() + registry1 = LocationRegistry() + registry2 = LocationRegistry() - location1 = SourceLocation(filepath="file1.yaml", line=5, column=2, id="key1") - location2 = SourceLocation(filepath="file2.yaml", line=10, column=3, id="key2") + location1 = Location(filepath="file1.yaml", line=5, column=2, id="key1") + location2 = Location(filepath="file2.yaml", line=10, column=3, id="key2") registry1.register("key1", location1) registry2.register("key2", location2) @@ -302,8 +302,8 @@ def test_merge(self): def test_copy(self): """Test copying registry.""" - registry = MetadataRegistry() - location = SourceLocation(filepath="config.yaml", line=10, column=5, id="key") + registry = LocationRegistry() + location = Location(filepath="config.yaml", line=10, column=5, id="key") registry.register("key", location) @@ -313,7 +313,7 @@ def test_copy(self): assert copied.get("key") == location # Ensure it's a real copy, not a reference - location2 = SourceLocation(filepath="other.yaml", line=20, column=1, id="other") + location2 = Location(filepath="other.yaml", line=20, column=1, id="other") copied.register("other", location2) assert "other" in copied diff --git a/tests/test_operators.py b/tests/test_operators.py new file mode 100644 index 0000000..93ee68f --- /dev/null +++ b/tests/test_operators.py @@ -0,0 +1,395 @@ +"""Comprehensive tests for operators module. + +This module tests the operators.py module functionality: +- MergeContext class and location tracking +- Operator validation (_validate_delete_operator) +- apply_operators function with all edge cases +""" + +import pytest + +from sparkwheel.locations import LocationRegistry +from sparkwheel.operators import MergeContext, _validate_delete_operator, apply_operators, validate_operators +from sparkwheel.utils.exceptions import ConfigMergeError, Location + + +class TestMergeContext: + """Test MergeContext class for tracking merge operations.""" + + def test_child_path_empty_base(self): + """Test creating child path from empty base.""" + ctx = MergeContext() + child = ctx.child_path("model") + assert child.current_path == "model" + + def test_child_path_with_base(self): + """Test creating child path with existing base.""" + ctx = MergeContext(current_path="model") + child = ctx.child_path("optimizer") + assert child.current_path == "model::optimizer" + + def test_child_path_nested(self): + """Test creating deeply nested child paths.""" + ctx = MergeContext() + child1 = ctx.child_path("model") + child2 = child1.child_path("optimizer") + child3 = child2.child_path("lr") + assert child3.current_path == "model::optimizer::lr" + + def test_get_source_location_no_registry(self): + """Test get_source_location returns None when no registry.""" + ctx = MergeContext() + location = ctx.get_source_location("key") + assert location is None + + def test_get_source_location_with_registry(self): + """Test get_source_location with registry.""" + registry = LocationRegistry() + test_location = Location(filepath="config.yaml", line=10) + registry.register("model::lr", test_location) + + ctx = MergeContext(locations=registry, current_path="model") + location = ctx.get_source_location("lr") + assert location == test_location + + def test_get_source_location_with_remove_operator(self): + """Test get_source_location strips ~ operator prefix.""" + registry = LocationRegistry() + test_location = Location(filepath="config.yaml", line=10) + registry.register("model::lr", test_location) + + ctx = MergeContext(locations=registry, current_path="model") + # Should strip ~ and find "lr" + location = ctx.get_source_location("~lr") + assert location == test_location + + def test_get_source_location_with_replace_operator(self): + """Test get_source_location strips = operator prefix.""" + registry = LocationRegistry() + test_location = Location(filepath="config.yaml", line=10) + registry.register("model::lr", test_location) + + ctx = MergeContext(locations=registry, current_path="model") + # Should strip = and find "lr" + location = ctx.get_source_location("=lr") + assert location == test_location + + def test_get_source_location_operator_key_takes_precedence(self): + """Test that exact operator key match takes precedence.""" + registry = LocationRegistry() + exact_location = Location(filepath="config.yaml", line=5) + fallback_location = Location(filepath="config.yaml", line=10) + + # Register both ~lr and lr + registry.register("model::~lr", exact_location) + registry.register("model::lr", fallback_location) + + ctx = MergeContext(locations=registry, current_path="model") + # Should find exact match first + location = ctx.get_source_location("~lr") + assert location == exact_location + + def test_get_source_location_not_found(self): + """Test get_source_location returns None for missing keys.""" + registry = LocationRegistry() + ctx = MergeContext(locations=registry, current_path="model") + location = ctx.get_source_location("missing_key") + assert location is None + + def test_get_source_location_empty_path(self): + """Test get_source_location with empty current path.""" + registry = LocationRegistry() + test_location = Location(filepath="config.yaml", line=10) + registry.register("lr", test_location) + + ctx = MergeContext(locations=registry, current_path="") + location = ctx.get_source_location("lr") + assert location == test_location + + def test_get_source_location_operator_key_not_found(self): + """Test get_source_location returns None when operator key not found.""" + registry = LocationRegistry() + # Register only "other", not "lr" or "~lr" + test_location = Location(filepath="config.yaml", line=10) + registry.register("model::other", test_location) + + ctx = MergeContext(locations=registry, current_path="model") + # Should not find ~lr or lr + location = ctx.get_source_location("~lr") + assert location is None + + def test_get_source_location_regular_key_not_found(self): + """Test get_source_location returns None for regular key not found.""" + registry = LocationRegistry() + test_location = Location(filepath="config.yaml", line=10) + registry.register("model::other", test_location) + + ctx = MergeContext(locations=registry, current_path="model") + # Should not find "missing" (regular key, no operator) + location = ctx.get_source_location("missing") + assert location is None + + +class TestValidateDeleteOperator: + """Test _validate_delete_operator function.""" + + def test_valid_null_value(self): + """Test that null value is valid.""" + _validate_delete_operator("key", None) # Should not raise + + def test_valid_empty_string(self): + """Test that empty string is valid.""" + _validate_delete_operator("key", "") # Should not raise + + def test_valid_list_value(self): + """Test that list value is valid.""" + _validate_delete_operator("key", [0, 1, 2]) # Should not raise + + def test_invalid_string_value(self): + """Test that non-empty string value raises error.""" + with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"): + _validate_delete_operator("key", "invalid") + + def test_invalid_dict_value(self): + """Test that dict value raises error.""" + with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"): + _validate_delete_operator("key", {"nested": "value"}) + + def test_invalid_int_value(self): + """Test that int value raises error.""" + with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"): + _validate_delete_operator("key", 123) + + def test_empty_list_raises_error(self): + """Test that empty list raises error.""" + with pytest.raises(ConfigMergeError, match="cannot be empty"): + _validate_delete_operator("key", []) + + +class TestValidateOperators: + """Test validate_operators function.""" + + def test_validate_non_dict_config(self): + """Test that non-dict config is handled gracefully.""" + # Should not raise for non-dict + validate_operators("not a dict") # type: ignore[arg-type] + validate_operators(123) # type: ignore[arg-type] + validate_operators(None) # type: ignore[arg-type] + + def test_validate_remove_operator(self): + """Test validation of remove operator.""" + config = {"~key": None} + validate_operators(config) # Should not raise + + def test_validate_remove_operator_invalid_value(self): + """Test validation catches invalid remove operator value.""" + config = {"~key": "invalid"} + with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"): + validate_operators(config) + + def test_validate_replace_operator(self): + """Test validation of replace operator.""" + config = {"=key": "value"} + validate_operators(config) # Should not raise + + def test_validate_nested_remove_operator(self): + """Test validation of nested remove operator.""" + config = {"model": {"~lr": None}} + validate_operators(config) # Should not raise + + def test_validate_nested_remove_operator_invalid(self): + """Test validation catches invalid nested remove operator.""" + config = {"model": {"~lr": 123}} + with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"): + validate_operators(config) + + def test_validate_skips_dict_under_remove(self): + """Test that validation catches dict value under remove operator.""" + # A dict value under a remove operator is invalid + # Remove operators must have null, empty, or list values + config = {"~key": {"nested": "value"}} + with pytest.raises(ConfigMergeError, match="must have null, empty, or list value"): + validate_operators(config) + + +class TestApplyOperatorsEdgeCases: + """Test edge cases in apply_operators function.""" + + def test_non_dict_base_returns_override(self): + """Test that non-dict base returns deepcopy of override.""" + result = apply_operators("not a dict", {"key": "value"}) # type: ignore[arg-type] + assert result == {"key": "value"} + + def test_non_dict_override_returns_override(self): + """Test that non-dict override returns deepcopy of override.""" + result = apply_operators({"key": "value"}, "not a dict") # type: ignore[arg-type] + assert result == "not a dict" + + def test_both_non_dict_returns_override(self): + """Test that both non-dict returns override.""" + result = apply_operators("base", "override") # type: ignore[arg-type] + assert result == "override" + + def test_non_string_key_copied_directly(self): + """Test that non-string keys are copied directly.""" + base = {} + override = {123: "numeric_key"} # type: ignore[dict-item] + result = apply_operators(base, override) + assert result[123] == "numeric_key" # type: ignore[index] + + def test_context_propagates_to_nested_merges(self): + """Test that context is properly propagated in nested merges.""" + registry = LocationRegistry() + ctx = MergeContext(locations=registry) + + base = {"model": {"lr": 0.001}} + override = {"model": {"dropout": 0.1}} + + result = apply_operators(base, override, context=ctx) + assert result == {"model": {"lr": 0.001, "dropout": 0.1}} + + def test_delete_with_context_location_tracking(self): + """Test that delete operator uses context for error messages.""" + registry = LocationRegistry() + test_location = Location(filepath="config.yaml", line=10) + registry.register("~missing", test_location) + + ctx = MergeContext(locations=registry) + + base = {"existing": "value"} + override = {"~missing": None} + + with pytest.raises(ConfigMergeError, match="Cannot delete key 'missing'"): + apply_operators(base, override, context=ctx) + + def test_delete_list_items_no_context(self): + """Test delete list items works without context.""" + base = {"items": [1, 2, 3, 4, 5]} + override = {"~items": [0, 2, 4]} + result = apply_operators(base, override) + assert result == {"items": [2, 4]} + + def test_delete_dict_keys_with_context(self): + """Test delete dict keys with context for better error messages.""" + registry = LocationRegistry() + test_location = Location(filepath="config.yaml", line=10) + registry.register("~model", test_location) + + ctx = MergeContext(locations=registry) + + base = {"model": {"lr": 0.001, "dropout": 0.1}} + override = {"~model": ["missing_key"]} + + with pytest.raises(ConfigMergeError, match="Cannot remove non-existent key 'missing_key'"): + apply_operators(base, override, context=ctx) + + def test_replace_operator_with_none_value(self): + """Test replace operator can set value to None.""" + base = {"key": "value"} + override = {"=key": None} + result = apply_operators(base, override) + assert result == {"key": None} + + def test_composition_list_extend_preserves_order(self): + """Test that list composition preserves order.""" + base = {"items": [1, 2, 3]} + override = {"items": [4, 5, 6]} + result = apply_operators(base, override) + assert result == {"items": [1, 2, 3, 4, 5, 6]} + + def test_composition_dict_merge_deep(self): + """Test that dict composition merges deeply.""" + base = {"model": {"optimizer": {"lr": 0.001, "momentum": 0.9}}} + override = {"model": {"optimizer": {"lr": 0.01}}} + result = apply_operators(base, override) + assert result == {"model": {"optimizer": {"lr": 0.01, "momentum": 0.9}}} + + def test_scalar_replacement_on_type_mismatch(self): + """Test that type mismatches cause replacement.""" + base = {"value": [1, 2, 3]} + override = {"value": "string"} + result = apply_operators(base, override) + assert result == {"value": "string"} + + def test_new_key_addition(self): + """Test adding new keys to config.""" + base = {"existing": "value"} + override = {"new_key": "new_value"} + result = apply_operators(base, override) + assert result == {"existing": "value", "new_key": "new_value"} + + def test_delete_list_negative_indices_normalized(self): + """Test that negative indices are properly normalized.""" + base = {"items": [1, 2, 3, 4, 5]} + override = {"~items": [-1, -2]} # Remove last two items + result = apply_operators(base, override) + assert result == {"items": [1, 2, 3]} + + def test_delete_list_duplicate_indices_handled(self): + """Test that duplicate indices are handled correctly.""" + base = {"items": [1, 2, 3, 4, 5]} + override = {"~items": [1, 1, 1]} # Duplicate index + result = apply_operators(base, override) + assert result == {"items": [1, 3, 4, 5]} # Only removed once + + def test_delete_items_errors_on_scalar(self): + """Test that deleting items from scalar raises error.""" + base = {"value": "scalar"} + override = {"~value": [0]} + + with pytest.raises(ConfigMergeError, match="expected list or dict"): + apply_operators(base, override) + + def test_multiple_operators_in_one_override(self): + """Test multiple operators in single override dict.""" + base = {"a": 1, "b": 2, "c": 3, "d": {"x": 1, "y": 2}} + override = { + "a": 10, # Compose (replace scalar) + "=b": 20, # Explicit replace + "~c": None, # Delete + "d": {"x": 10}, # Compose dict (merge) + } + result = apply_operators(base, override) + assert result == {"a": 10, "b": 20, "d": {"x": 10, "y": 2}} + + +class TestApplyOperatorsDeepCopy: + """Test that apply_operators properly deep copies values.""" + + def test_base_not_mutated(self): + """Test that base dict is not mutated.""" + base = {"model": {"lr": 0.001}} + override = {"model": {"dropout": 0.1}} + result = apply_operators(base, override) + + # Modify result + result["model"]["lr"] = 0.01 + + # Base should be unchanged + assert base["model"]["lr"] == 0.001 + + def test_override_not_mutated(self): + """Test that override dict is not mutated.""" + base = {"model": {"lr": 0.001}} + override = {"model": {"dropout": 0.1}} + result = apply_operators(base, override) + + # Modify result + result["model"]["dropout"] = 0.5 + + # Override should be unchanged + assert override["model"]["dropout"] == 0.1 + + def test_result_is_independent(self): + """Test that result is independent of base and override.""" + base = {"items": [1, 2, 3]} + override = {"items": [4, 5]} + result = apply_operators(base, override) + + # Modify result + result["items"].append(6) + + # Base and override should be unchanged + assert base["items"] == [1, 2, 3] + assert override["items"] == [4, 5] + assert result["items"] == [1, 2, 3, 4, 5, 6] diff --git a/tests/test_path_utils.py b/tests/test_path_utils.py index b90f607..a0c69ac 100644 --- a/tests/test_path_utils.py +++ b/tests/test_path_utils.py @@ -1,6 +1,8 @@ """Tests for path utility functions.""" -from sparkwheel.path_utils import PathPatterns, find_references +import pytest + +from sparkwheel.path_utils import PathPatterns, find_references, get_by_id class TestPathPatterns: @@ -28,3 +30,131 @@ def test_find_references_empty_for_plain_text(self): """Test find_references returns empty for plain text.""" refs = find_references("just plain text") assert refs == [] + + +class TestGetById: + """Test get_by_id function for navigating config structures.""" + + def test_empty_id_returns_whole_config(self): + """Test get_by_id with empty ID returns whole config.""" + config = {"key": "value", "nested": {"item": 123}} + result = get_by_id(config, "") + + assert result == config + + def test_list_indexing(self): + """Test get_by_id with list indexing.""" + config = {"items": [10, 20, 30]} + result = get_by_id(config, "items::1") + + assert result == 20 + + def test_nested_list(self): + """Test get_by_id with nested structures including lists.""" + config = {"data": {"values": [{"x": 1}, {"x": 2}, {"x": 3}]}} + result = get_by_id(config, "data::values::2::x") + + assert result == 3 + + def test_type_error_on_primitive(self): + """Test get_by_id raises TypeError when trying to index a primitive value.""" + config = {"value": 42} + + with pytest.raises(TypeError, match="Cannot index int"): + get_by_id(config, "value::subkey") + + def test_missing_key_first_level(self): + """Test get_by_id with missing key at first level shows non-redundant error.""" + config = {"foo": 1, "bar": 2} + + with pytest.raises(KeyError) as exc_info: + get_by_id(config, "missing") + + error_msg = str(exc_info.value) + assert "Key 'missing' not found" in error_msg + assert "Available keys:" in error_msg + assert "'foo'" in error_msg + assert "'bar'" in error_msg + # Should NOT say "at path 'missing'" for first level + assert "at path" not in error_msg + + def test_missing_key_nested(self): + """Test get_by_id with missing nested key shows parent path.""" + config = {"data": {"train": {"lr": 0.001, "epochs": 10}}} + + with pytest.raises(KeyError) as exc_info: + get_by_id(config, "data::train::missing") + + error_msg = str(exc_info.value) + assert "Key 'missing' not found in 'data::train'" in error_msg + assert "Available keys:" in error_msg + assert "'lr'" in error_msg + assert "'epochs'" in error_msg + + def test_invalid_list_index(self): + """Test get_by_id with invalid list index.""" + config = {"items": [1, 2, 3]} + + with pytest.raises(KeyError) as exc_info: + get_by_id(config, "items::10") + + error_msg = str(exc_info.value) + assert "List index '10' out of range" in error_msg + assert "in 'items'" in error_msg + + def test_invalid_list_index_first_level(self): + """Test get_by_id with invalid list index at first level.""" + config = [1, 2, 3] + + with pytest.raises(KeyError) as exc_info: + get_by_id(config, "10") + + error_msg = str(exc_info.value) + assert "List index '10' out of range" in error_msg + # Should not mention parent path for first level + assert "in '" not in error_msg + + def test_invalid_list_index_non_integer(self): + """Test get_by_id with non-integer list index.""" + config = {"items": [1, 2, 3]} + + with pytest.raises(KeyError) as exc_info: + get_by_id(config, "items::abc") + + error_msg = str(exc_info.value) + assert "Invalid list index 'abc'" in error_msg + assert "not an integer" in error_msg + + def test_type_error_first_level(self): + """Test type error at first level shows clean message.""" + config = "string_value" + + with pytest.raises(TypeError) as exc_info: + get_by_id(config, "foo") + + error_msg = str(exc_info.value) + assert "Cannot index str with key 'foo'" in error_msg + # Should not mention parent path for first level + assert "in '" not in error_msg + + def test_type_error_nested(self): + """Test type error in nested path shows parent.""" + config = {"data": {"value": 42}} + + with pytest.raises(TypeError) as exc_info: + get_by_id(config, "data::value::foo") + + error_msg = str(exc_info.value) + assert "Cannot index int with key 'foo'" in error_msg + assert "in 'data::value'" in error_msg + + def test_available_keys_truncated(self): + """Test that available keys list is truncated when > 10 keys.""" + config = {f"key_{i}": i for i in range(20)} + + with pytest.raises(KeyError) as exc_info: + get_by_id(config, "missing") + + error_msg = str(exc_info.value) + assert "Available keys:" in error_msg + assert "..." in error_msg # Should be truncated diff --git a/tests/test_preprocessor.py b/tests/test_preprocessor.py index f33d0f7..121ad38 100644 --- a/tests/test_preprocessor.py +++ b/tests/test_preprocessor.py @@ -4,6 +4,7 @@ from sparkwheel.loader import Loader from sparkwheel.preprocessor import Preprocessor +from sparkwheel.utils.exceptions import CircularReferenceError, ConfigKeyError class TestPreprocessor: @@ -19,35 +20,247 @@ def test_circular_raw_reference(self, tmp_path): preprocessor = Preprocessor(loader) # Load the config and try to process it - config, _ = loader.load_file(str(config_file)) + config, locations = loader.load_file(str(config_file)) - with pytest.raises(ValueError, match="Circular raw reference detected"): - preprocessor.process(config, config) + with pytest.raises(CircularReferenceError, match="Circular raw reference detected"): + preprocessor.process_raw_refs(config, config, locations=locations) - def test_get_by_id_empty_id(self): - """Test _get_by_id with empty ID returns whole config.""" - config = {"key": "value", "nested": {"item": 123}} - result = Preprocessor._get_by_id(config, "") + def test_raw_ref_missing_key_with_location(self, tmp_path): + """Test raw reference error includes source location.""" + # Create a config file with a raw reference to a missing key + config_file = tmp_path / "config.yaml" + config_file.write_text('value: "%missing::key"') + + loader = Loader() + preprocessor = Preprocessor(loader) + + config, locations = loader.load_file(str(config_file)) + + with pytest.raises(ConfigKeyError) as exc_info: + preprocessor.process_raw_refs(config, config, locations=locations) + + error = exc_info.value + assert error.source_location is not None + assert error.source_location.filepath == str(config_file) + assert error.source_location.line == 1 + assert "Error resolving raw reference" in error._original_message + assert "Key 'missing' not found" in error._original_message + + def test_raw_ref_external_file_missing_key(self, tmp_path): + """Test raw reference to external file with missing key.""" + # Create external file + external_file = tmp_path / "external.yaml" + external_file.write_text("foo: 1\nbar: 2") + + # Create main config that references missing key in external file + config_file = tmp_path / "config.yaml" + config_file.write_text(f'value: "%{external_file}::missing"') + + loader = Loader() + preprocessor = Preprocessor(loader) + + config, locations = loader.load_file(str(config_file)) + + with pytest.raises(ConfigKeyError) as exc_info: + preprocessor.process_raw_refs(config, config, locations=locations) + + error = exc_info.value + assert error.source_location is not None + assert error.source_location.filepath == str(config_file) + assert f"from '{external_file}'" in error._original_message + assert "Key 'missing' not found" in error._original_message + + def test_raw_ref_nested_missing_key(self, tmp_path): + """Test raw reference with nested path where middle key is missing.""" + config_file = tmp_path / "config.yaml" + config_file.write_text('data:\n foo: 1\nvalue: "%data::missing::key"') + + loader = Loader() + preprocessor = Preprocessor(loader) + + config, locations = loader.load_file(str(config_file)) + + with pytest.raises(ConfigKeyError) as exc_info: + preprocessor.process_raw_refs(config, config, locations=locations) + + error = exc_info.value + assert "Key 'missing' not found in 'data'" in error._original_message + + def test_circular_reference_with_location(self, tmp_path): + """Test circular reference error includes source location.""" + config_file = tmp_path / "config.yaml" + config_file.write_text('a: "%b"\nb: "%a"') + + loader = Loader() + preprocessor = Preprocessor(loader) + + config, locations = loader.load_file(str(config_file)) + + with pytest.raises(CircularReferenceError) as exc_info: + preprocessor.process_raw_refs(config, config, locations=locations) + + error = exc_info.value + assert error.source_location is not None + assert error.source_location.filepath == str(config_file) + assert "Reference chain:" in error._original_message + + def test_raw_ref_expansion_success(self, tmp_path): + """Test successful raw reference expansion.""" + config_file = tmp_path / "config.yaml" + config_file.write_text('base_lr: 0.001\nmodel:\n lr: "%base_lr"') + + loader = Loader() + preprocessor = Preprocessor(loader) + + config, locations = loader.load_file(str(config_file)) + result = preprocessor.process_raw_refs(config, config, locations=locations) + + assert result["model"]["lr"] == 0.001 + + def test_raw_ref_nested_expansion(self, tmp_path): + """Test nested raw reference expansion.""" + config_file = tmp_path / "config.yaml" + config_file.write_text('a:\n b:\n c: 42\nvalue: "%a::b::c"') + + loader = Loader() + preprocessor = Preprocessor(loader) + + config, locations = loader.load_file(str(config_file)) + result = preprocessor.process_raw_refs(config, config, locations=locations) - assert result == config + assert result["value"] == 42 - def test_get_by_id_list_indexing(self): - """Test _get_by_id with list indexing.""" - config = {"items": [10, 20, 30]} - result = Preprocessor._get_by_id(config, "items::1") + def test_raw_ref_external_file_success(self, tmp_path): + """Test successful raw reference from external file.""" + external_file = tmp_path / "base.yaml" + external_file.write_text("learning_rate: 0.001") - assert result == 20 + config_file = tmp_path / "config.yaml" + config_file.write_text(f'lr: "%{external_file}::learning_rate"') + + loader = Loader() + preprocessor = Preprocessor(loader) + + config, locations = loader.load_file(str(config_file)) + result = preprocessor.process_raw_refs(config, config, locations=locations) + + assert result["lr"] == 0.001 + + def test_raw_ref_in_list(self, tmp_path): + """Test raw reference inside a list item is expanded.""" + config_file = tmp_path / "config.yaml" + config_file.write_text('base_value: 42\nitems:\n - "%base_value"\n - 100') + + loader = Loader() + preprocessor = Preprocessor(loader) + + config, locations = loader.load_file(str(config_file)) + result = preprocessor.process_raw_refs(config, config, locations=locations) + + assert result["items"][0] == 42 + assert result["items"][1] == 100 + + +class TestPreprocessorExternalOnly: + """Test external_only parameter for two-phase raw reference expansion.""" + + def test_external_only_expands_external_refs(self, tmp_path): + """Test that external_only=True expands external file refs.""" + external_file = tmp_path / "external.yaml" + external_file.write_text("value: 42") + + loader = Loader() + preprocessor = Preprocessor(loader) + + config = {"external_ref": f"%{external_file}::value", "local_ref": "%local_key", "local_key": 100} + + result = preprocessor.process_raw_refs(config, config, external_only=True) + + # External ref should be expanded + assert result["external_ref"] == 42 + # Local ref should remain as string + assert result["local_ref"] == "%local_key" + # Local key unchanged + assert result["local_key"] == 100 + + def test_external_only_false_expands_all_refs(self, tmp_path): + """Test that external_only=False expands all refs including local.""" + external_file = tmp_path / "external.yaml" + external_file.write_text("value: 42") - def test_get_by_id_nested_list(self): - """Test _get_by_id with nested structures including lists.""" - config = {"data": {"values": [{"x": 1}, {"x": 2}, {"x": 3}]}} - result = Preprocessor._get_by_id(config, "data::values::2::x") + loader = Loader() + preprocessor = Preprocessor(loader) + + config = {"external_ref": f"%{external_file}::value", "local_ref": "%local_key", "local_key": 100} + + result = preprocessor.process_raw_refs(config, config, external_only=False) + + # Both should be expanded + assert result["external_ref"] == 42 + assert result["local_ref"] == 100 + + def test_two_phase_expansion_with_override(self, tmp_path): + """Test that two-phase expansion allows overrides to affect local refs.""" + external_file = tmp_path / "external.yaml" + external_file.write_text("external_value: 1") + + loader = Loader() + preprocessor = Preprocessor(loader) + + # Initial config with both external and local refs + config = { + "external_ref": f"%{external_file}::external_value", + "local_ref": "%vars::value", + "vars": {"value": None}, # Will be overridden + } + + # Phase 1: Expand only external refs + config = preprocessor.process_raw_refs(config, config, external_only=True) + assert config["external_ref"] == 1 + assert config["local_ref"] == "%vars::value" # Still string + + # Simulate CLI override + config["vars"]["value"] = "/data/features.npz" + + # Phase 2: Expand local refs (now sees override) + config = preprocessor.process_raw_refs(config, config, external_only=False) + assert config["local_ref"] == "/data/features.npz" + + def test_nested_local_refs_expanded_together(self, tmp_path): + """Test that nested local refs are all expanded in phase 2.""" + loader = Loader() + preprocessor = Preprocessor(loader) + + config = {"a": {"b": {"c": 42}}, "ref_to_b": "%a::b", "ref_to_c": "%a::b::c"} + + # Phase 1: Nothing to expand (no external refs) + result = preprocessor.process_raw_refs(config, config, external_only=True) + assert result["ref_to_b"] == "%a::b" + assert result["ref_to_c"] == "%a::b::c" + + # Phase 2: Expand all local refs + result = preprocessor.process_raw_refs(result, result, external_only=False) + assert result["ref_to_b"] == {"c": 42} + assert result["ref_to_c"] == 42 + + def test_external_ref_within_local_ref_expanded_correctly(self, tmp_path): + """Test that external refs within locally-referenced values are expanded.""" + external_file = tmp_path / "external.yaml" + external_file.write_text("nested:\n value: 99") + + loader = Loader() + preprocessor = Preprocessor(loader) - assert result == 3 + config = { + "template": {"external": f"%{external_file}::nested"}, + "copy": "%template", + } - def test_get_by_id_type_error_on_primitive(self): - """Test _get_by_id raises TypeError when trying to index a primitive value.""" - config = {"value": 42} + # Phase 1: Expand external ref inside template + result = preprocessor.process_raw_refs(config, config, external_only=True) + assert result["template"]["external"] == {"value": 99} + assert result["copy"] == "%template" # Local ref still string - with pytest.raises(TypeError, match="Cannot index int"): - Preprocessor._get_by_id(config, "value::subkey") + # Phase 2: Expand local ref - should get the already-expanded template + result = preprocessor.process_raw_refs(result, result, external_only=False) + assert result["copy"] == {"external": {"value": 99}} diff --git a/uv.lock b/uv.lock index 43c5e30..08e3a48 100644 --- a/uv.lock +++ b/uv.lock @@ -1714,7 +1714,7 @@ wheels = [ [[package]] name = "sparkwheel" -version = "0.0.7" +version = "0.0.8" source = { editable = "." } dependencies = [ { name = "pyyaml" },