Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/modaic/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _from_precompiled(repo_dir: Path, hub_path: str = None, **kwargs) -> Precomp
cfg = json.load(fp)

ConfigClass = _load_auto_class(repo_dir, "AutoConfig", hub_path=hub_path) # noqa: N806
return ConfigClass(**{**cfg, **kwargs})
return ConfigClass.from_dict(cfg, **kwargs)


class AutoProgram:
Expand Down
87 changes: 83 additions & 4 deletions src/modaic/precompiled.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import importlib
import inspect
import json
import os
import pathlib
import sys
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
Expand All @@ -25,16 +27,88 @@
from .hub import load_repo, push_folder_to_hub

if TYPE_CHECKING:
from modaic.context.base import Context

Check failure on line 30 in src/modaic/precompiled.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/modaic/precompiled.py:30:37: F401 `modaic.context.base.Context` imported but unused

C = TypeVar("C", bound="PrecompiledConfig")
A = TypeVar("A", bound="PrecompiledProgram")
R = TypeVar("R", bound="Retriever")

# Special marker for serialized DSPy signatures
_DSPY_SIGNATURE_PREFIX = "__dspy_signature__:"


class PrecompiledConfig(BaseModel):
model: Optional[str] = None

@staticmethod
def _is_dspy_signature(obj: Any) -> bool:
"""Check if an object is a DSPy Signature class (not instance)."""
try:
# Check if it's a class that inherits from dspy.Signature
return isinstance(obj, type) and issubclass(obj, dspy.Signature)
except (TypeError, AttributeError):
return False

@classmethod
def _get_signature_module_path(cls, sig_class: Type) -> str:
"""Get the module path for a DSPy signature class."""
from .module_utils import resolve_project_root

module_name = sig_class.__module__

# If it's defined in __main__, try to resolve the actual module path
if module_name == "__main__":
module = sys.modules[module_name]
if hasattr(module, "__file__") and module.__file__:
file_path = Path(module.__file__)
try:
project_root = resolve_project_root()
rel_path = file_path.relative_to(project_root).with_suffix("")
module_path = str(rel_path).replace("/", ".")
return f"{module_path}.{sig_class.__name__}"
except (ValueError, FileNotFoundError):
# Fallback to just using the class name
return sig_class.__name__

return f"{module_name}.{sig_class.__name__}"

@classmethod
def _serialize_dspy_signatures(cls, obj: Any) -> Any:
"""Recursively serialize DSPy Signature classes in nested structures."""
if cls._is_dspy_signature(obj):
module_path = cls._get_signature_module_path(obj)
return f"{_DSPY_SIGNATURE_PREFIX}{module_path}"
elif isinstance(obj, dict):
return {key: cls._serialize_dspy_signatures(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(cls._serialize_dspy_signatures(item) for item in obj)
else:
return obj

@classmethod
def _deserialize_dspy_signatures(cls, obj: Any) -> Any:
"""Recursively deserialize DSPy Signature classes from nested structures."""
if isinstance(obj, str) and obj.startswith(_DSPY_SIGNATURE_PREFIX):
# Extract the module path
module_path = obj[len(_DSPY_SIGNATURE_PREFIX) :]
# Import and return the signature class
module_name, _, class_name = module_path.rpartition(".")
try:
module = importlib.import_module(module_name)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
warnings.warn(
f"Failed to import DSPy signature '{module_path}': {e}. Returning the serialized string instead.",
stacklevel=2,
)
return obj
elif isinstance(obj, dict):
return {key: cls._deserialize_dspy_signatures(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(cls._deserialize_dspy_signatures(item) for item in obj)
else:
return obj

def save_precompiled(
self,
path: str | Path,
Expand Down Expand Up @@ -87,7 +161,7 @@
path = local_dir / "config.json"
with open(path, "r") as f:
config_dict = json.load(f)
return cls(**{**config_dict, **kwargs})
return cls.from_dict(config_dict, **kwargs)

@classmethod
def from_dict(cls: Type[C], dict: Dict, **kwargs) -> C:
Expand All @@ -101,7 +175,10 @@
Returns:
An instance of the PrecompiledConfig class.
"""
instance = cls(**{**dict, **kwargs})
# Deserialize any DSPy signatures
deserialized_dict = cls._deserialize_dspy_signatures(dict)
deserialized_kwargs = cls._deserialize_dspy_signatures(kwargs)
instance = cls(**{**deserialized_dict, **deserialized_kwargs})
return instance

@classmethod
Expand All @@ -122,9 +199,11 @@

def to_dict(self) -> Dict:
"""
Converts the config to a dictionary.
Converts the config to a dictionary, handling DSPy signatures.
"""
return self.model_dump()
result = self.model_dump()
# Serialize any DSPy signatures to importable module paths
return self._serialize_dspy_signatures(result)

def to_json(self) -> str:
"""
Expand Down Expand Up @@ -168,7 +247,7 @@
# TODO: throw a warning if the config of the retriever has different values than the config of the program

# def __init_subclass__(cls, **kwargs):
# super().__init_subclass__(**kwargs)

Check failure on line 250 in src/modaic/precompiled.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ERA001)

src/modaic/precompiled.py:250:5: ERA001 Found commented-out code
# # Make sure subclasses have an annotated config attribute
# if not (config_class := cls.__annotations__.get("config")) or config_class is PrecompiledConfig:
# raise ValueError(
Expand All @@ -176,9 +255,9 @@
# Hint: Please add an annotation for config to your subclass.
# Example:
# class {cls.__name__}(PrecompiledProgram):
# config: YourConfigClass

Check failure on line 258 in src/modaic/precompiled.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ERA001)

src/modaic/precompiled.py:258:5: ERA001 Found commented-out code
# def __init__(self, config: YourConfigClass, **kwargs):
# super().__init__(config, **kwargs)

Check failure on line 260 in src/modaic/precompiled.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ERA001)

src/modaic/precompiled.py:260:5: ERA001 Found commented-out code
# ...
# """
# )
Expand Down
74 changes: 74 additions & 0 deletions tests/test_precompiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,40 @@ class Summarize(dspy.Signature):
answer: str = dspy.OutputField(desc="Answer to the question, based on the passage")


class ClassifyEmotion(dspy.Signature):
"""Classify the emotion in the sentence."""

sentence: str = dspy.InputField()
emotion: str = dspy.OutputField()


class ExampleConfig(PrecompiledConfig):
output_type: Literal["bool", "str"]
lm: str = "openai/gpt-4o-mini"
number: int = 1


class ConfigWithSignature(PrecompiledConfig):
"""Config that includes a DSPy signature as a field."""

signature: Type[dspy.Signature]
lm: str = "openai/gpt-4o-mini"


class ProgramWithSignatureConfig(PrecompiledProgram):
"""Program that uses a config with a DSPy signature."""

config: ConfigWithSignature

def __init__(self, config: ConfigWithSignature, **kwargs):
super().__init__(config, **kwargs)
self.predictor = dspy.Predict(config.signature)
self.predictor.lm = dspy.LM(config.lm)

def forward(self, **kwargs) -> str:
return self.predictor(**kwargs)


class ExampleProgram(PrecompiledProgram):
config: ExampleConfig

Expand Down Expand Up @@ -506,3 +534,49 @@ def test_no_config_w_retriever_hub(hub_repo: str):
assert len(os.listdir(temp_dir)) == 4
loaded_program = NoConfigWhRetrieverProgram.from_precompiled(hub_repo, runtime_param="wassuhh", retriever=retriever)
assert loaded_program.runtime_param == "wassuhh"


def test_config_with_dspy_signature_local(clean_folder: Path):
"""Test that configs with DSPy signatures can be serialized and deserialized."""
config = ConfigWithSignature(signature=ClassifyEmotion)
config.save_precompiled(clean_folder)

assert os.path.exists(clean_folder / "config.json")

# Verify the signature was serialized correctly
with open(clean_folder / "config.json", "r") as f:
config_json = json.load(f)
assert "signature" in config_json
assert config_json["signature"].startswith("__dspy_signature__:")

# Load the config back
loaded_config = ConfigWithSignature.from_precompiled(clean_folder)
assert loaded_config.signature == ClassifyEmotion
assert loaded_config.lm == "openai/gpt-4o-mini"

# Test with different signature
config2 = ConfigWithSignature(signature=Summarize, lm="openai/gpt-4o")
config2.save_precompiled(clean_folder)
loaded_config2 = ConfigWithSignature.from_precompiled(clean_folder)
assert loaded_config2.signature == Summarize
assert loaded_config2.lm == "openai/gpt-4o"


def test_program_with_dspy_signature_local(clean_folder: Path):
"""Test that programs with DSPy signature configs can be saved and loaded."""
config = ConfigWithSignature(signature=ClassifyEmotion)
program = ProgramWithSignatureConfig(config=config)
program.save_precompiled(clean_folder)

assert os.path.exists(clean_folder / "config.json")
assert os.path.exists(clean_folder / "program.json")

# Verify the signature was serialized correctly
with open(clean_folder / "config.json", "r") as f:
config_json = json.load(f)
assert config_json["signature"].startswith("__dspy_signature__:")

# Load the program back
loaded_program = ProgramWithSignatureConfig.from_precompiled(clean_folder)
assert loaded_program.config.signature == ClassifyEmotion
assert loaded_program.config.lm == "openai/gpt-4o-mini"
Loading