diff --git a/.github/workflows/check_pull_request_title.yml b/.github/workflows/check_pull_request_title.yml index 891264f..2c68fbc 100644 --- a/.github/workflows/check_pull_request_title.yml +++ b/.github/workflows/check_pull_request_title.yml @@ -43,5 +43,5 @@ jobs: Invalid: "add feature", "Adds feature", "Add.", "Fix" # Disable type prefixes (we don't use conventional commits format) requireScope: false - ignoreLabels: - - ignore-title-check + ignoreLabels: | + ignore-title-check diff --git a/README.md b/README.md index f44e8a9..118cd0e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,8 @@ dataset: ```python from sparkwheel import Config -config = Config.load("config.yaml") +config = Config() +config.update("config.yaml") model = config.resolve("model") # Actual torch.nn.Linear(784, 10) instance! ``` @@ -88,7 +89,6 @@ model: - [Full Documentation](https://project-lighter.github.io/sparkwheel/) - [Quick Start Guide](https://project-lighter.github.io/sparkwheel/getting-started/quickstart/) - [Core Concepts](https://project-lighter.github.io/sparkwheel/user-guide/basics/) -- [Examples](https://project-lighter.github.io/sparkwheel/examples/simple/) - [API Reference](https://project-lighter.github.io/sparkwheel/reference/) ## Community diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 0694fc6..4bcf488 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -129,6 +129,5 @@ python train.py training::learning_rate=0.01 dataset::batch_size=64 Now that you've seen the basics: - **[Core Concepts](../user-guide/basics.md)** - Learn more about references, expressions, and instantiation -- **[Examples](../examples/simple.md)** - See complete real-world examples - **[Composition & Operators](../user-guide/operators.md)** - Master config composition with `=` and `~` - **[Schema Validation](../user-guide/schema-validation.md)** - Validate configs with dataclasses diff --git a/docs/index.md b/docs/index.md index faf2249..db33c86 100644 --- a/docs/index.md +++ b/docs/index.md @@ -233,14 +233,6 @@ Sparkwheel has two types of references with distinct purposes: [:octicons-arrow-right-24: Core Concepts](user-guide/basics.md) -- :material-lightbulb-on-outline:{ .lg .middle } __Examples__ - - --- - - See complete real-world configuration patterns - - [:octicons-arrow-right-24: View Examples](examples/simple.md) - - :material-code-tags:{ .lg .middle } __API Reference__ --- diff --git a/docs/user-guide/advanced.md b/docs/user-guide/advanced.md index 21f572c..1555750 100644 --- a/docs/user-guide/advanced.md +++ b/docs/user-guide/advanced.md @@ -171,7 +171,8 @@ model: # Merges by default! ```python from sparkwheel import Config -config = Config.load("base.yaml") +config = Config() +config.update("base.yaml") config.update("override.yaml") # Result: @@ -213,7 +214,8 @@ Use `~key: null` to delete a key, or `~key: [items]` to delete specific items fr ``` ```python -config = Config.load("base.yaml") +config = Config() +config.update("base.yaml") config.update({"~model::dropout": None}) # Remove entire key config.update({"~plugins": [0, 2]}) # Remove list items config.update({"~dataloaders": ["train", "test"]}) # Remove dict keys @@ -226,7 +228,8 @@ config.update({"~dataloaders": ["train", "test"]}) # Remove dict keys Apply operators programmatically: ```python -config = Config.load("config.yaml") +config = Config() +config.update("config.yaml") # Set individual values config.set("model::hidden_size", 1024) @@ -267,7 +270,8 @@ Sparkwheel provides helpful error messages with suggestions: ```python from sparkwheel import Config, ConfigKeyError -config = Config.load({ +config = Config() +config.update({ "model": {"hidden_size": 512, "num_layers": 4}, "training": {"batch_size": 32} }) @@ -295,7 +299,8 @@ Pre-import modules for use in expressions: from sparkwheel import Config # Pre-import torch for all expressions -config = Config.load("config.yaml", globals={"torch": "torch", "np": "numpy"}) +config = Config(globals={"torch": "torch", "np": "numpy"}) +config.update("config.yaml") # Now expressions can use torch and np without importing ``` @@ -312,7 +317,8 @@ data: "$np.array([1, 2, 3])" ```python from sparkwheel import Config -config: Config = Config.load("config.yaml") +config: Config = Config() +config.update("config.yaml") resolved: dict = config.resolve() ``` diff --git a/docs/user-guide/cli.md b/docs/user-guide/cli.md index 8cae776..dc3474d 100644 --- a/docs/user-guide/cli.md +++ b/docs/user-guide/cli.md @@ -190,4 +190,4 @@ python train.py base.yaml optimizer::lr=0.001 trainer::epochs=100 - **[Configuration Basics](basics.md)** - Loading and accessing configs - **[Operators](operators.md)** - Composition, replacement, and deletion - **[Schema Validation](schema-validation.md)** - Type-safe configs with dataclasses -- **[API Reference](references.md)** - Full API documentation +- **[API Reference](reference/)** - Full API documentation diff --git a/docs/user-guide/expressions.md b/docs/user-guide/expressions.md index f1c9c6f..6468efa 100644 --- a/docs/user-guide/expressions.md +++ b/docs/user-guide/expressions.md @@ -223,7 +223,8 @@ learning_rate: "$0.001 * (@training::batch_size ** 0.5)" from sparkwheel import Config try: - config = Config.load("config.yaml") + config = Config() + config.update("config.yaml") resolved = config.resolve() except SyntaxError as e: print(f"Expression syntax error: {e}") @@ -247,4 +248,3 @@ except Exception as e: - [Instantiation](instantiation.md) - Create objects with expressions - [Advanced Features](advanced.md) - Complex expression patterns -- [Examples](../examples/deep-learning.md) - Real-world expression usage diff --git a/docs/user-guide/instantiation.md b/docs/user-guide/instantiation.md index 73ac816..aaf10e3 100644 --- a/docs/user-guide/instantiation.md +++ b/docs/user-guide/instantiation.md @@ -14,7 +14,8 @@ model: ```python from sparkwheel import Config -config = Config.load("config.yaml") +config = Config() +config.update("config.yaml") # Instantiate the object model = config.resolve("model") diff --git a/docs/user-guide/operators.md b/docs/user-guide/operators.md index 4fa01fc..6baf2d1 100644 --- a/docs/user-guide/operators.md +++ b/docs/user-guide/operators.md @@ -535,4 +535,3 @@ Sparkwheel goes beyond Hydra with: - **[Configuration Basics](basics.md)** - Core config management - **[Advanced Features](advanced.md)** - Macros and power features -- **[Examples](../examples/simple.md)** - Real-world patterns diff --git a/docs/user-guide/references.md b/docs/user-guide/references.md index e23505b..be2f42d 100644 --- a/docs/user-guide/references.md +++ b/docs/user-guide/references.md @@ -10,6 +10,7 @@ Sparkwheel provides two types of references for linking configuration values: | Feature | `@ref` (Resolved) | `%ref` (Raw) | `$expr` (Expression) | |---------|-------------------|--------------|----------------------| | **Returns** | Final computed value | Raw YAML content | Evaluated expression result | +| **When processed** | Lazy (`resolve()`) | Eager (`update()`) | Lazy (`resolve()`) | | **Instantiates objects** | ✅ Yes | ❌ No | ✅ Yes (if referenced) | | **Evaluates expressions** | ✅ Yes | ❌ No | ✅ Yes | | **Use in dataclass validation** | ✅ Yes | ⚠️ Limited | ✅ Yes | @@ -17,6 +18,53 @@ Sparkwheel provides two types of references for linking configuration values: | **Cross-file references** | ✅ Yes | ✅ Yes | ❌ No | | **When to use** | Get computed results | Copy config structures | Compute new values | +## Two-Stage Processing Model + +Sparkwheel processes references at different times to enable safe config composition: + +!!! abstract "When References Are Processed" + + **Stage 1: Eager Processing (during `update()`)** + + - **Raw References (`%`)** are expanded immediately when configs are merged + - Enables safe config composition and pruning workflows + - External file references resolved at load time + + **Stage 2: Lazy Processing (during `resolve()`)** + + - **Resolved References (`@`)** are processed on-demand + - **Expressions (`$`)** are evaluated when needed + - **Components (`_target_`)** are instantiated only when requested + - Supports complex dependency graphs and deferred instantiation + +**Why two stages?** + +This separation enables powerful workflows like config pruning: + +```yaml +# base.yaml +system: + lr: 0.001 + batch_size: 32 + +experiment: + model: + optimizer: + lr: "%system::lr" # Copies raw value 0.001 eagerly + +~system: null # Delete system section after copying +``` + +```python +config = Config() +config.update("base.yaml") +# % references already expanded during update() +# ~system deletion applied after expansion +# Result: experiment::model::optimizer::lr = 0.001 (system deleted safely) +``` + +With `@` references, this would fail because they resolve lazily after deletion. + ## Resolution Flow !!! abstract "How References Are Resolved" @@ -25,10 +73,10 @@ Sparkwheel provides two types of references for linking configuration values: **Step 2: Determine Type** - - **`@key`** → Proceed to dependency resolution - - **`%key`** → Return raw YAML immediately ✅ + - **`%key`** → Expanded eagerly during `update()` ✅ + - **`@key`** → Proceed to dependency resolution (lazy) - **Step 3: Resolve Dependencies** (for `@` references) + **Step 3: Resolve Dependencies** (for `@` references during `resolve()`) - Check for circular references → ❌ **Error if found** - Resolve all dependencies first diff --git a/src/sparkwheel/__init__.py b/src/sparkwheel/__init__.py index 059c476..dd7cda3 100644 --- a/src/sparkwheel/__init__.py +++ b/src/sparkwheel/__init__.py @@ -19,8 +19,8 @@ EvaluationError, FrozenConfigError, InstantiationError, - ModuleNotFoundError, SourceLocation, + TargetNotFoundError, ) __version__ = "0.0.6" @@ -47,7 +47,7 @@ "REMOVE_KEY", "REPLACE_KEY", "BaseError", - "ModuleNotFoundError", + "TargetNotFoundError", "CircularReferenceError", "InstantiationError", "ConfigKeyError", diff --git a/src/sparkwheel/coercion.py b/src/sparkwheel/coercion.py index d0b16e9..6b25ab5 100644 --- a/src/sparkwheel/coercion.py +++ b/src/sparkwheel/coercion.py @@ -1,4 +1,10 @@ -"""Type coercion for schema validation.""" +"""Type coercion for schema validation. + +Automatically converts values between compatible types when safe and unambiguous, +making configs more flexible while maintaining type safety. Supports string to +numeric/bool conversions, int to float, and recursive coercion through nested +structures. Enabled by default via `Config(coerce=True)`. +""" import dataclasses import types diff --git a/src/sparkwheel/config.py b/src/sparkwheel/config.py index e51092c..9adf20b 100644 --- a/src/sparkwheel/config.py +++ b/src/sparkwheel/config.py @@ -1,4 +1,97 @@ -"""Main configuration management API.""" +"""Main configuration management API. + +Sparkwheel is a YAML-based configuration system with references, expressions, and dynamic instantiation. +This module provides the main Config class for loading, managing, and resolving configurations. + +## Two-Stage Processing + +Sparkwheel uses a two-stage processing model to handle different reference types at appropriate times: + +### Stage 1: Eager Processing (during update()) +- **Raw References (`%`)**: Expanded immediately when configs are merged +- **Purpose**: Enables safe config composition or deletion +- **Example**: `%base.yaml::lr` is replaced with the actual value from base.yaml + +### Stage 2: Lazy Processing (during resolve()) +- **Resolved References (`@`)**: Resolved on-demand to support circular dependencies +- **Expressions (`$`)**: Evaluated when needed using Python's eval() +- **Components (`_target_`)**: Instantiated only when requested +- **Purpose**: Supports deferred instantiation and complex dependency graphs + +## Reference Types + +| Symbol | Name | When Expanded | Purpose | Example | +|--------|------|---------------|---------|---------| +| `%` | Raw Reference | Eager (update()) | Copy/paste YAML sections | `%base.yaml::lr` | +| `@` | Resolved Reference | Lazy (resolve()) | Reference config values | `@model::lr` | +| `$` | Expression | Lazy (resolve()) | Compute values dynamically | `$@lr * 2` | + +## Key Methods + +- **`get(id)`**: Returns raw config value (unresolved, from `_data`) +- **`resolve(id)`**: Follows references, evaluates expressions, instantiates components +- **`update(source)`**: Loads and merges configuration from file, dict, or CLI string +- **`set(id, value)`**: Sets a config value at the given path + +## Important: get() vs resolve() + +These methods serve different purposes: + +- `get()` always returns raw values from the internal `_data` dict + - Returns `"@model::lr"` (the string) + - Fast, no resolution or caching + - Always returns raw data, even after resolve() has been called + +- `resolve()` follows references and instantiates objects + - Returns `0.001` (the actual value) + - Uses a separate resolution cache (`_resolver._resolved`) + - Evaluates expressions, instantiates components + +Example: + ```python + config = Config() + config.update({"lr": 0.001, "ref": "@lr"}) + + config.get("ref") # "@lr" (always raw, from _data) + config.resolve("ref") # 0.001 (follows reference, from cache) + config.get("ref") # Still "@lr" (get never uses cache) + ``` + +This separation ensures that: +1. Raw config structure is always accessible +2. Resolution happens lazily and is cached separately +3. Multiple resolve() calls are efficient (uses cache) +4. You can inspect raw references without triggering resolution + +## Quick Start + +```python +from sparkwheel import Config + +# Load and merge configs +config = Config() +config.update("base.yaml") +config.update("experiment.yaml") + +# Get raw values +lr = config.get("model::lr") # Raw value (may be "@base::lr") + +# Resolve references and instantiate +model = config.resolve("model") # Actual instantiated model object +lr_resolved = config.resolve("model::lr") # Resolved value (e.g., 0.001) +``` + +## CLI Overrides + +```python +# Auto-detects override syntax +config.update("model::lr=0.001") # Compose (merge) +config.update("=model::lr=0.001") # Replace +config.update("~model::old_param") # Delete +``` + +See Config class docstring for full API details. +""" from pathlib import Path from typing import Any @@ -29,13 +122,14 @@ class Config: from sparkwheel import Config # Create and load from file - config = Config(schema=MySchema).update("config.yaml") + config = Config(schema=MySchema) + config.update("config.yaml") # Or chain multiple sources - config = (Config(schema=MySchema) - .update("base.yaml") - .update("override.yaml") - .update({"model::lr": 0.001})) + config = Config(schema=MySchema) + config.update("base.yaml") + config.update("override.yaml") + config.update({"model::lr": 0.001}) # Access raw values lr = config.get("model::lr") @@ -112,20 +206,54 @@ def __init__( def get(self, id: str = "", default: Any = None) -> Any: """Get raw config value (unresolved). + IMPORTANT: This method ALWAYS returns raw values from the internal `_data` dict, + even after resolve() has been called. It never uses the resolution cache. + + - Returns `@` references as strings (e.g., "@model::lr") + - Returns `$` expressions as strings (e.g., "$@lr * 2") + - Returns `%` raw references already expanded (eager expansion during update()) + - Fast, no resolution overhead + + Use this when you need to: + - Inspect the raw config structure + - Check what references exist + - Access config before resolution + - Avoid triggering expensive instantiation + Args: id: Configuration path (use :: for nesting, e.g., "model::lr") Empty string returns entire config default: Default value if id not found Returns: - Raw configuration value (resolved references not resolved, raw references not expanded) + Raw configuration value from _data (@ and $ unresolved, % already expanded) - Example: - >>> config = Config.load({"model": {"lr": 0.001, "ref": "@model::lr"}}) - >>> config.get("model::lr") + Examples: + >>> # Basic usage + >>> config = Config() + >>> config.update({"lr": 0.001, "ref": "@lr"}) + >>> config.get("lr") 0.001 - >>> config.get("model::ref") - "@model::lr" # Unresolved resolved reference + >>> config.get("ref") + "@lr" # Raw @ reference string + + >>> # get() vs resolve() comparison + >>> config = Config() + >>> config.update({ + ... "lr": 0.001, + ... "doubled": "$@lr * 2", + ... "ref": "@lr" + ... }) + >>> config.get("doubled") + "$@lr * 2" # Raw expression string + >>> config.resolve("doubled") + 0.002 # Evaluated result + >>> config.get("doubled") # Still raw after resolve()! + "$@lr * 2" + + >>> # With default value + >>> config.get("nonexistent", default=999) + 999 """ try: return self._get_by_id(id) @@ -194,9 +322,9 @@ def validate(self, schema: type) -> None: ... class ModelConfig: ... hidden_size: int ... dropout: float - >>> config = Config.load({"hidden_size": 512, "dropout": 0.1}) + >>> config = Config().update({"hidden_size": 512, "dropout": 0.1}) >>> config.validate(ModelConfig) # Passes - >>> bad_config = Config.load({"hidden_size": "not an int"}) + >>> bad_config = Config().update({"hidden_size": "not an int"}) >>> bad_config.validate(ModelConfig) # Raises ValidationError """ from .schema import validate as validate_schema @@ -295,7 +423,13 @@ def update(self, source: PathLike | dict[str, Any] | "Config" | str) -> "Config" else: self._update_from_file(source) - # Validate after update if schema exists + # Eagerly expand raw references (%) immediately after update + # This matches MONAI's behavior and allows safe pruning with delete operator (~) + # Must happen BEFORE validation so schema sees final structure, not raw ref strings + self._data = self._preprocessor.process_raw_refs(self._data, self._data, id="") + + # Validate after raw ref expansion if schema exists + # This validates the final structure, not intermediate raw reference strings if self._schema: from .schema import validate as validate_schema @@ -374,8 +508,17 @@ def _update_from_file(self, source: PathLike) -> None: """Load and update from a file.""" new_data, new_metadata = self._loader.load_file(source) validate_operators(new_data) - self._data = apply_operators(self._data, new_data) - self._metadata.merge(new_metadata) + + # Check if loaded data uses :: path syntax + if self._uses_nested_paths(new_data): + # Expand nested paths using path updates + self._metadata.merge(new_metadata) + self._apply_path_updates(new_data) + else: + # Normal structural update + self._data = apply_operators(self._data, new_data) + self._metadata.merge(new_metadata) + self._invalidate_resolution() def _update_from_override_string(self, override: str) -> None: @@ -391,39 +534,73 @@ def resolve( lazy: bool = True, default: Any = None, ) -> Any: - """Resolve resolved references (@) and return parsed config. + """Resolve references, evaluate expressions, and instantiate components. + + This is the main method for getting fully resolved config values. It: + 1. Follows `@` references to their target values + 2. Evaluates `$` expressions using Python eval() + 3. Instantiates components with `_target_` keys + 4. Caches results in a separate resolution cache (`_resolver._resolved`) - Automatically parses config on first call. Resolves @ resolved references (follows - them to get instantiated/evaluated values), evaluates $ expressions, and - instantiates _target_ components. Note: % raw references are expanded during - preprocessing (before this stage). + Unlike get(), which always returns raw `_data`, resolve() performs full processing + and uses a separate cache for efficiency. + + Processing stages: + - `%` raw references: Already expanded during update() (eager) + - `@` resolved references: Resolved now (lazy, supports circular deps) + - `$` expressions: Evaluated now (lazy) + - `_target_` components: Instantiated now (lazy) Args: id: Config path to resolve (empty string for entire config) - instantiate: Whether to instantiate components with _target_ - eval_expr: Whether to evaluate $ expressions - lazy: Whether to use cached resolution + instantiate: Whether to instantiate components with _target_ (default: True) + eval_expr: Whether to evaluate $ expressions (default: True) + lazy: Whether to use cached resolution (default: True) default: Default value if id not found (returns default.get_config() if Item) Returns: - Resolved value (instantiated objects, evaluated expressions, etc.) + Resolved value (could be primitive, object, or complex structure) - Example: - >>> config = Config.load({ + Examples: + >>> # Basic reference resolution + >>> config = Config() + >>> config.update({ ... "lr": 0.001, - ... "doubled": "$@lr * 2", + ... "ref": "@lr" + ... }) + >>> config.get("ref") + "@lr" # Raw string + >>> config.resolve("ref") + 0.001 # Followed reference + + >>> # Expression evaluation + >>> config = Config() + >>> config.update({ + ... "lr": 0.001, + ... "doubled": "$@lr * 2" + ... }) + >>> config.resolve("doubled") + 0.002 + + >>> # Component instantiation + >>> config = Config() + >>> config.update({ ... "optimizer": { ... "_target_": "torch.optim.Adam", - ... "lr": "@lr" + ... "lr": 0.001 ... } ... }) - >>> config.resolve("lr") - 0.001 - >>> config.resolve("doubled") - 0.002 >>> optimizer = config.resolve("optimizer") >>> type(optimizer).__name__ 'Adam' + + >>> # Disable instantiation (useful for inspection) + >>> config.resolve("optimizer", instantiate=False) + {'_target_': 'torch.optim.Adam', 'lr': 0.001} + + >>> # With default value + >>> config.resolve("nonexistent", default=None) + None """ # Parse if needed if not self._is_parsed or not lazy: @@ -446,6 +623,7 @@ def _parse(self, reset: bool = True) -> None: """Parse config tree and prepare for resolution. Internal method called automatically by resolve(). + Note: % raw references are already expanded during update(). Args: reset: Whether to reset the resolver before parsing (default: True) @@ -454,7 +632,8 @@ def _parse(self, reset: bool = True) -> None: if reset: self._resolver.reset() - # Stage 1: Preprocess (% raw references, @:: relative resolved IDs) + # Stage 1: Preprocess (@:: relative resolved IDs) + # Note: % raw references were already expanded in update() self._data = self._preprocessor.process(self._data, self._data, id="") # Stage 2: Parse config tree to create Items @@ -507,7 +686,7 @@ def __getitem__(self, id: str) -> Any: Config value at that path Example: - >>> config = Config.load({"model": {"lr": 0.001}}) + >>> config = Config().update({"model": {"lr": 0.001}}) >>> config["model::lr"] 0.001 """ @@ -521,7 +700,7 @@ def __setitem__(self, id: str, value: Any) -> None: value: Value to set Example: - >>> config = Config.load({}) + >>> config = Config().update({}) >>> config["model::lr"] = 0.001 """ self.set(id, value) diff --git a/src/sparkwheel/items.py b/src/sparkwheel/items.py index 0f320b3..b095084 100644 --- a/src/sparkwheel/items.py +++ b/src/sparkwheel/items.py @@ -7,7 +7,7 @@ from .utils import CompInitMode, first, instantiate, optional_import, run_debug, run_eval from .utils.constants import EXPR_KEY -from .utils.exceptions import EvaluationError, InstantiationError, ModuleNotFoundError, SourceLocation +from .utils.exceptions import EvaluationError, InstantiationError, SourceLocation, TargetNotFoundError __all__ = ["Item", "Expression", "Component", "Instantiable"] @@ -185,10 +185,10 @@ def instantiate(self, **kwargs: Any) -> object: try: return instantiate(modname, mode, **args) - except ModuleNotFoundError as e: + except TargetNotFoundError as e: # Re-raise with source location and suggestions suggestion = self._suggest_similar_modules(modname) if isinstance(modname, str) else None - raise ModuleNotFoundError( + raise TargetNotFoundError( f"Cannot locate class or function: '{modname}'", source_location=self.source_location, suggestion=suggestion, diff --git a/src/sparkwheel/loader.py b/src/sparkwheel/loader.py index a7a5ef4..948da8f 100644 --- a/src/sparkwheel/loader.py +++ b/src/sparkwheel/loader.py @@ -8,7 +8,7 @@ import yaml # type: ignore[import-untyped] from .metadata import MetadataRegistry -from .path_patterns import is_yaml_file +from .path_utils import is_yaml_file from .utils import CheckKeyDuplicatesYamlLoader, PathLike from .utils.constants import ID_SEP_KEY from .utils.exceptions import SourceLocation diff --git a/src/sparkwheel/path_patterns.py b/src/sparkwheel/path_patterns.py deleted file mode 100644 index 31ebe8b..0000000 --- a/src/sparkwheel/path_patterns.py +++ /dev/null @@ -1,214 +0,0 @@ -"""Centralized regex patterns for config path parsing. - -This module contains all regex patterns used across sparkwheel for parsing -configuration paths, references, and file paths. Patterns are compiled once -at module load and documented with examples. - -Why regex here? -- Complex patterns (lookahead, Unicode support) -- Performance (C regex engine) -- Correctness (battle-tested patterns) - -Patterns are centralized here instead of scattered across multiple files -for easier maintenance, testing, and documentation. -""" - -import re - -from .utils.constants import RAW_REF_KEY, RESOLVED_REF_KEY - -__all__ = [ - "PathPatterns", - "is_yaml_file", -] - - -def is_yaml_file(filepath: str) -> bool: - """Check if filepath is a YAML file (.yaml or .yml). - - Simple string check - no regex needed for this. - - Args: - filepath: Path to check - - Returns: - True if filepath ends with .yaml or .yml (case-insensitive) - - Examples: - >>> is_yaml_file("config.yaml") - True - >>> is_yaml_file("CONFIG.YAML") - True - >>> is_yaml_file("data.json") - False - """ - lower = filepath.lower() - return lower.endswith(".yaml") or lower.endswith(".yml") - - -class PathPatterns: - """Collection of compiled regex patterns for config path parsing. - - All patterns are compiled once at class definition time and reused. - Each pattern includes documentation with examples of what it matches. - """ - - # File path and config ID splitting - # Example: "config.yaml::model::lr" -> captures "config.yaml" - # Uses lookahead (?=...) to find extension without consuming :: separator - FILE_AND_ID = re.compile(r"(.*\.(yaml|yml))(?=(?:::.*)|$)", re.IGNORECASE) - """Split combined file path and config ID. - - The pattern uses lookahead to find the file extension without consuming - the :: separator that follows. - - Matches: - - "config.yaml::model::lr" -> group 1: "config.yaml" - - "path/to/file.yml::key" -> group 1: "path/to/file.yml" - - "/abs/path/cfg.yaml::a::b" -> group 1: "/abs/path/cfg.yaml" - - Non-matches: - - "model::lr" -> no .yaml/.yml extension - - "data.json::key" -> wrong extension - - Edge cases handled: - - Case insensitive: "Config.YAML::key" works - - Multiple extensions: "backup.yaml.old" stops at first .yaml - - Absolute paths: "/etc/config.yaml::key" works - """ - - RELATIVE_REFERENCE = re.compile(rf"(?:{RESOLVED_REF_KEY}|{RAW_REF_KEY})(::)+") - """Match relative reference prefixes: @::, @::::, %::, etc. - - Used to find relative navigation patterns in config references. - The number of :: pairs indicates how many levels to go up. - - Matches: - - "@::" -> resolved reference one level up (parent) - - "@::::" -> resolved reference two levels up (grandparent) - - "%::" -> raw reference one level up - - "%::::" -> raw reference two levels up - - Examples in context: - - In "model::optimizer", "@::lr" means "@model::lr" - - In "a::b::c", "@::::x" means "@a::x" - - Pattern breakdown: - - (?:@|%) -> @ or % symbol (non-capturing group) - - (::)+ -> one or more :: pairs (captured) - """ - - ABSOLUTE_REFERENCE = re.compile(rf"{RESOLVED_REF_KEY}(\w+(?:::\w+)*)") - r"""Match absolute resolved reference patterns: @id::path::to::value - - Finds @ resolved references in config values and expressions. Handles nested - paths with :: separators and list indices (numbers). - - Matches: - - "@model::lr" -> captures "model::lr" - - "@data::0::value" -> captures "data::0::value" - - "@x" -> captures "x" - - Examples in expressions: - - "$@model::lr * 2" -> matches "@model::lr" - - "$@x + @y" -> matches "@x" and "@y" - - Pattern breakdown: - - @ -> literal @ symbol - - (\w+(?:::\w+)*) -> captures word chars followed by optional :: and more word chars - - Note: \w includes [a-zA-Z0-9_] plus Unicode word characters, - so this handles international characters correctly. - """ - - @classmethod - def split_file_and_id(cls, src: str) -> tuple[str, str]: - """Split combined file path and config ID using FILE_AND_ID pattern. - - Args: - src: String like "config.yaml::model::lr" - - Returns: - Tuple of (filepath, config_id) - - Examples: - >>> PathPatterns.split_file_and_id("config.yaml::model::lr") - ("config.yaml", "model::lr") - >>> PathPatterns.split_file_and_id("model::lr") - ("", "model::lr") - >>> PathPatterns.split_file_and_id("/path/to/file.yml::key") - ("/path/to/file.yml", "key") - """ - src = src.strip() - match = cls.FILE_AND_ID.search(src) - - if not match: - return "", src # Pure ID, no file path - - filepath = match.group(1) - remainder = src[match.end() :] - - # Strip leading :: from config ID part - config_id = remainder[2:] if remainder.startswith("::") else remainder - - return filepath, config_id - - @classmethod - def find_relative_references(cls, text: str) -> list[str]: - """Find all relative reference patterns in text. - - Args: - text: String to search - - Returns: - List of relative reference patterns found (e.g., ['@::', '@::::']) - - Examples: - >>> PathPatterns.find_relative_references("value: @::sibling") - ['@::'] - >>> PathPatterns.find_relative_references("@::::parent and @::sibling") - ['@::::', '@::'] - """ - # Use finditer to get full matches instead of just captured groups - return [match.group(0) for match in cls.RELATIVE_REFERENCE.finditer(text)] - - @classmethod - def find_absolute_references(cls, text: str) -> list[str]: - """Find all absolute reference patterns in text. - - Only searches in expressions ($...) or pure reference values. - - Args: - text: String to search - - Returns: - List of reference IDs found (without @ prefix) - - Examples: - >>> PathPatterns.find_absolute_references("@model::lr") - ['model::lr'] - >>> PathPatterns.find_absolute_references("$@x + @y") - ['x', 'y'] - >>> PathPatterns.find_absolute_references("normal text") - [] - """ - is_expr = text.startswith("$") - is_pure_ref = text.startswith("@") - - if not (is_expr or is_pure_ref): - return [] - - return cls.ABSOLUTE_REFERENCE.findall(text) - - -# Utility functions that delegate to PathPatterns - - -def split_file_and_id(src: str) -> tuple[str, str]: - """Convenience function wrapping PathPatterns.split_file_and_id().""" - return PathPatterns.split_file_and_id(src) - - -def find_references(text: str) -> list[str]: - """Convenience function wrapping PathPatterns.find_absolute_references().""" - return PathPatterns.find_absolute_references(text) diff --git a/src/sparkwheel/path_utils.py b/src/sparkwheel/path_utils.py index 708c2c1..b0f5b2e 100644 --- a/src/sparkwheel/path_utils.py +++ b/src/sparkwheel/path_utils.py @@ -1,13 +1,14 @@ """Path parsing and manipulation utilities. -Provides helper functions for working with config paths, building on -the regex patterns from path_patterns.py. +Centralized regex patterns and helper functions for parsing configuration paths, +references, and file paths. Handles ID splitting, relative reference resolution, +and file path extraction from combined strings like "config.yaml::model::lr". """ +import re from typing import Any -from .path_patterns import PathPatterns -from .utils.constants import ID_SEP_KEY +from .utils.constants import ID_SEP_KEY, RAW_REF_KEY, RESOLVED_REF_KEY __all__ = [ "split_id", @@ -15,9 +16,238 @@ "resolve_relative_ids", "scan_references", "replace_references", + "split_file_and_id", + "is_yaml_file", + "PathPatterns", # Export for backward compatibility ] +# ============================================================================ +# YAML File Detection +# ============================================================================ + + +def is_yaml_file(filepath: str) -> bool: + """Check if filepath is a YAML file (.yaml or .yml). + + Simple string check - no regex needed for this. + + Args: + filepath: Path to check + + Returns: + True if filepath ends with .yaml or .yml (case-insensitive) + + Examples: + >>> is_yaml_file("config.yaml") + True + >>> is_yaml_file("CONFIG.YAML") + True + >>> is_yaml_file("data.json") + False + """ + lower = filepath.lower() + return lower.endswith(".yaml") or lower.endswith(".yml") + + +# ============================================================================ +# Compiled Regex Patterns +# ============================================================================ + + +class PathPatterns: + """Collection of compiled regex patterns for config path parsing. + + All patterns are compiled once at class definition time and reused. + Each pattern includes documentation with examples of what it matches. + """ + + # File path and config ID splitting + # Example: "config.yaml::model::lr" -> captures "config.yaml" + # Uses lookahead (?=...) to find extension without consuming :: separator + FILE_AND_ID = re.compile(r"(.*\.(yaml|yml))(?=(?:::.*)|$)", re.IGNORECASE) + """Split combined file path and config ID. + + The pattern uses lookahead to find the file extension without consuming + the :: separator that follows. + + Matches: + - "config.yaml::model::lr" -> group 1: "config.yaml" + - "path/to/file.yml::key" -> group 1: "path/to/file.yml" + - "/abs/path/cfg.yaml::a::b" -> group 1: "/abs/path/cfg.yaml" + + Non-matches: + - "model::lr" -> no .yaml/.yml extension + - "data.json::key" -> wrong extension + + Edge cases handled: + - Case insensitive: "Config.YAML::key" works + - Multiple extensions: "backup.yaml.old" stops at first .yaml + - Absolute paths: "/etc/config.yaml::key" works + """ + + RELATIVE_REFERENCE = re.compile(rf"(?:{RESOLVED_REF_KEY}|{RAW_REF_KEY})(::)+") + """Match relative reference prefixes: @::, @::::, %::, etc. + + Used to find relative navigation patterns in config references. + The number of :: pairs indicates how many levels to go up. + + Matches: + - "@::" -> resolved reference one level up (parent) + - "@::::" -> resolved reference two levels up (grandparent) + - "%::" -> raw reference one level up + - "%::::" -> raw reference two levels up + + Examples in context: + - In "model::optimizer", "@::lr" means "@model::lr" + - In "a::b::c", "@::::x" means "@a::x" + + Pattern breakdown: + - (?:@|%) -> @ or % symbol (non-capturing group) + - (::)+ -> one or more :: pairs (captured) + """ + + ABSOLUTE_REFERENCE = re.compile(rf"{RESOLVED_REF_KEY}(\w+(?:::\w+)*)") + r"""Match absolute resolved reference patterns: @id::path::to::value + + Finds @ resolved references in config values and expressions. Handles nested + paths with :: separators and list indices (numbers). + + Matches: + - "@model::lr" -> captures "model::lr" + - "@data::0::value" -> captures "data::0::value" + - "@x" -> captures "x" + + Examples in expressions: + - "$@model::lr * 2" -> matches "@model::lr" + - "$@x + @y" -> matches "@x" and "@y" + + Pattern breakdown: + - @ -> literal @ symbol + - (\w+(?:::\w+)*) -> captures word chars followed by optional :: and more word chars + + Note: \w includes [a-zA-Z0-9_] plus Unicode word characters, + so this handles international characters correctly. + """ + + @classmethod + def split_file_and_id(cls, src: str) -> tuple[str, str]: + """Split combined file path and config ID using FILE_AND_ID pattern. + + Args: + src: String like "config.yaml::model::lr" + + Returns: + Tuple of (filepath, config_id) + + Examples: + >>> PathPatterns.split_file_and_id("config.yaml::model::lr") + ("config.yaml", "model::lr") + >>> PathPatterns.split_file_and_id("model::lr") + ("", "model::lr") + >>> PathPatterns.split_file_and_id("/path/to/file.yml::key") + ("/path/to/file.yml", "key") + """ + src = src.strip() + match = cls.FILE_AND_ID.search(src) + + if not match: + return "", src # Pure ID, no file path + + filepath = match.group(1) + remainder = src[match.end() :] + + # Strip leading :: from config ID part + config_id = remainder[2:] if remainder.startswith("::") else remainder + + return filepath, config_id + + @classmethod + def find_relative_references(cls, text: str) -> list[str]: + """Find all relative reference patterns in text. + + Args: + text: String to search + + Returns: + List of relative reference patterns found (e.g., ['@::', '@::::']) + + Examples: + >>> PathPatterns.find_relative_references("value: @::sibling") + ['@::'] + >>> PathPatterns.find_relative_references("@::::parent and @::sibling") + ['@::::', '@::'] + """ + # Use finditer to get full matches instead of just captured groups + return [match.group(0) for match in cls.RELATIVE_REFERENCE.finditer(text)] + + @classmethod + def find_absolute_references(cls, text: str) -> list[str]: + """Find all absolute reference patterns in text. + + Only searches in expressions ($...) or pure reference values. + + Args: + text: String to search + + Returns: + List of reference IDs found (without @ prefix) + + Examples: + >>> PathPatterns.find_absolute_references("@model::lr") + ['model::lr'] + >>> PathPatterns.find_absolute_references("$@x + @y") + ['x', 'y'] + >>> PathPatterns.find_absolute_references("normal text") + [] + """ + is_expr = text.startswith("$") + is_pure_ref = text.startswith("@") + + if not (is_expr or is_pure_ref): + return [] + + return cls.ABSOLUTE_REFERENCE.findall(text) + + +# ============================================================================ +# Convenience Functions (delegate to PathPatterns) +# ============================================================================ + + +def split_file_and_id(src: str) -> tuple[str, str]: + """Convenience function wrapping PathPatterns.split_file_and_id(). + + Args: + src: String like "config.yaml::model::lr" + + Returns: + Tuple of (filepath, config_id) + + Examples: + >>> split_file_and_id("config.yaml::model::lr") + ("config.yaml", "model::lr") + """ + return PathPatterns.split_file_and_id(src) + + +def find_references(text: str) -> list[str]: + """Convenience function wrapping PathPatterns.find_absolute_references(). + + Args: + text: String to search + + Returns: + List of reference IDs found (without @ prefix) + """ + return PathPatterns.find_absolute_references(text) + + +# ============================================================================ +# ID Manipulation Functions +# ============================================================================ + + def split_id(id: str | int) -> list[str]: """Split config ID into parts by :: separator. diff --git a/src/sparkwheel/preprocessor.py b/src/sparkwheel/preprocessor.py index 996cafe..cfc70ab 100644 --- a/src/sparkwheel/preprocessor.py +++ b/src/sparkwheel/preprocessor.py @@ -8,8 +8,7 @@ from copy import deepcopy from typing import Any -from .path_patterns import split_file_and_id -from .path_utils import resolve_relative_ids, split_id +from .path_utils import resolve_relative_ids, split_file_and_id, split_id from .utils.constants import ID_SEP_KEY, RAW_REF_KEY __all__ = ["Preprocessor"] @@ -59,11 +58,36 @@ def __init__(self, loader, globals: dict[str, Any] | None = None): # type: igno self.loader = loader self.globals = globals or {} + def process_raw_refs(self, config: Any, base_data: dict[str, Any], id: str = "") -> Any: + """Preprocess config tree - expand only % raw references. + + This is the first preprocessing stage that runs eagerly during update(). + It expands all % raw references including those with relative syntax: + - Local: %key + - External: %file.yaml::key + - Relative: %::key, %::::key (converted to absolute before expansion) + + Leaves @ resolved references untouched (they're processed lazily during resolve()). + + Args: + config: Raw config structure to process + base_data: Root config dict (for resolving local raw references) + id: Current ID path in tree + + Returns: + Config with raw references expanded + + Raises: + ValueError: If circular raw reference detected + """ + return self._process_raw_refs_recursive(config, base_data, id, set()) + def process(self, config: Any, base_data: dict[str, Any], id: str = "") -> Any: """Preprocess entire config tree. Main entry point - walks config tree recursively and applies - all preprocessing transformations. + all preprocessing transformations. This is the second preprocessing stage + that runs lazily during resolve(), handling relative IDs and @ references. Args: config: Raw config structure to process @@ -78,6 +102,57 @@ def process(self, config: Any, base_data: dict[str, Any], id: str = "") -> Any: """ return self._process_recursive(config, base_data, id, set()) + def _process_raw_refs_recursive( + self, + config: Any, + base_data: dict[str, Any], + id: str, + raw_ref_stack: set[str], + ) -> Any: + """Internal recursive implementation for expanding only raw references. + + This method only expands % raw references and leaves @ references untouched. + + Performance optimization: Skips recursion for nodes that don't contain any + raw reference strings, avoiding unnecessary tree traversal. + + Args: + config: Current config node + base_data: Root config dict + id: Current ID path + raw_ref_stack: Circular reference detection + + Returns: + Config with raw references expanded + """ + # Early exit optimization: Skip processing if this subtree has no raw references + # This avoids unnecessary recursion for large config sections without % refs + if not self._contains_raw_refs(config): + return config + + # Recursively process nested structures + if isinstance(config, dict): + for key in list(config.keys()): + sub_id = f"{id}{ID_SEP_KEY}{key}" if id else str(key) + config[key] = self._process_raw_refs_recursive(config[key], base_data, sub_id, raw_ref_stack) + + elif isinstance(config, list): + for idx in range(len(config)): + sub_id = f"{id}{ID_SEP_KEY}{idx}" if id else str(idx) + config[idx] = self._process_raw_refs_recursive(config[idx], base_data, sub_id, raw_ref_stack) + + # Process string values - only expand raw references (%) + if isinstance(config, str): + # First resolve relative IDs in raw references (e.g., %::key -> %parent::key) + # This is necessary because raw references can use relative syntax + config = resolve_relative_ids(id, config) + + # Then expand raw references + if config.startswith(RAW_REF_KEY): + config = self._expand_raw_ref(config, base_data, raw_ref_stack) + + return config + def _process_recursive( self, config: Any, @@ -112,7 +187,7 @@ def _process_recursive( # Step 1: Resolve relative IDs (@::, @::::) to absolute (@) config = resolve_relative_ids(id, config) - # Step 2: Expand raw references (%) + # Step 2: Expand raw references (%) - should already be expanded, but keep for safety if config.startswith(RAW_REF_KEY): config = self._expand_raw_ref(config, base_data, raw_ref_stack) @@ -152,8 +227,8 @@ def _expand_raw_ref(self, raw_ref: str, base_data: dict[str, Any], raw_ref_stack # Navigate to referenced value result = self._get_by_id(loaded_config, ids) - # Recursively preprocess the loaded value - result = self._process_recursive(result, loaded_config, ids, raw_ref_stack) + # Recursively preprocess the loaded value (expand nested raw references only) + result = self._process_raw_refs_recursive(result, loaded_config, ids, raw_ref_stack) # Deep copy for independence return deepcopy(result) @@ -161,6 +236,26 @@ def _expand_raw_ref(self, raw_ref: str, base_data: dict[str, Any], raw_ref_stack finally: raw_ref_stack.discard(raw_ref) + @staticmethod + def _contains_raw_refs(config: Any) -> bool: + """Check if a config node or its descendants contain any raw references. + + Performance optimization to skip processing subtrees without % references. + + Args: + config: Config node to check + + Returns: + True if any raw references found, False otherwise + """ + if isinstance(config, str): + return config.startswith(RAW_REF_KEY) + elif isinstance(config, dict): + return any(Preprocessor._contains_raw_refs(v) for v in config.values()) + elif isinstance(config, list): + return any(Preprocessor._contains_raw_refs(item) for item in config) + return False + @staticmethod def _get_by_id(config: dict[str, Any], id: str) -> Any: """Navigate config dict by ID path. diff --git a/src/sparkwheel/schema.py b/src/sparkwheel/schema.py index f408fc2..973cfa6 100644 --- a/src/sparkwheel/schema.py +++ b/src/sparkwheel/schema.py @@ -25,11 +25,13 @@ class ModelConfig: optimizer: OptimizerConfig # Load and validate config - config = Config.load("config.yaml") + config = Config() + config.update("config.yaml") validate(config.get(), ModelConfig) # Raises error if invalid # Or validate during load - config = Config.load("config.yaml", schema=ModelConfig) + config = Config(schema=ModelConfig) + config.update("config.yaml") ``` """ @@ -566,6 +568,12 @@ def _validate_field( source_location=source_loc, ) + # Handle references and expressions early + # Accept resolved references (@), raw references (%), and expressions ($) as strings + # since they'll be resolved/expanded later - we can't validate their type until resolution + if isinstance(value, str) and (value.startswith("@") or value.startswith("$") or value.startswith("%")): + return + # Handle list[T] if origin is list: if not isinstance(value, list): @@ -661,13 +669,6 @@ def _validate_field( # Handle basic types (int, str, float, bool, etc.) if not isinstance(value, expected_type): - # Special case: accept resolved references (@), raw references (%), and expressions ($) as strings - # since they'll be resolved/expanded later - if isinstance(value, str) and (value.startswith("@") or value.startswith("$") or value.startswith("%")): - # This is a resolved reference/raw reference/expression that will be processed later - # We can't validate its type until resolution - return - # Special case: allow int for float if expected_type is float and isinstance(value, int): return diff --git a/src/sparkwheel/utils/exceptions.py b/src/sparkwheel/utils/exceptions.py index e4d279d..cbb9645 100644 --- a/src/sparkwheel/utils/exceptions.py +++ b/src/sparkwheel/utils/exceptions.py @@ -7,7 +7,7 @@ __all__ = [ "SourceLocation", "BaseError", - "ModuleNotFoundError", + "TargetNotFoundError", "CircularReferenceError", "InstantiationError", "ConfigKeyError", @@ -109,7 +109,7 @@ def _get_config_snippet(self) -> str: return "" -class ModuleNotFoundError(BaseError): +class TargetNotFoundError(BaseError): """Raised when a _target_ module/class/function cannot be located.""" pass diff --git a/src/sparkwheel/utils/module.py b/src/sparkwheel/utils/module.py index 81cb063..cb0647e 100644 --- a/src/sparkwheel/utils/module.py +++ b/src/sparkwheel/utils/module.py @@ -9,7 +9,7 @@ from typing import Any from sparkwheel.utils.enums import CompInitMode -from sparkwheel.utils.exceptions import InstantiationError, ModuleNotFoundError +from sparkwheel.utils.exceptions import InstantiationError, TargetNotFoundError __all__ = [ "run_eval", @@ -222,7 +222,7 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: """ component = locate(__path) if isinstance(__path, str) else __path if component is None: - raise ModuleNotFoundError(f"Cannot locate class or function path: '{__path}'.") + raise TargetNotFoundError(f"Cannot locate class or function path: '{__path}'.") m = look_up_option(__mode, CompInitMode) try: if kwargs.pop("_debug_", False) or run_debug: diff --git a/tests/test_config.py b/tests/test_config.py index b67763b..80b8359 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,8 +19,7 @@ import yaml from sparkwheel import Config, apply_operators -from sparkwheel.path_patterns import split_file_and_id -from sparkwheel.path_utils import resolve_relative_ids +from sparkwheel.path_utils import resolve_relative_ids, split_file_and_id class TestConfigBasics: @@ -278,6 +277,52 @@ def test_do_resolve_macro_load(self): finally: Path(filepath).unlink() + def test_eager_raw_reference_expansion(self): + """Test that raw references are expanded during update(), not resolve().""" + config = {"original": {"a": 1, "b": 2}, "copy": "%original"} + parser = Config().update(config) + + # After update(), raw references should be expanded (not resolve() yet!) + # The copy should be the actual value, not the "%original" string + assert parser.get("copy") == {"a": 1, "b": 2} + assert parser.get("copy") is not parser.get("original") # Deep copy + + # Verify it's a real copy by modifying it + parser.set("copy::c", 3) + assert parser.get("copy::c") == 3 + assert "c" not in parser.get("original") + + def test_pruning_with_raw_references(self): + """Test that pruning works with raw references (the Lighter use case).""" + # This is the critical test: raw refs should be expanded before pruning + config = { + "system": { + "dataloaders": { + "train": {"batch_size": 32}, + "val": {"batch_size": 64}, + } + }, + "train": { + "dataloader": "%system::dataloaders::train" # Raw reference + }, + } + + parser = Config().update(config) + + # Raw reference should already be expanded + assert parser.get("train::dataloader") == {"batch_size": 32} + + # Now prune the system section (delete it) + parser.update("~system") + + # The raw reference was already expanded, so train::dataloader should still exist + assert parser.get("train::dataloader") == {"batch_size": 32} + assert "system" not in parser.get() # system is deleted + + # Verify we can still resolve after pruning + result = parser.resolve("train::dataloader") + assert result == {"batch_size": 32} + class TestComponents: """Test component instantiation and handling.""" diff --git a/tests/test_items.py b/tests/test_items.py index 10bd160..f46e113 100644 --- a/tests/test_items.py +++ b/tests/test_items.py @@ -3,7 +3,7 @@ import pytest from sparkwheel.items import Component, Expression, Item -from sparkwheel.utils.exceptions import EvaluationError, InstantiationError, ModuleNotFoundError +from sparkwheel.utils.exceptions import EvaluationError, InstantiationError, TargetNotFoundError class TestItem: @@ -126,20 +126,20 @@ def test_instantiate_disabled(self): assert result is None def test_instantiate_module_not_found(self): - """Test instantiate raises ModuleNotFoundError for missing module.""" + """Test instantiate raises TargetNotFoundError for missing module.""" config = {"_target_": "nonexistent.module.Class"} component = Component(config=config) - with pytest.raises(ModuleNotFoundError, match="Cannot locate class or function"): + with pytest.raises(TargetNotFoundError, match="Cannot locate class or function"): component.instantiate() def test_instantiate_with_suggestions(self): - """Test ModuleNotFoundError includes suggestions for typos.""" + """Test TargetNotFoundError includes suggestions for typos.""" # Use a real module with a typo config = {"_target_": "collections.Counterfeit"} # Should suggest "Counter" component = Component(config=config) - with pytest.raises(ModuleNotFoundError) as exc_info: + with pytest.raises(TargetNotFoundError) as exc_info: component.instantiate() error_msg = str(exc_info.value) @@ -387,7 +387,7 @@ def test_component_error_includes_source_location(self): config = {"_target_": "nonexistent.Module"} component = Component(config=config, id="model", source_location=location) - with pytest.raises(ModuleNotFoundError) as exc_info: + with pytest.raises(TargetNotFoundError) as exc_info: component.instantiate() error = exc_info.value @@ -416,7 +416,7 @@ def test_component_suggestion_exception_handling(self): component = Component(config={"_target_": "nonexistent.BadModule"}, id="test") # Instantiate should fail, but suggestion generation shouldn't crash - with pytest.raises(ModuleNotFoundError): + with pytest.raises(TargetNotFoundError): component.instantiate() def test_expression_multiple_import_aliases(self): diff --git a/tests/test_path_utils.py b/tests/test_path_utils.py index f3b2f11..b90f607 100644 --- a/tests/test_path_utils.py +++ b/tests/test_path_utils.py @@ -1,6 +1,6 @@ """Tests for path utility functions.""" -from sparkwheel.path_patterns import PathPatterns, find_references +from sparkwheel.path_utils import PathPatterns, find_references class TestPathPatterns: diff --git a/tests/test_utils.py b/tests/test_utils.py index 90f8172..1f7adaf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -332,9 +332,9 @@ def test_instantiate_with_callable_object(self): def test_instantiate_not_found(self): """Test instantiate raises error for non-existent path.""" - from sparkwheel.utils.exceptions import ModuleNotFoundError + from sparkwheel.utils.exceptions import TargetNotFoundError - with pytest.raises(ModuleNotFoundError, match="Cannot locate"): + with pytest.raises(TargetNotFoundError, match="Cannot locate"): instantiate("nonexistent.module.Class", "default") def test_instantiate_not_callable_warning(self): diff --git a/uv.lock b/uv.lock index aa78a76..be68920 100644 --- a/uv.lock +++ b/uv.lock @@ -1714,7 +1714,7 @@ wheels = [ [[package]] name = "sparkwheel" -version = "0.0.5" +version = "0.0.6" source = { editable = "." } dependencies = [ { name = "pyyaml" },