diff --git a/pyproject.toml b/pyproject.toml index 3b37df3..f4c7a5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "promptix" -version = "0.1.12" +version = "0.1.13" description = "A simple library for managing and using prompts locally with Promptix Studio" readme = "README.md" requires-python = ">=3.9" diff --git a/src/promptix/__init__.py b/src/promptix/__init__.py index 4d452e8..e08b955 100644 --- a/src/promptix/__init__.py +++ b/src/promptix/__init__.py @@ -21,7 +21,7 @@ config = Promptix.builder("template_name").with_variable("value").build() """ -from .core.base import Promptix +from .core.base_refactored import Promptix -__version__ = "0.1.11" +__version__ = "0.1.13" __all__ = ["Promptix"] diff --git a/src/promptix/core/base_refactored.py b/src/promptix/core/base_refactored.py new file mode 100644 index 0000000..162aa65 --- /dev/null +++ b/src/promptix/core/base_refactored.py @@ -0,0 +1,257 @@ +""" +Refactored Promptix main class using dependency injection and focused components. + +This module provides the main Promptix class that has been refactored to use +focused components and dependency injection for better testability and modularity. +""" + +from typing import Any, Dict, Optional, List +from .container import get_container +from .components import ( + PromptLoader, + VariableValidator, + TemplateRenderer, + VersionManager, + ModelConfigBuilder +) +from .exceptions import PromptNotFoundError, ConfigurationError, StorageError + +class Promptix: + """Main class for managing and using prompts with schema validation and template rendering.""" + + def __init__(self, container=None): + """Initialize Promptix with dependency injection. + + Args: + container: Optional container for dependency injection. If None, uses global container. + """ + self._container = container or get_container() + + # Get dependencies from container + self._prompt_loader = self._container.get_typed("prompt_loader", PromptLoader) + self._variable_validator = self._container.get_typed("variable_validator", VariableValidator) + self._template_renderer = self._container.get_typed("template_renderer", TemplateRenderer) + self._version_manager = self._container.get_typed("version_manager", VersionManager) + self._model_config_builder = self._container.get_typed("model_config_builder", ModelConfigBuilder) + self._logger = self._container.get("logger") + + @classmethod + def get_prompt(cls, prompt_template: str, version: Optional[str] = None, **variables) -> str: + """Get a prompt by name and fill in the variables. + + Args: + prompt_template (str): The name of the prompt template to use + version (Optional[str]): Specific version to use (e.g. "v1"). + If None, uses the latest live version. + **variables: Variable key-value pairs to fill in the prompt template + + Returns: + str: The rendered prompt + + Raises: + PromptNotFoundError: If the prompt template is not found + RequiredVariableError: If required variables are missing + VariableValidationError: If a variable doesn't match the schema type + TemplateRenderError: If template rendering fails + """ + instance = cls() + return instance.render_prompt(prompt_template, version, **variables) + + def render_prompt(self, prompt_template: str, version: Optional[str] = None, **variables) -> str: + """Render a prompt with the provided variables. + + Args: + prompt_template: The name of the prompt template to use. + version: Specific version to use. If None, uses the live version. + **variables: Variable key-value pairs to fill in the prompt template. + + Returns: + The rendered prompt string. + + Raises: + PromptNotFoundError: If the prompt template is not found. + RequiredVariableError: If required variables are missing. + VariableValidationError: If a variable doesn't match the schema type. + TemplateRenderError: If template rendering fails. + """ + # Load prompt data + try: + prompt_data = self._prompt_loader.get_prompt_data(prompt_template) + except StorageError as err: + try: + available_prompts = list(self._prompt_loader.get_prompts().keys()) + except StorageError: + available_prompts = [] + raise PromptNotFoundError( + prompt_name=prompt_template, + available_prompts=available_prompts + ) from err + versions = prompt_data.get("versions", {}) + + # Get the appropriate version data + version_data = self._version_manager.get_version_data(versions, version, prompt_template) + + # Get the system instruction template + try: + template_text = self._version_manager.get_system_instruction(version_data, prompt_template) + except ValueError as err: + raise ConfigurationError( + config_issue="Missing 'config.system_instruction'", + config_path=f"{prompt_template}.versions" + ) from err + + # Validate variables against schema + schema = version_data.get("schema", {}) + self._variable_validator.validate_variables(schema, variables, prompt_template) + + # Render the template + result = self._template_renderer.render_template(template_text, variables, prompt_template) + + return result + + @classmethod + def prepare_model_config( + cls, + prompt_template: str, + memory: List[Dict[str, str]], + version: Optional[str] = None, + **variables + ) -> Dict[str, Any]: + """Prepare a model configuration ready for OpenAI chat completion API. + + Args: + prompt_template (str): The name of the prompt template to use + memory (List[Dict[str, str]]): List of previous messages in the conversation + version (Optional[str]): Specific version to use (e.g. "v1"). + If None, uses the latest live version. + **variables: Variable key-value pairs to fill in the prompt template + + Returns: + Dict[str, Any]: Configuration dictionary for OpenAI chat completion API + + Raises: + PromptNotFoundError: If the prompt template is not found + InvalidMemoryFormatError: If memory format is invalid + RequiredVariableError: If required variables are missing + VariableValidationError: If a variable doesn't match the schema type + ConfigurationError: If required configuration is missing + """ + instance = cls() + return instance.build_model_config(prompt_template, memory, version, **variables) + + def build_model_config( + self, + prompt_template: str, + memory: List[Dict[str, str]], + version: Optional[str] = None, + **variables + ) -> Dict[str, Any]: + """Build a model configuration. + + Args: + prompt_template: The name of the prompt template to use. + memory: List of previous messages in the conversation. + version: Specific version to use. If None, uses the live version. + **variables: Variable key-value pairs to fill in the prompt template. + + Returns: + Configuration dictionary for the model API. + + Raises: + PromptNotFoundError: If the prompt template is not found. + InvalidMemoryFormatError: If memory format is invalid. + RequiredVariableError: If required variables are missing. + VariableValidationError: If a variable doesn't match the schema type. + ConfigurationError: If required configuration is missing. + """ + # Get the prompt data for version information + try: + prompt_data = self._prompt_loader.get_prompt_data(prompt_template) + except StorageError as err: + try: + available_prompts = list(self._prompt_loader.get_prompts().keys()) + except StorageError: + available_prompts = [] + raise PromptNotFoundError( + prompt_name=prompt_template, + available_prompts=available_prompts + ) from err + + versions = prompt_data.get("versions", {}) + version_data = self._version_manager.get_version_data(versions, version, prompt_template) + + # Render the system message + system_message = self.render_prompt(prompt_template, version, **variables) + + # Build the model configuration + return self._model_config_builder.build_model_config( + system_message=system_message, + memory=memory, + version_data=version_data, + prompt_name=prompt_template + ) + + @staticmethod + def builder(prompt_template: str, container=None): + """Create a new PromptixBuilder instance for building model configurations. + + Args: + prompt_template (str): The name of the prompt template to use + container: Optional container for dependency injection + + Returns: + PromptixBuilder: A builder instance for configuring the model + """ + from .builder_refactored import PromptixBuilder + return PromptixBuilder(prompt_template, container) + + def list_prompts(self) -> Dict[str, Any]: + """List all available prompts. + + Returns: + Dictionary of all available prompts. + """ + return self._prompt_loader.get_prompts() + + def list_versions(self, prompt_template: str) -> List[Dict[str, Any]]: + """List all versions for a specific prompt. + + Args: + prompt_template: Name of the prompt template. + + Returns: + List of version information. + + Raises: + PromptNotFoundError: If the prompt template is not found. + """ + try: + prompt_data = self._prompt_loader.get_prompt_data(prompt_template) + except StorageError as err: + try: + available_prompts = list(self._prompt_loader.get_prompts().keys()) + except StorageError: + available_prompts = [] + raise PromptNotFoundError( + prompt_name=prompt_template, + available_prompts=available_prompts + ) from err + + versions = prompt_data.get("versions", {}) + return self._version_manager.list_versions(versions) + def validate_template(self, template_text: str) -> bool: + """Validate that a template is syntactically correct. + + Args: + template_text: The template text to validate. + + Returns: + True if the template is valid, False otherwise. + """ + return self._template_renderer.validate_template(template_text) + + def reload_prompts(self) -> None: + """Force reload prompts from storage.""" + self._prompt_loader.reload_prompts() + if self._logger: + self._logger.info("Prompts reloaded successfully") diff --git a/src/promptix/core/builder_refactored.py b/src/promptix/core/builder_refactored.py new file mode 100644 index 0000000..0c6b906 --- /dev/null +++ b/src/promptix/core/builder_refactored.py @@ -0,0 +1,611 @@ +""" +Refactored PromptixBuilder class using dependency injection and focused components. + +This module provides the PromptixBuilder class that has been refactored to use +focused components and dependency injection for better testability and modularity. +""" + +from typing import Any, Dict, List, Optional, Union +from .container import get_container +from .components import ( + PromptLoader, + VariableValidator, + TemplateRenderer, + VersionManager, + ModelConfigBuilder +) +from .adapters._base import ModelAdapter +from .exceptions import ( + PromptNotFoundError, + VersionNotFoundError, + UnsupportedClientError, + ToolNotFoundError, + ToolProcessingError, + ValidationError, + StorageError, + RequiredVariableError, + VariableValidationError +) + + +class PromptixBuilder: + """Builder class for creating model configurations using dependency injection.""" + + def __init__(self, prompt_template: str, container=None): + """Initialize the builder with dependency injection. + + Args: + prompt_template: The name of the prompt template to use. + container: Optional container for dependency injection. If None, uses global container. + + Raises: + PromptNotFoundError: If the prompt template is not found. + """ + self._container = container or get_container() + self.prompt_template = prompt_template + self.custom_version = None + self._data = {} # Holds all variables + self._memory = [] # Conversation history + self._client = "openai" # Default client + self._model_params = {} # Holds direct model parameters + + # Get dependencies from container + self._prompt_loader = self._container.get_typed("prompt_loader", PromptLoader) + self._variable_validator = self._container.get_typed("variable_validator", VariableValidator) + self._template_renderer = self._container.get_typed("template_renderer", TemplateRenderer) + self._version_manager = self._container.get_typed("version_manager", VersionManager) + self._model_config_builder = self._container.get_typed("model_config_builder", ModelConfigBuilder) + self._logger = self._container.get("logger") + self._adapters = self._container.get("adapters") + + # Initialize prompt data + self._initialize_prompt_data() + + def _initialize_prompt_data(self) -> None: + """Initialize prompt data and find live version. + + Raises: + PromptNotFoundError: If the prompt template is not found. + """ + try: + self.prompt_data = self._prompt_loader.get_prompt_data(self.prompt_template) + except StorageError as err: + try: + available_prompts = list(self._prompt_loader.get_prompts().keys()) + except StorageError: + available_prompts = [] + raise PromptNotFoundError( + prompt_name=self.prompt_template, + available_prompts=available_prompts + ) from err + versions = self.prompt_data.get("versions", {}) + live_version_key = self._version_manager.find_live_version(versions, self.prompt_template) + self.version_data = versions[live_version_key] + + # Extract schema properties + schema = self.version_data.get("schema", {}) + self.properties = schema.get("properties", {}) + self.allow_additional = schema.get("additionalProperties", False) + + @classmethod + def register_adapter(cls, client_name: str, adapter: ModelAdapter, container=None) -> None: + """Register a new adapter for a client. + + Args: + client_name: Name of the client. + adapter: The adapter instance. + container: Optional container. If None, uses global container. + + Raises: + InvalidDependencyError: If the adapter is not a ModelAdapter instance. + """ + _container = container or get_container() + _container.register_adapter(client_name, adapter) + + def _validate_type(self, field: str, value: Any) -> None: + """Validate that a value matches its schema-defined type. + + Args: + field: Name of the field to validate. + value: Value to validate. + + Raises: + ValidationError: If validation fails. + """ + if field not in self.properties: + if not self.allow_additional: + raise ValidationError( + f"Field '{field}' is not defined in the schema and additional properties are not allowed.", + details={"field": field, "value": value} + ) + return + + self._variable_validator.validate_builder_type(field, value, self.properties) + + def __getattr__(self, name: str): + """Dynamically handle chainable with_() methods. + + Args: + name: Name of the method being called. + + Returns: + A setter function for chainable method calls. + + Raises: + AttributeError: If the method is not a valid with_* method. + """ + if name.startswith("with_"): + field = name[5:] + + def setter(value: Any): + self._validate_type(field, value) + self._data[field] = value + return self + return setter + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def with_data(self, **kwargs: Dict[str, Any]): + """Set multiple variables at once using keyword arguments. + + Args: + **kwargs: Variables to set. + + Returns: + Self for method chaining. + """ + for field, value in kwargs.items(): + self._validate_type(field, value) + self._data[field] = value + return self + + def with_var(self, variables: Dict[str, Any]): + """Set multiple variables at once using a dictionary. + + This method allows passing a dictionary of variables to be used in prompt templates + and tools configuration. All variables are made available to the tools_template + Jinja2 template for conditional tool selection. + + Args: + variables: Dictionary of variable names and their values to be set + + Returns: + Self for method chaining + + Example: + ```python + config = (Promptix.builder("ComplexCodeReviewer") + .with_var({ + 'programming_language': 'Python', + 'severity': 'high', + 'review_focus': 'security and performance' + }) + .build()) + ``` + """ + for field, value in variables.items(): + self._validate_type(field, value) + self._data[field] = value + return self + + def with_extra(self, extra_params: Dict[str, Any]): + """Set additional/extra parameters to be passed directly to the model API. + + Args: + extra_params: Dictionary containing model parameters to be passed directly + to the API (e.g., temperature, top_p, max_tokens). + + Returns: + Self reference for method chaining. + """ + self._model_params.update(extra_params) + return self + + def with_memory(self, memory: List[Dict[str, str]]): + """Set the conversation memory. + + Args: + memory: List of message dictionaries. + + Returns: + Self for method chaining. + + Raises: + InvalidMemoryFormatError: If memory format is invalid. + """ + # Use the model config builder to validate memory format + self._model_config_builder.validate_memory_format(memory) + self._memory = memory + return self + + def for_client(self, client: str): + """Set the client to use for building the configuration. + + Args: + client: Name of the client to use. + + Returns: + Self for method chaining. + + Raises: + UnsupportedClientError: If the client is not supported. + """ + # Check if we have an adapter for this client + if client not in self._adapters: + available_clients = list(self._adapters.keys()) + raise UnsupportedClientError( + client_name=client, + available_clients=available_clients + ) + + # Check compatibility and warn if necessary + self._check_client_compatibility(client) + + self._client = client + return self + + def _check_client_compatibility(self, client: str) -> None: + """Check if the client is compatible with the prompt version. + + Args: + client: Name of the client to check. + """ + provider = self.version_data.get("provider", "").lower() + config_provider = self.version_data.get("config", {}).get("provider", "").lower() + + # Use either provider field + effective_provider = provider or config_provider + + # Issue warning if providers don't match + if effective_provider and effective_provider != client: + warning_msg = ( + f"Client '{client}' may not be fully compatible with this prompt version. " + f"This prompt version is configured for '{effective_provider}'. " + f"Some features may not work as expected." + ) + if self._logger: + self._logger.warning(warning_msg) + + def with_version(self, version: str): + """Set a specific version of the prompt template to use. + + Args: + version: Version identifier to use. + + Returns: + Self for method chaining. + + Raises: + VersionNotFoundError: If the version is not found. + """ + versions = self.prompt_data.get("versions", {}) + if version not in versions: + available_versions = list(versions.keys()) + raise VersionNotFoundError( + version=version, + prompt_name=self.prompt_template, + available_versions=available_versions + ) + + self.custom_version = version + self.version_data = versions[version] + + # Update schema properties for the new version + schema = self.version_data.get("schema", {}) + self.properties = schema.get("properties", {}) + self.allow_additional = schema.get("additionalProperties", False) + + # Set the client based on the provider in version_data + provider = self.version_data.get("provider", "openai").lower() + if provider in self._adapters: + self._client = provider + + return self + + def with_tool(self, tool_name: str, *args, **kwargs) -> "PromptixBuilder": + """Activate a tool by name. + + Args: + tool_name: Name of the tool to activate + *args: Additional tool names to activate + + Returns: + Self for method chaining + """ + # First handle the primary tool_name + self._activate_tool(tool_name) + + # Handle any additional tool names passed as positional arguments + for tool in args: + self._activate_tool(tool) + + return self + + def _activate_tool(self, tool_name: str) -> None: + """Internal helper to activate a single tool. + + Args: + tool_name: Name of the tool to activate. + """ + # Validate tool exists in prompts configuration + tools_config = self.version_data.get("tools_config", {}) + tools = tools_config.get("tools", {}) + + if tool_name in tools: + # Store tool activation as a template variable + tool_var = f"use_{tool_name}" + self._data[tool_var] = True + else: + available_tools = list(tools.keys()) if tools else [] + if self._logger: + self._logger.warning( + f"Tool '{tool_name}' not found. Available tools: {available_tools}" + ) + + def with_tool_parameter(self, tool_name: str, param_name: str, param_value: Any) -> "PromptixBuilder": + """Set a parameter value for a specific tool. + + Args: + tool_name: Name of the tool to configure + param_name: Name of the parameter to set + param_value: Value to set for the parameter + + Returns: + Self for method chaining + """ + # Validate tool exists + tools_config = self.version_data.get("tools_config", {}) + tools = tools_config.get("tools", {}) + + if tool_name not in tools: + available_tools = list(tools.keys()) if tools else [] + if self._logger: + self._logger.warning( + f"Tool '{tool_name}' not found. Available tools: {available_tools}" + ) + return self + + # Make sure the tool is activated + tool_var = f"use_{tool_name}" + if tool_var not in self._data or not self._data[tool_var]: + self._data[tool_var] = True + + # Store parameter in a dedicated location + param_key = f"tool_params_{tool_name}" + if param_key not in self._data: + self._data[param_key] = {} + + self._data[param_key][param_name] = param_value + return self + + def enable_tools(self, *tool_names: str) -> "PromptixBuilder": + """Enable multiple tools at once. + + Args: + *tool_names: Names of tools to enable + + Returns: + Self for method chaining + """ + for tool_name in tool_names: + self.with_tool(tool_name) + return self + + def disable_tools(self, *tool_names: str) -> "PromptixBuilder": + """Disable specific tools. + + Args: + *tool_names: Names of tools to disable + + Returns: + Self for method chaining + """ + for tool_name in tool_names: + tool_var = f"use_{tool_name}" + self._data[tool_var] = False + return self + + def disable_all_tools(self) -> "PromptixBuilder": + """Disable all available tools. + + Returns: + Self for method chaining + """ + tools_config = self.version_data.get("tools_config", {}) + tools = tools_config.get("tools", {}) + + for tool_name in tools.keys(): + tool_var = f"use_{tool_name}" + self._data[tool_var] = False + + return self + + def _process_tools_template(self) -> List[Dict[str, Any]]: + """Process the tools template and return the configured tools. + + Returns: + List of configured tools. + """ + tools_config = self.version_data.get("tools_config", {}) + available_tools = tools_config.get("tools", {}) + + if not tools_config or not available_tools: + return [] + + # Track both template-selected and explicitly activated tools + selected_tools = {} + + # First, find explicitly activated tools (via with_tool) + for tool_name in available_tools.keys(): + prefixed_name = f"use_{tool_name}" + if (tool_name in self._data and self._data[tool_name]) or \ + (prefixed_name in self._data and self._data[prefixed_name]): + selected_tools[tool_name] = available_tools[tool_name] + + # Process tools template if available + tools_template = tools_config.get("tools_template") + if tools_template: + try: + template_result = self._template_renderer.render_tools_template( + tools_template=tools_template, + variables=self._data, + available_tools=available_tools, + prompt_name=self.prompt_template + ) + if template_result: + self._process_template_result(template_result, available_tools, selected_tools) + except TemplateRenderError as e: + if self._logger: + self._logger.warning(f"Error processing tools template: {e!s}") + # Let unexpected exceptions bubble up + # If no tools selected, return empty list + if not selected_tools: + return [] + + try: + # Convert to the format expected by the adapter + adapter = self._adapters[self._client] + return adapter.process_tools(selected_tools) + + except Exception as e: + if self._logger: + self._logger.warning(f"Error processing tools: {str(e)}") + return [] + + def _process_template_result( + self, + template_result: Any, + available_tools: Dict[str, Any], + selected_tools: Dict[str, Any] + ) -> None: + """Process the result from tools template rendering. + + Args: + template_result: Result from template rendering. + available_tools: Available tools configuration. + selected_tools: Dictionary to update with selected tools. + """ + # Handle different return types from template + if isinstance(template_result, list): + # If it's a list of tool names (new format) + if all(isinstance(item, str) for item in template_result): + for tool_name in template_result: + if tool_name in available_tools and tool_name not in selected_tools: + selected_tools[tool_name] = available_tools[tool_name] + # If it's a list of tool objects (old format for backward compatibility) + elif all(isinstance(item, dict) for item in template_result): + for tool in template_result: + if isinstance(tool, dict) and 'name' in tool: + tool_name = tool['name'] + if tool_name in available_tools and tool_name not in selected_tools: + selected_tools[tool_name] = available_tools[tool_name] + # If it's a dictionary of tools (old format for backward compatibility) + elif isinstance(template_result, dict): + for tool_name, tool_config in template_result.items(): + if tool_name not in selected_tools: + selected_tools[tool_name] = tool_config + + def build(self, system_only: bool = False) -> Union[Dict[str, Any], str]: + """Build the final configuration using the appropriate adapter. + + Args: + system_only: If True, returns only the system instruction string. + + Returns: + Either the full model configuration dictionary or just the system instruction string. + """ + # Validate all required fields are present + missing_fields = [] + for field, props in self.properties.items(): + if props.get("required", False) and field not in self._data: + missing_fields.append(field) + if self._logger: + self._logger.warning(f"Required field '{field}' is missing from prompt parameters") + + try: + # Generate the system message using the template renderer + from .base_refactored import Promptix # Import here to avoid circular dependency + promptix_instance = Promptix(self._container) + system_message = promptix_instance.render_prompt(self.prompt_template, self.custom_version, **self._data) + except (ValueError, ImportError, RuntimeError, RequiredVariableError, VariableValidationError) as e: + if self._logger: + self._logger.warning(f"Error generating system message: {e!s}") + # Provide a fallback basic message when template rendering fails + system_message = f"You are an AI assistant for {self.prompt_template}." + + # If system_only is True, just return the system message + if system_only: + return system_message + + # Build configuration based on client type + if self._client == "anthropic": + model_config = self._model_config_builder.prepare_anthropic_config( + system_message=system_message, + memory=self._memory, + version_data=self.version_data, + prompt_name=self.prompt_template + ) + else: + # For OpenAI and others + model_config = self._model_config_builder.build_model_config( + system_message=system_message, + memory=self._memory, + version_data=self.version_data, + prompt_name=self.prompt_template + ) + + # Add any direct model parameters from with_extra + model_config.update(self._model_params) + + # Process tools configuration + try: + tools = self._process_tools_template() + if tools: + model_config["tools"] = tools + except Exception as e: + if self._logger: + self._logger.warning(f"Error processing tools: {str(e)}") + + # Get the appropriate adapter and adapt the configuration + adapter = self._adapters[self._client] + try: + model_config = adapter.adapt_config(model_config, self.version_data) + except Exception as e: + if self._logger: + self._logger.warning(f"Error adapting configuration for client {self._client}: {str(e)}") + + return model_config + + def system_instruction(self) -> str: + """Get only the system instruction/prompt as a string. + + Returns: + The rendered system instruction string + """ + return self.build(system_only=True) + + def debug_tools(self) -> Dict[str, Any]: + """Debug method to inspect the tools configuration. + + Returns: + Dict containing tools configuration information for debugging. + """ + tools_config = self.version_data.get("tools_config", {}) + tools = tools_config.get("tools", {}) + tools_template = tools_config.get("tools_template") if tools_config else None + + # Create context for template rendering + template_context = { + "tools_config": tools_config, + "tools": tools, + **self._data + } + + # Return debug information + return { + "has_tools_config": bool(tools_config), + "has_tools": bool(tools), + "has_tools_template": bool(tools_template), + "available_tools": list(tools.keys()) if tools else [], + "template_context_keys": list(template_context.keys()), + "tool_activation_flags": {k: v for k, v in self._data.items() if k.startswith("use_")} + } diff --git a/src/promptix/core/components/__init__.py b/src/promptix/core/components/__init__.py new file mode 100644 index 0000000..21be061 --- /dev/null +++ b/src/promptix/core/components/__init__.py @@ -0,0 +1,20 @@ +""" +Focused components for Promptix architecture. + +This module contains focused, single-responsibility components that together +provide the functionality of the Promptix system. +""" + +from .prompt_loader import PromptLoader +from .variable_validator import VariableValidator +from .template_renderer import TemplateRenderer +from .version_manager import VersionManager +from .model_config_builder import ModelConfigBuilder + +__all__ = [ + "PromptLoader", + "VariableValidator", + "TemplateRenderer", + "VersionManager", + "ModelConfigBuilder" +] diff --git a/src/promptix/core/components/model_config_builder.py b/src/promptix/core/components/model_config_builder.py new file mode 100644 index 0000000..ec24289 --- /dev/null +++ b/src/promptix/core/components/model_config_builder.py @@ -0,0 +1,217 @@ +""" +ModelConfigBuilder component for building model configurations. + +This component handles building model configurations for different +providers, including message formatting and parameter validation. +""" + +from typing import Any, Dict, List, Optional, Union +from ..exceptions import InvalidMemoryFormatError, ConfigurationError + + +class ModelConfigBuilder: + """Handles building model configurations for API calls.""" + + def __init__(self, logger=None): + """Initialize the model config builder. + + Args: + logger: Optional logger instance for dependency injection. + """ + self._logger = logger + + def build_model_config( + self, + system_message: str, + memory: List[Dict[str, str]], + version_data: Dict[str, Any], + prompt_name: str + ) -> Dict[str, Any]: + """Build a complete model configuration. + + Args: + system_message: The system message/instruction. + memory: List of conversation history messages. + version_data: Version configuration data. + prompt_name: Name of the prompt for error reporting. + + Returns: + Complete model configuration dictionary. + + Raises: + InvalidMemoryFormatError: If memory format is invalid. + ConfigurationError: If required configuration is missing. + """ + # Validate memory format + self.validate_memory_format(memory) + + # Validate system message + if not system_message.strip(): + raise ConfigurationError("System message cannot be empty") + + # Initialize the base configuration + model_config = { + "messages": [{"role": "system", "content": system_message}] + } + model_config["messages"].extend(memory) + + # Get configuration from version data + config = version_data.get("config", {}) + + # Model is required + if "model" not in config: + raise ConfigurationError( + f"Model must be specified in the version data config for prompt '{prompt_name}'" + ) + model_config["model"] = config["model"] + + # Add optional configuration parameters + self._add_optional_parameters(model_config, config) + + # Add tools configuration if present + self._add_tools_configuration(model_config, config) + + return model_config + + def validate_memory_format(self, memory: List[Dict[str, str]]) -> None: + """Validate the format of conversation memory. + + Args: + memory: List of message dictionaries to validate. + + Raises: + InvalidMemoryFormatError: If memory format is invalid. + """ + if not isinstance(memory, list): + raise InvalidMemoryFormatError("Memory must be a list of message dictionaries") + + for i, msg in enumerate(memory): + if not isinstance(msg, dict): + raise InvalidMemoryFormatError( + f"Message at index {i} must be a dictionary", + invalid_message=msg + ) + + if "role" not in msg or "content" not in msg: + raise InvalidMemoryFormatError( + f"Message at index {i} must have 'role' and 'content' keys", + invalid_message=msg + ) + + if msg["role"] not in ["user", "assistant", "system"]: + raise InvalidMemoryFormatError( + f"Message role at index {i} must be 'user', 'assistant', or 'system'", + invalid_message=msg + ) + + if not isinstance(msg["content"], str): + raise InvalidMemoryFormatError( + f"Message content at index {i} must be a string", + invalid_message=msg + ) + + if not msg["content"].strip(): + raise InvalidMemoryFormatError( + f"Message content at index {i} cannot be empty", + invalid_message=msg + ) + + def _add_optional_parameters(self, model_config: Dict[str, Any], config: Dict[str, Any]) -> None: + """Add optional parameters to model configuration. + + Args: + model_config: The model configuration to update. + config: The source configuration data. + """ + optional_params = [ + ("temperature", (int, float)), + ("max_tokens", int), + ("top_p", (int, float)), + ("frequency_penalty", (int, float)), + ("presence_penalty", (int, float)) + ] + + for param_name, expected_type in optional_params: + if param_name in config and config[param_name] is not None: + value = config[param_name] + if not isinstance(value, expected_type): + if self._logger: + self._logger.warning( + f"{param_name} must be of type {expected_type}, got {type(value)}" + ) + continue + model_config[param_name] = value + + def _add_tools_configuration(self, model_config: Dict[str, Any], config: Dict[str, Any]) -> None: + """Add tools configuration to model configuration. + + Args: + model_config: The model configuration to update. + config: The source configuration data. + """ + if "tools" in config and config["tools"]: + tools = config["tools"] + if not isinstance(tools, list): + if self._logger: + self._logger.warning("Tools configuration must be a list") + return + + model_config["tools"] = tools + + # If tools are present, also set tool_choice if specified + if "tool_choice" in config: + model_config["tool_choice"] = config["tool_choice"] + + def prepare_anthropic_config( + self, + system_message: str, + memory: List[Dict[str, str]], + version_data: Dict[str, Any], + prompt_name: str + ) -> Dict[str, Any]: + """Build configuration specifically for Anthropic API. + + Args: + system_message: The system message/instruction. + memory: List of conversation history messages. + version_data: Version configuration data. + prompt_name: Name of the prompt for error reporting. + + Returns: + Anthropic-specific model configuration dictionary. + """ + # Validate inputs + self.validate_memory_format(memory) + + if not system_message.strip(): + raise ConfigurationError("System message cannot be empty") + + config = version_data.get("config", {}) + + if "model" not in config: + raise ConfigurationError( + f"Model must be specified in the version data config for prompt '{prompt_name}'" + ) + + # Anthropic format uses separate system parameter + model_config = { + "model": config["model"], + "system": system_message, + "messages": memory + } + + # Add optional parameters + self._add_optional_parameters(model_config, config) + + return model_config + + def build_system_only_config(self, system_message: str) -> str: + """Build a configuration that returns only the system message. + + Args: + system_message: The system message to return. + + Returns: + The system message string. + """ + return system_message diff --git a/src/promptix/core/components/prompt_loader.py b/src/promptix/core/components/prompt_loader.py new file mode 100644 index 0000000..151ebef --- /dev/null +++ b/src/promptix/core/components/prompt_loader.py @@ -0,0 +1,141 @@ +""" +PromptLoader component for loading and managing prompts from storage. + +This component is responsible for loading prompts from the storage system +and managing the prompt data in memory. +""" + +from pathlib import Path +from typing import Any, Dict, Optional +from ..exceptions import StorageError, StorageFileNotFoundError, UnsupportedFormatError +from ..storage.loaders import PromptLoaderFactory +from ..storage.utils import create_default_prompts_file +from ..config import config + + +class PromptLoader: + """Handles loading and managing prompts from storage.""" + + def __init__(self, logger=None): + """Initialize the prompt loader. + + Args: + logger: Optional logger instance for dependency injection. + """ + self._prompts: Dict[str, Any] = {} + self._logger = logger + self._loaded = False + + def load_prompts(self, force_reload: bool = False) -> Dict[str, Any]: + """Load prompts from storage. + + Args: + force_reload: If True, reload prompts even if already loaded. + + Returns: + Dictionary containing all loaded prompts. + + Raises: + StorageError: If loading fails. + UnsupportedFormatError: If JSON format is detected. + """ + if self._loaded and not force_reload: + return self._prompts + + try: + # Check for unsupported JSON files first + unsupported_files = config.check_for_unsupported_files() + if unsupported_files: + json_file = unsupported_files[0] # Get the first JSON file found + raise UnsupportedFormatError( + file_path=str(json_file), + unsupported_format="json", + supported_formats=["yaml"] + ) + + # Use centralized configuration to find prompt file + prompt_file = config.get_prompt_file_path() + + if prompt_file is None: + # No existing prompts file found, create default + prompt_file = config.get_default_prompt_file_path() + self._prompts = create_default_prompts_file(prompt_file) + if self._logger: + self._logger.info(f"Created new prompts file at {prompt_file} with a sample prompt") + self._loaded = True + return self._prompts + + loader = PromptLoaderFactory.get_loader(prompt_file) + self._prompts = loader.load(prompt_file) + if self._logger: + self._logger.info(f"Successfully loaded prompts from {prompt_file}") + self._loaded = True + return self._prompts + + except UnsupportedFormatError: + # Bubble up as-is per public contract. + raise + except StorageError: + # Already a Promptix storage error; preserve type. + raise + except ValueError as e: + # Normalize unknown-extension errors from factory into a structured error. + if 'Unsupported file format' in str(e) and 'prompt_file' in locals(): + ext = str(getattr(prompt_file, "suffix", "")).lstrip('.') + raise UnsupportedFormatError( + file_path=str(prompt_file), + unsupported_format=ext or "unknown", + supported_formats=["yaml", "yml"] + ) from e + raise StorageError("Failed to load prompts", {"cause": str(e)}) from e + except Exception as e: + # Catch-all for anything else, with proper chaining. + raise StorageError("Failed to load prompts", {"cause": str(e)}) from e + + def get_prompts(self) -> Dict[str, Any]: + """Get the loaded prompts. + + Returns: + Dictionary containing all loaded prompts. + """ + if not self._loaded: + return self.load_prompts() + return self._prompts + + def get_prompt_data(self, prompt_template: str) -> Dict[str, Any]: + """Get data for a specific prompt template. + + Args: + prompt_template: Name of the prompt template. + + Returns: + Dictionary containing the prompt data. + + Raises: + StorageError: If prompt is not found. + """ + prompts = self.get_prompts() + if prompt_template not in prompts: + from ..exceptions import PromptNotFoundError + available_prompts = list(prompts.keys()) + raise PromptNotFoundError(prompt_template, available_prompts) + return prompts[prompt_template] + + def is_loaded(self) -> bool: + """Check if prompts have been loaded. + + Returns: + True if prompts are loaded, False otherwise. + """ + return self._loaded + + def reload_prompts(self) -> Dict[str, Any]: + """Force reload prompts from storage. + + Returns: + Dictionary containing all reloaded prompts. + + Raises: + StorageError: If reloading fails. + """ + return self.load_prompts(force_reload=True) diff --git a/src/promptix/core/components/template_renderer.py b/src/promptix/core/components/template_renderer.py new file mode 100644 index 0000000..ea3f73a --- /dev/null +++ b/src/promptix/core/components/template_renderer.py @@ -0,0 +1,146 @@ +""" +TemplateRenderer component for rendering templates with Jinja2. + +This component handles rendering of prompt templates using Jinja2, +including variable substitution and template processing. +""" + +from typing import Any, Dict +from jinja2 import BaseLoader, Environment, TemplateError +from ..exceptions import TemplateRenderError + + +class TemplateRenderer: + """Handles template rendering with Jinja2.""" + + def __init__(self, logger=None): + """Initialize the template renderer. + + Args: + logger: Optional logger instance for dependency injection. + """ + self._jinja_env = Environment( + loader=BaseLoader(), + trim_blocks=True, + lstrip_blocks=True + ) + self._logger = logger + + def render_template( + self, + template_text: str, + variables: Dict[str, Any], + prompt_name: str = "unknown" + ) -> str: + """Render a template with the provided variables. + + Args: + template_text: The template text to render. + variables: Variables to substitute in the template. + prompt_name: Name of the prompt for error reporting. + + Returns: + The rendered template as a string. + + Raises: + TemplateRenderError: If template rendering fails. + """ + try: + template_obj = self._jinja_env.from_string(template_text) + result = template_obj.render(**variables) + + # Convert escaped newlines (\n) to actual line breaks + result = result.replace("\\n", "\n") + + return result + + except TemplateError as e: + raise TemplateRenderError( + prompt_name=prompt_name, + template_error=str(e), + variables=variables + ) + except Exception as e: + # Catch any other rendering errors + raise TemplateRenderError( + prompt_name=prompt_name, + template_error=f"Unexpected error: {str(e)}", + variables=variables + ) + + def render_tools_template( + self, + tools_template: str, + variables: Dict[str, Any], + available_tools: Dict[str, Any], + prompt_name: str = "unknown" + ) -> Any: + """Render a tools template and parse the result. + + Args: + tools_template: The tools template to render. + variables: Variables available to the template. + available_tools: Available tools configuration. + prompt_name: Name of the prompt for error reporting. + + Returns: + Parsed template result (typically a list or dict). + + Raises: + TemplateRenderError: If template rendering or parsing fails. + """ + try: + # Make a copy of variables to avoid modifying the original + template_vars = dict(variables) + + # Add the tools configuration to the template variables + template_vars['tools'] = available_tools + + # Render the template with the variables + template = self._jinja_env.from_string(tools_template) + rendered_template = template.render(**template_vars) + + # Skip empty template output + if not rendered_template.strip(): + return None + + # Parse the rendered template (assuming it returns JSON-like string) + import json + try: + return json.loads(rendered_template) + except json.JSONDecodeError as json_error: + raise TemplateRenderError( + prompt_name=prompt_name, + template_error=f"Tools template rendered invalid JSON: {str(json_error)}", + variables=template_vars + ) + + except TemplateError as e: + raise TemplateRenderError( + prompt_name=prompt_name, + template_error=f"Tools template rendering failed: {str(e)}", + variables=variables + ) + except Exception as e: + raise TemplateRenderError( + prompt_name=prompt_name, + template_error=f"Unexpected error in tools template: {str(e)}", + variables=variables + ) + + def validate_template(self, template_text: str) -> bool: + """Validate that a template is syntactically correct. + + Args: + template_text: The template text to validate. + + Returns: + True if the template is valid, False otherwise. + """ + try: + self._jinja_env.from_string(template_text) + return True + except TemplateError: + return False + except Exception: + return False diff --git a/src/promptix/core/components/variable_validator.py b/src/promptix/core/components/variable_validator.py new file mode 100644 index 0000000..c17284d --- /dev/null +++ b/src/promptix/core/components/variable_validator.py @@ -0,0 +1,210 @@ +""" +VariableValidator component for validating variables against schemas. + +This component handles validation of user-provided variables against +prompt schemas, including type checking and required field validation. +""" + +from typing import Any, Dict, List +from ..exceptions import ( + VariableValidationError, + RequiredVariableError, + create_validation_error +) + + +class VariableValidator: + """Handles validation of variables against prompt schemas.""" + + def __init__(self, logger=None): + """Initialize the variable validator. + + Args: + logger: Optional logger instance for dependency injection. + """ + self._logger = logger + + def validate_variables( + self, + schema: Dict[str, Any], + user_vars: Dict[str, Any], + prompt_name: str + ) -> None: + """ + Validate user variables against the prompt's schema. + + Performs the following validations: + 1. Check required variables are present + 2. Check variable types match expected types + 3. Check enumeration constraints + + Args: + schema: The prompt schema definition. + user_vars: Variables provided by the user. + prompt_name: Name of the prompt template for error reporting. + + Raises: + RequiredVariableError: If required variables are missing. + VariableValidationError: If variable validation fails. + """ + required = schema.get("required", []) + optional = schema.get("optional", []) + types_dict = schema.get("types", {}) + + # --- 1) Check required variables --- + missing_required = [r for r in required if r not in user_vars] + if missing_required: + provided_vars = list(user_vars.keys()) + raise RequiredVariableError( + prompt_name=prompt_name, + missing_variables=missing_required, + provided_variables=provided_vars + ) + + # --- 2) Check for unknown variables (optional check) --- + # Currently disabled to allow flexibility, but can be enabled by uncommenting: + # allowed_vars = set(required + optional) + # unknown_vars = [k for k in user_vars if k not in allowed_vars] + # if unknown_vars: + # raise VariableValidationError( + # prompt_name=prompt_name, + # variable_name=','.join(unknown_vars), + # error_message=f"unknown variables not allowed", + # provided_value=unknown_vars + # ) + + # --- 3) Type checking and enumeration checks --- + for var_name, var_value in user_vars.items(): + if var_name not in types_dict: + # Not specified in the schema, skip type check + continue + + expected_type = types_dict[var_name] + + # 3.1) If it's a list, treat it as enumeration of allowed values + if isinstance(expected_type, list): + if var_value not in expected_type: + raise create_validation_error( + prompt_name=prompt_name, + field=var_name, + value=var_value, + enum_values=expected_type + ) + + # 3.2) If it's a string specifying a type name + elif isinstance(expected_type, str): + self._validate_type_constraint( + var_name=var_name, + var_value=var_value, + expected_type=expected_type, + prompt_name=prompt_name + ) + + def _validate_type_constraint( + self, + var_name: str, + var_value: Any, + expected_type: str, + prompt_name: str + ) -> None: + """Validate a single variable against its type constraint. + + Args: + var_name: Name of the variable. + var_value: Value of the variable. + expected_type: Expected type as string. + prompt_name: Name of the prompt template for error reporting. + + Raises: + VariableValidationError: If type validation fails. + """ + type_checks = { + "string": lambda v: isinstance(v, str), + "integer": lambda v: isinstance(v, int) and not isinstance(v, bool), + "boolean": lambda v: isinstance(v, bool), + "array": lambda v: isinstance(v, (list, tuple)), + "object": lambda v: isinstance(v, dict), + "number": lambda v: isinstance(v, (int, float)) and not isinstance(v, bool) + } + + if expected_type in type_checks: + if not type_checks[expected_type](var_value): + raise create_validation_error( + prompt_name=prompt_name, + field=var_name, + value=var_value, + expected_type=expected_type + ) + # If type is not recognized, we skip validation with a warning + elif self._logger: + self._logger.warning( + f"Unknown type constraint '{expected_type}' for variable '{var_name}' " + f"in prompt '{prompt_name}'. Skipping type validation." + ) + + def validate_builder_type(self, field: str, value: Any, properties: Dict[str, Any]) -> None: + """Validate a single field against its schema properties (for builder pattern). + + Args: + field: Name of the field. + value: Value to validate. + properties: Schema properties definition. + + Raises: + VariableValidationError: If validation fails. + """ + if field not in properties: + # If additional properties are not allowed, this should be checked + # at the builder level + return + + prop = properties[field] + expected_type = prop.get("type") + enum_values = prop.get("enum") + + # Type validation + if expected_type == "string": + if not isinstance(value, str): + raise VariableValidationError( + prompt_name="builder", + variable_name=field, + error_message=f"must be a string, got {type(value).__name__}", + provided_value=value, + expected_type="string" + ) + elif expected_type == "number": + if not (isinstance(value, (int, float)) and not isinstance(value, bool)): + raise VariableValidationError( + prompt_name="builder", + variable_name=field, + error_message=f"must be a number, got {type(value).__name__}", + provided_value=value, + expected_type="number" + ) + elif expected_type == "integer": + if not (isinstance(value, int) and not isinstance(value, bool)): + raise VariableValidationError( + prompt_name="builder", + variable_name=field, + error_message=f"must be an integer, got {type(value).__name__}", + provided_value=value, + expected_type="integer" + ) + elif expected_type == "boolean": + if not isinstance(value, bool): + raise VariableValidationError( + prompt_name="builder", + variable_name=field, + error_message=f"must be a boolean, got {type(value).__name__}", + provided_value=value, + expected_type="boolean" + ) + + # Enumeration validation + if enum_values is not None and value not in enum_values: + raise create_validation_error( + prompt_name="builder", + field=field, + value=value, + enum_values=enum_values + ) diff --git a/src/promptix/core/components/version_manager.py b/src/promptix/core/components/version_manager.py new file mode 100644 index 0000000..e881b4f --- /dev/null +++ b/src/promptix/core/components/version_manager.py @@ -0,0 +1,185 @@ +""" +VersionManager component for handling prompt version management. + +This component is responsible for finding live versions, managing +version data, and handling version-related operations. +""" + +from typing import Any, Dict, List, Optional +from ..exceptions import ( + NoLiveVersionError, + MultipleLiveVersionsError, + VersionNotFoundError +) + + +class VersionManager: + """Handles prompt version management operations.""" + + def __init__(self, logger=None): + """Initialize the version manager. + + Args: + logger: Optional logger instance for dependency injection. + """ + self._logger = logger + + def find_live_version(self, versions: Dict[str, Any], prompt_name: str) -> str: + """Find the live version from a versions dictionary. + + Args: + versions: Dictionary of version data. + prompt_name: Name of the prompt for error reporting. + + Returns: + The key of the live version. + + Raises: + NoLiveVersionError: If no live version is found. + MultipleLiveVersionsError: If multiple live versions are found. + """ + # Find versions where is_live == True + live_versions = [k for k, v in versions.items() if v.get("is_live", False)] + + if not live_versions: + available_versions = list(versions.keys()) + raise NoLiveVersionError( + prompt_name=prompt_name, + available_versions=available_versions + ) + + if len(live_versions) > 1: + raise MultipleLiveVersionsError( + prompt_name=prompt_name, + live_versions=live_versions + ) + + return live_versions[0] + + def get_version_data( + self, + versions: Dict[str, Any], + version: Optional[str], + prompt_name: str + ) -> Dict[str, Any]: + """Get version data, either specific version or live version. + + Args: + versions: Dictionary of all versions. + version: Specific version to get, or None for live version. + prompt_name: Name of the prompt for error reporting. + + Returns: + The version data dictionary. + + Raises: + VersionNotFoundError: If the specified version is not found. + NoLiveVersionError: If no live version is found when version is None. + MultipleLiveVersionsError: If multiple live versions are found. + """ + if version: + # Use explicitly requested version + if version not in versions: + available_versions = list(versions.keys()) + raise VersionNotFoundError( + version=version, + prompt_name=prompt_name, + available_versions=available_versions + ) + return versions[version] + else: + # Find the live version + live_version_key = self.find_live_version(versions, prompt_name) + return versions[live_version_key] + + def get_system_instruction(self, version_data: Dict[str, Any], prompt_name: str) -> str: + """Extract system instruction from version data. + + Args: + version_data: The version data dictionary. + prompt_name: Name of the prompt for error reporting. + + Returns: + The system instruction text. + + Raises: + ValueError: If system instruction is not found in version data. + """ + template_text = version_data.get("config", {}).get("system_instruction") + if not template_text: + raise ValueError( + f"Version data for '{prompt_name}' does not contain 'config.system_instruction'." + ) + return template_text + + def list_versions(self, versions: Dict[str, Any]) -> List[Dict[str, Any]]: + """List all versions with their metadata. + + Args: + versions: Dictionary of version data. + + Returns: + List of version information dictionaries. + """ + version_list = [] + for version_key, version_data in versions.items(): + version_info = { + "version": version_key, + "is_live": version_data.get("is_live", False), + "provider": version_data.get("provider", "unknown"), + "model": version_data.get("config", {}).get("model", "unknown"), + "has_tools": bool(version_data.get("tools_config", {})), + "description": version_data.get("description", "") + } + version_list.append(version_info) + + return version_list + + def validate_version_data(self, version_data: Dict[str, Any], prompt_name: str, version: str) -> bool: + """Validate that version data contains required fields. + + Args: + version_data: The version data to validate. + prompt_name: Name of the prompt for error reporting. + version: Version identifier for error reporting. + + Returns: + True if version data is valid. + + Raises: + ValueError: If required fields are missing. + """ + required_fields = [ + ("config", "Configuration section is missing"), + ("config.system_instruction", "System instruction is missing from config"), + ("config.model", "Model is missing from config") + ] + + for field_path, error_msg in required_fields: + if self._get_nested_field(version_data, field_path) is None: + raise ValueError( + f"Invalid version data for '{prompt_name}' version '{version}': {error_msg}" + ) + + return True + + def _get_nested_field(self, data: Dict[str, Any], field_path: str) -> Any: + """Get a nested field from a dictionary using dot notation. + + Args: + data: The dictionary to search in. + field_path: Dot-separated path to the field (e.g., "config.model"). + + Returns: + The value at the field path, or None if not found. + """ + keys = field_path.split('.') + current = data + + for key in keys: + if isinstance(current, dict) and key in current: + current = current[key] + else: + return None + + return current diff --git a/src/promptix/core/container.py b/src/promptix/core/container.py new file mode 100644 index 0000000..5de4f1d --- /dev/null +++ b/src/promptix/core/container.py @@ -0,0 +1,282 @@ +""" +Dependency Injection Container for Promptix. + +This module provides a simple dependency injection container to manage +dependencies and improve testability of the Promptix system. +""" + +from typing import Any, Dict, Optional, Type, TypeVar, Generic, Callable +from ..enhancements.logging import setup_logging +from .components import ( + PromptLoader, + VariableValidator, + TemplateRenderer, + VersionManager, + ModelConfigBuilder +) +from .adapters.openai import OpenAIAdapter +from .adapters.anthropic import AnthropicAdapter +from .adapters._base import ModelAdapter +from .exceptions import MissingDependencyError, InvalidDependencyError + +T = TypeVar('T') + + +class Container: + """Simple dependency injection container for Promptix components.""" + + def __init__(self): + """Initialize the container with default dependencies.""" + self._services: Dict[str, Any] = {} + self._factories: Dict[str, Callable[[], Any]] = {} + self._singletons: Dict[str, Any] = {} + self._setup_defaults() + + def _setup_defaults(self) -> None: + """Setup default dependencies.""" + # Setup logger as singleton + self.register_singleton("logger", setup_logging()) + + # Register component factories + self.register_factory("prompt_loader", lambda: PromptLoader( + logger=self.get("logger") + )) + + self.register_factory("variable_validator", lambda: VariableValidator( + logger=self.get("logger") + )) + + self.register_factory("template_renderer", lambda: TemplateRenderer( + logger=self.get("logger") + )) + + self.register_factory("version_manager", lambda: VersionManager( + logger=self.get("logger") + )) + + self.register_factory("model_config_builder", lambda: ModelConfigBuilder( + logger=self.get("logger") + )) + + # Register adapters as singletons + self.register_singleton("openai_adapter", OpenAIAdapter()) + self.register_singleton("anthropic_adapter", AnthropicAdapter()) + + # Register adapter registry + self.register_singleton("adapters", { + "openai": self.get("openai_adapter"), + "anthropic": self.get("anthropic_adapter") + }) + + def register_singleton(self, name: str, instance: Any) -> None: + """Register a singleton instance. + + Args: + name: Name of the service. + instance: The singleton instance. + """ + self._singletons[name] = instance + + def register_factory(self, name: str, factory: Callable[[], Any]) -> None: + """Register a factory function that creates instances. + + Args: + name: Name of the service. + factory: Function that creates the service instance. + """ + self._factories[name] = factory + + def register_transient(self, name: str, service_type: Type[T], *args, **kwargs) -> None: + """Register a transient service (new instance each time). + + Args: + name: Name of the service. + service_type: Type of the service to instantiate. + *args: Arguments to pass to the constructor. + **kwargs: Keyword arguments to pass to the constructor. + """ + self._services[name] = (service_type, args, kwargs) + + def get(self, name: str) -> Any: + """Get a service instance by name. + + Args: + name: Name of the service to retrieve. + + Returns: + The service instance. + + Raises: + MissingDependencyError: If the service is not registered. + """ + # Check singletons first + if name in self._singletons: + return self._singletons[name] + + # Check factories + if name in self._factories: + return self._factories[name]() + + # Check transient services + if name in self._services: + service_type, args, kwargs = self._services[name] + return service_type(*args, **kwargs) + + raise MissingDependencyError( + dependency_name=name, + component="Container" + ) + + def get_typed(self, name: str, expected_type: Type[T]) -> T: + """Get a service instance with type checking. + + Args: + name: Name of the service to retrieve. + expected_type: Expected type of the service. + + Returns: + The service instance cast to the expected type. + + Raises: + MissingDependencyError: If the service is not registered. + InvalidDependencyError: If the service is not of the expected type. + """ + service = self.get(name) + if not isinstance(service, expected_type): + raise InvalidDependencyError( + dependency_name=name, + expected_type=expected_type.__name__, + actual_type=type(service).__name__ + ) + return service + + def override(self, name: str, instance: Any) -> None: + """Override a service with a new instance (useful for testing). + + Args: + name: Name of the service to override. + instance: The new service instance. + """ + self._singletons[name] = instance + + def clear_overrides(self) -> None: + """Clear all service overrides and restore defaults.""" + self._services.clear() + self._factories.clear() + self._singletons.clear() + self._setup_defaults() + + def register_adapter(self, name: str, adapter: ModelAdapter) -> None: + """Register a new model adapter. + + Args: + name: Name of the adapter (e.g., "claude", "gpt4"). + adapter: The adapter instance. + + Raises: + InvalidDependencyError: If the adapter is not a ModelAdapter instance. + """ + if not isinstance(adapter, ModelAdapter): + raise InvalidDependencyError( + dependency_name=name, + expected_type="ModelAdapter", + actual_type=type(adapter).__name__ + ) + + # Add to adapters registry + adapters = self.get("adapters") + adapters[name] = adapter + + def create_scope(self) -> 'ContainerScope': + """Create a new scope with isolated overrides. + + Returns: + A new container scope. + """ + return ContainerScope(self) + + +class ContainerScope: + """A scoped container that can have isolated overrides.""" + + def __init__(self, parent_container: Container): + """Initialize the scoped container. + + Args: + parent_container: The parent container to inherit from. + """ + self._parent = parent_container + self._overrides: Dict[str, Any] = {} + + def override(self, name: str, instance: Any) -> None: + """Override a service in this scope only. + + Args: + name: Name of the service to override. + instance: The new service instance. + """ + self._overrides[name] = instance + + def get(self, name: str) -> Any: + """Get a service, checking overrides first. + + Args: + name: Name of the service to retrieve. + + Returns: + The service instance. + """ + if name in self._overrides: + return self._overrides[name] + return self._parent.get(name) + + def get_typed(self, name: str, expected_type: Type[T]) -> T: + """Get a service with type checking. + + Args: + name: Name of the service to retrieve. + expected_type: Expected type of the service. + + Returns: + The service instance cast to the expected type. + """ + service = self.get(name) + if not isinstance(service, expected_type): + raise InvalidDependencyError( + dependency_name=name, + expected_type=expected_type.__name__, + actual_type=type(service).__name__ + ) + return service + + +# Global container instance +_container: Optional[Container] = None + + +def get_container() -> Container: + """Get the global container instance. + + Returns: + The global container instance. + """ + global _container + if _container is None: + _container = Container() + return _container + + +def set_container(container: Container) -> None: + """Set the global container instance (useful for testing). + + Args: + container: The container instance to set as global. + """ + global _container + _container = container + + +def reset_container() -> None: + """Reset the global container to defaults.""" + global _container + _container = None diff --git a/src/promptix/core/exceptions.py b/src/promptix/core/exceptions.py new file mode 100644 index 0000000..6316cac --- /dev/null +++ b/src/promptix/core/exceptions.py @@ -0,0 +1,333 @@ +""" +Custom exception classes for Promptix. + +This module provides a standardized exception hierarchy for consistent error handling +throughout the Promptix library. All custom exceptions inherit from PromptixError. +""" + +from typing import Any, Dict, List, Optional, Union + + +class PromptixError(Exception): + """Base exception class for all Promptix errors.""" + + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + super().__init__(message) + self.message = message + self.details = details or {} + + def __str__(self) -> str: + if self.details: + return f"{self.message}. Details: {self.details}" + return self.message + + +# === Template and Prompt Related Errors === + +class PromptNotFoundError(PromptixError): + """Raised when a requested prompt template is not found.""" + + def __init__(self, prompt_name: str, available_prompts: Optional[List[str]] = None): + message = f"Prompt template '{prompt_name}' not found" + details = { + "prompt_name": prompt_name, + "available_prompts": available_prompts or [] + } + super().__init__(message, details) + + +class VersionNotFoundError(PromptixError): + """Raised when a requested prompt version is not found.""" + + def __init__(self, version: str, prompt_name: str, available_versions: Optional[List[str]] = None): + message = f"Version '{version}' not found for prompt '{prompt_name}'" + details = { + "version": version, + "prompt_name": prompt_name, + "available_versions": available_versions or [] + } + super().__init__(message, details) + + +class NoLiveVersionError(PromptixError): + """Raised when no live version is found for a prompt.""" + + def __init__(self, prompt_name: str, available_versions: Optional[List[str]] = None): + message = f"No live version found for prompt '{prompt_name}'" + details = { + "prompt_name": prompt_name, + "available_versions": available_versions or [] + } + super().__init__(message, details) + + +class MultipleLiveVersionsError(PromptixError): + """Raised when multiple live versions are found for a prompt.""" + + def __init__(self, prompt_name: str, live_versions: List[str]): + message = f"Multiple live versions found for prompt '{prompt_name}': {live_versions}. Only one version can be live at a time" + details = { + "prompt_name": prompt_name, + "live_versions": live_versions + } + super().__init__(message, details) + + +class TemplateRenderError(PromptixError): + """Raised when template rendering fails.""" + + def __init__(self, prompt_name: str, template_error: str, variables: Optional[Dict[str, Any]] = None): + message = f"Error rendering template for '{prompt_name}': {template_error}" + details = { + "prompt_name": prompt_name, + "template_error": template_error, + "variables": variables or {} + } + super().__init__(message, details) + + +# === Validation Errors === + +class ValidationError(PromptixError): + """Base class for validation-related errors.""" + pass + + +class VariableValidationError(ValidationError): + """Raised when variable validation fails.""" + + def __init__(self, prompt_name: str, variable_name: str, error_message: str, + provided_value: Any = None, expected_type: Optional[str] = None): + message = f"Variable '{variable_name}' validation failed for prompt '{prompt_name}': {error_message}" + details = { + "prompt_name": prompt_name, + "variable_name": variable_name, + "provided_value": provided_value, + "expected_type": expected_type, + "error_message": error_message + } + super().__init__(message, details) + + +class RequiredVariableError(ValidationError): + """Raised when required variables are missing.""" + + def __init__(self, prompt_name: str, missing_variables: List[str], + provided_variables: Optional[List[str]] = None): + message = f"Prompt '{prompt_name}' is missing required variables: {', '.join(missing_variables)}" + details = { + "prompt_name": prompt_name, + "missing_variables": missing_variables, + "provided_variables": provided_variables or [] + } + super().__init__(message, details) + + +class SchemaValidationError(ValidationError): + """Raised when schema validation fails.""" + + def __init__(self, prompt_name: str, schema_errors: List[str]): + message = f"Schema validation failed for prompt '{prompt_name}': {'; '.join(schema_errors)}" + details = { + "prompt_name": prompt_name, + "schema_errors": schema_errors + } + super().__init__(message, details) + + +# === Configuration and Adapter Errors === + +class ConfigurationError(PromptixError): + """Raised when configuration is invalid or missing.""" + + def __init__(self, config_issue: str, config_path: Optional[str] = None): + message = f"Configuration error: {config_issue}" + details = {"config_issue": config_issue} + if config_path: + details["config_path"] = config_path + super().__init__(message, details) + + +class AdapterError(PromptixError): + """Base class for adapter-related errors.""" + pass + + +class UnsupportedClientError(AdapterError): + """Raised when an unsupported client is requested.""" + + def __init__(self, client_name: str, available_clients: List[str]): + message = f"Unsupported client: {client_name}. Available clients: {available_clients}" + details = { + "client_name": client_name, + "available_clients": available_clients + } + super().__init__(message, details) + + +class AdapterConfigurationError(AdapterError): + """Raised when adapter configuration fails.""" + + def __init__(self, client_name: str, configuration_error: str): + message = f"Adapter configuration failed for client '{client_name}': {configuration_error}" + details = { + "client_name": client_name, + "configuration_error": configuration_error + } + super().__init__(message, details) + + +# === Storage and File Errors === + +class StorageError(PromptixError): + """Base class for storage-related errors.""" + pass + + +class StorageFileNotFoundError(StorageError): + """Raised when a required file is not found.""" + + def __init__(self, file_path: str, file_type: str = "file"): + message = f"{file_type.capitalize()} not found: {file_path}" + details = { + "file_path": file_path, + "file_type": file_type + } + super().__init__(message, details) + + +class UnsupportedFormatError(StorageError): + """Raised when an unsupported file format is encountered.""" + + def __init__(self, file_path: str, unsupported_format: str, supported_formats: List[str]): + message = f"Unsupported format '{unsupported_format}' for file: {file_path}" + details = { + "file_path": file_path, + "unsupported_format": unsupported_format, + "supported_formats": supported_formats + } + super().__init__(message, details) + + +class FileParsingError(StorageError): + """Raised when file parsing fails.""" + + def __init__(self, file_path: str, parsing_error: str): + message = f"Error parsing file '{file_path}': {parsing_error}" + details = { + "file_path": file_path, + "parsing_error": parsing_error + } + super().__init__(message, details) + + +# === Memory and Message Errors === + +class PromptixMemoryError(PromptixError): + """Base class for memory/message-related errors.""" + pass + + +class InvalidMemoryFormatError(PromptixMemoryError): + """Raised when memory/message format is invalid.""" + + def __init__(self, error_description: str, invalid_message: Optional[Dict[str, Any]] = None): + message = f"Invalid memory format: {error_description}" + details = { + "error_description": error_description, + "invalid_message": invalid_message + } + super().__init__(message, details) + + +# === Tool-related Errors === + +class ToolError(PromptixError): + """Base class for tool-related errors.""" + pass + + +class ToolNotFoundError(ToolError): + """Raised when a requested tool is not found.""" + + def __init__(self, tool_name: str, available_tools: Optional[List[str]] = None): + message = f"Tool '{tool_name}' not found in configuration" + details = { + "tool_name": tool_name, + "available_tools": available_tools or [] + } + super().__init__(message, details) + + +class ToolProcessingError(ToolError): + """Raised when tool processing fails.""" + + def __init__(self, tool_name: str, processing_error: str): + message = f"Error processing tool '{tool_name}': {processing_error}" + details = { + "tool_name": tool_name, + "processing_error": processing_error + } + super().__init__(message, details) + + +# === Dependency Injection Errors === + +class DependencyError(PromptixError): + """Base class for dependency injection errors.""" + pass + + +class MissingDependencyError(DependencyError): + """Raised when a required dependency is not provided.""" + + def __init__(self, dependency_name: str, component: str): + message = f"Missing required dependency '{dependency_name}' for component '{component}'" + details = { + "dependency_name": dependency_name, + "component": component + } + super().__init__(message, details) + + +class InvalidDependencyError(DependencyError): + """Raised when a dependency does not meet requirements.""" + + def __init__(self, dependency_name: str, expected_type: str, actual_type: str): + message = f"Invalid dependency '{dependency_name}': expected {expected_type}, got {actual_type}" + details = { + "dependency_name": dependency_name, + "expected_type": expected_type, + "actual_type": actual_type + } + super().__init__(message, details) + + +# Convenience function to create appropriate exceptions +def create_validation_error(prompt_name: str, field: str, value: Any, + expected_type: Optional[str] = None, + enum_values: Optional[List[Any]] = None) -> ValidationError: + """Create appropriate validation error based on the validation failure type.""" + if enum_values and value not in enum_values: + return VariableValidationError( + prompt_name=prompt_name, + variable_name=field, + error_message=f"must be one of {enum_values}", + provided_value=value, + expected_type=f"enum: {enum_values}" + ) + elif expected_type: + return VariableValidationError( + prompt_name=prompt_name, + variable_name=field, + error_message=f"must be of type {expected_type}", + provided_value=value, + expected_type=expected_type + ) + else: + return VariableValidationError( + prompt_name=prompt_name, + variable_name=field, + error_message="validation failed", + provided_value=value + ) diff --git a/src/promptix/core/storage/loaders.py b/src/promptix/core/storage/loaders.py index 78dfb97..b15e048 100644 --- a/src/promptix/core/storage/loaders.py +++ b/src/promptix/core/storage/loaders.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Dict, Any from jsonschema import Draft7Validator, ValidationError +from ..exceptions import UnsupportedFormatError class InvalidPromptSchemaError(ValueError): """Raised when prompt data fails schema validation""" @@ -94,14 +95,6 @@ def validate_loaded(self, data: Dict[str, Any]) -> None: """Validate loaded data against schema""" pass -class UnsupportedFormatError(ValueError): - """Raised when trying to use an unsupported or deprecated file format""" - def __init__(self, file_path: Path, message: str = None): - if message is None: - message = f"JSON format is no longer supported. Please convert {file_path} to YAML format (.yaml or .yml)" - super().__init__(message) - self.file_path = file_path - class YAMLPromptLoader(PromptLoader): def load(self, file_path: Path) -> Dict[str, Any]: with open(file_path, 'r', encoding='utf-8') as f: @@ -124,9 +117,9 @@ def get_loader(file_path: Path) -> PromptLoader: return YAMLPromptLoader() elif file_path.suffix.lower() == '.json': raise UnsupportedFormatError( - file_path, - f"JSON format is no longer supported. Please convert '{file_path}' to YAML format. " - f"You can rename it to '{file_path.with_suffix('.yaml')}' and update the syntax if needed." + str(file_path), + "json", + ["yaml", "yml"] ) else: raise ValueError(f"Unsupported file format: {file_path.suffix}. Only YAML (.yaml, .yml) files are supported.") \ No newline at end of file diff --git a/src/promptix/tools/studio/data.py b/src/promptix/tools/studio/data.py index a23e58d..6ce1b25 100644 --- a/src/promptix/tools/studio/data.py +++ b/src/promptix/tools/studio/data.py @@ -2,7 +2,8 @@ from typing import Dict, List, Optional from datetime import datetime from pathlib import Path -from promptix.core.storage.loaders import PromptLoaderFactory, InvalidPromptSchemaError, UnsupportedFormatError +from promptix.core.storage.loaders import PromptLoaderFactory, InvalidPromptSchemaError +from promptix.core.exceptions import UnsupportedFormatError from promptix.core.storage.utils import create_default_prompts_file from promptix.core.config import config import traceback @@ -15,10 +16,9 @@ def __init__(self): if unsupported_files: json_file = unsupported_files[0] raise UnsupportedFormatError( - json_file, - f"Promptix Studio no longer supports JSON format. " - f"Please convert '{json_file}' to YAML format. " - f"You can rename it to '{json_file.with_suffix('.yaml')}' and ensure YAML syntax is correct." + str(json_file), + "json", + ["yaml", "yml"] ) # Get the prompt file path from configuration diff --git a/tests/test_01_basic.py b/tests/test_01_basic.py index c3a13df..25e3f34 100644 --- a/tests/test_01_basic.py +++ b/tests/test_01_basic.py @@ -28,7 +28,8 @@ def test_get_prompt_specific_version(): def test_get_prompt_invalid_template(): """Test error handling for invalid template.""" - with pytest.raises(ValueError): + from promptix.core.exceptions import PromptNotFoundError + with pytest.raises(PromptNotFoundError): Promptix.get_prompt( prompt_template="NonExistentTemplate", user_name="Test User" diff --git a/tests/test_02_builder.py b/tests/test_02_builder.py index b897adb..426137f 100644 --- a/tests/test_02_builder.py +++ b/tests/test_02_builder.py @@ -84,11 +84,13 @@ def test_template_demo_builder(): def test_builder_validation(): """Test builder validation and error cases.""" - with pytest.raises(ValueError): + from promptix.core.exceptions import PromptNotFoundError, UnsupportedClientError + + with pytest.raises(PromptNotFoundError): # Should raise error for invalid template name Promptix.builder("NonExistentTemplate").build() - with pytest.raises(ValueError): + with pytest.raises(UnsupportedClientError): # Should raise error for invalid client type (Promptix.builder("SimpleChat") .for_client("invalid_client") diff --git a/tests/test_03_template_features.py b/tests/test_03_template_features.py index f7f1f88..e05349f 100644 --- a/tests/test_03_template_features.py +++ b/tests/test_03_template_features.py @@ -111,7 +111,8 @@ def test_template_parameter_validation(): """Test template parameter validation.""" # In the current implementation, it seems the template doesn't validate difficulty levels # So instead, we'll test with a non-existent template name - with pytest.raises(ValueError): + from promptix.core.exceptions import PromptNotFoundError + with pytest.raises(PromptNotFoundError): Promptix.get_prompt( prompt_template="NonExistentTemplate", content_type="tutorial", diff --git a/tests/test_07_architecture_refactor.py b/tests/test_07_architecture_refactor.py new file mode 100644 index 0000000..89cb7b7 --- /dev/null +++ b/tests/test_07_architecture_refactor.py @@ -0,0 +1,546 @@ +""" +Tests for the refactored architecture with dependency injection and focused components. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + +# Import components and exceptions +from promptix.core.components import ( + PromptLoader, + VariableValidator, + TemplateRenderer, + VersionManager, + ModelConfigBuilder +) +from promptix.core.exceptions import ( + PromptixError, + PromptNotFoundError, + VersionNotFoundError, + NoLiveVersionError, + MultipleLiveVersionsError, + TemplateRenderError, + VariableValidationError, + RequiredVariableError, + ConfigurationError, + UnsupportedClientError, + InvalidMemoryFormatError +) +from promptix.core.container import Container, get_container, reset_container +from promptix.core.base_refactored import Promptix +from promptix.core.builder_refactored import PromptixBuilder + + +class TestExceptions: + """Test the custom exception hierarchy.""" + + def test_promptix_error_base(self): + """Test the base PromptixError class.""" + error = PromptixError("Test error", {"key": "value"}) + assert str(error) == "Test error. Details: {'key': 'value'}" + assert error.message == "Test error" + assert error.details == {"key": "value"} + + def test_prompt_not_found_error(self): + """Test PromptNotFoundError.""" + error = PromptNotFoundError("TestPrompt", ["Prompt1", "Prompt2"]) + assert "TestPrompt" in str(error) + assert error.details["prompt_name"] == "TestPrompt" + assert error.details["available_prompts"] == ["Prompt1", "Prompt2"] + + def test_version_not_found_error(self): + """Test VersionNotFoundError.""" + error = VersionNotFoundError("v2", "TestPrompt", ["v1", "v3"]) + assert "v2" in str(error) + assert "TestPrompt" in str(error) + assert error.details["version"] == "v2" + assert error.details["prompt_name"] == "TestPrompt" + + def test_variable_validation_error(self): + """Test VariableValidationError.""" + error = VariableValidationError("TestPrompt", "test_var", "must be string", 123, "string") + assert "test_var" in str(error) + assert "TestPrompt" in str(error) + assert error.details["variable_name"] == "test_var" + assert error.details["provided_value"] == 123 + + def test_required_variable_error(self): + """Test RequiredVariableError.""" + error = RequiredVariableError("TestPrompt", ["var1", "var2"], ["var3"]) + assert "var1" in str(error) + assert "var2" in str(error) + assert error.details["missing_variables"] == ["var1", "var2"] + + +class TestPromptLoader: + """Test the PromptLoader component.""" + + def test_prompt_loader_initialization(self): + """Test PromptLoader initialization.""" + logger = Mock() + loader = PromptLoader(logger) + assert loader._logger == logger + assert not loader.is_loaded() + + @patch('promptix.core.components.prompt_loader.config') + @patch('promptix.core.components.prompt_loader.PromptLoaderFactory') + def test_load_prompts_success(self, mock_factory, mock_config): + """Test successful prompt loading.""" + # Setup mocks + mock_config.check_for_unsupported_files.return_value = [] + mock_config.get_prompt_file_path.return_value = "/path/to/prompts.yaml" + + mock_loader = Mock() + mock_loader.load.return_value = {"TestPrompt": {"versions": {}}} + mock_factory.get_loader.return_value = mock_loader + + # Test + loader = PromptLoader() + prompts = loader.load_prompts() + + assert prompts == {"TestPrompt": {"versions": {}}} + assert loader.is_loaded() + + @patch('promptix.core.components.prompt_loader.config') + def test_load_prompts_json_error(self, mock_config): + """Test error when JSON files are detected.""" + from pathlib import Path + mock_config.check_for_unsupported_files.return_value = [Path("/path/to/prompts.json")] + + loader = PromptLoader() + with pytest.raises(Exception) as exc_info: + loader.load_prompts() + + assert "Unsupported format 'json'" in str(exc_info.value) + + def test_get_prompt_data_not_found(self): + """Test getting prompt data for non-existent prompt.""" + loader = PromptLoader() + loader._prompts = {"ExistingPrompt": {}} + loader._loaded = True + + with pytest.raises(Exception) as exc_info: + loader.get_prompt_data("NonExistentPrompt") + + assert "not found" in str(exc_info.value) + + +class TestVariableValidator: + """Test the VariableValidator component.""" + + def test_validate_variables_success(self): + """Test successful variable validation.""" + validator = VariableValidator() + schema = { + "required": ["name", "age"], + "types": {"name": "string", "age": "integer"} + } + user_vars = {"name": "John", "age": 25} + + # Should not raise any exception + validator.validate_variables(schema, user_vars, "TestPrompt") + + def test_validate_variables_missing_required(self): + """Test validation with missing required variables.""" + validator = VariableValidator() + schema = {"required": ["name", "age"]} + user_vars = {"name": "John"} + + with pytest.raises(RequiredVariableError) as exc_info: + validator.validate_variables(schema, user_vars, "TestPrompt") + + assert "age" in str(exc_info.value) + + def test_validate_variables_type_mismatch(self): + """Test validation with type mismatch.""" + validator = VariableValidator() + schema = { + "required": ["name"], + "types": {"name": "string"} + } + user_vars = {"name": 123} + + with pytest.raises(VariableValidationError) as exc_info: + validator.validate_variables(schema, user_vars, "TestPrompt") + + assert "must be of type string" in str(exc_info.value) + + def test_validate_variables_enum_violation(self): + """Test validation with enum constraint violation.""" + validator = VariableValidator() + schema = { + "required": ["status"], + "types": {"status": ["active", "inactive", "pending"]} + } + user_vars = {"status": "unknown"} + + with pytest.raises(VariableValidationError) as exc_info: + validator.validate_variables(schema, user_vars, "TestPrompt") + + assert "must be one of" in str(exc_info.value) + + +class TestTemplateRenderer: + """Test the TemplateRenderer component.""" + + def test_render_template_success(self): + """Test successful template rendering.""" + renderer = TemplateRenderer() + template = "Hello {{ name }}!" + variables = {"name": "World"} + + result = renderer.render_template(template, variables, "TestPrompt") + assert result == "Hello World!" + + def test_render_template_with_newlines(self): + """Test template rendering with escaped newlines.""" + renderer = TemplateRenderer() + template = "Line 1\\nLine 2" + + result = renderer.render_template(template, {}, "TestPrompt") + assert result == "Line 1\nLine 2" + + def test_render_template_error(self): + """Test template rendering error handling.""" + renderer = TemplateRenderer() + template = "Hello {{ undefined_variable.missing_attr }}!" + + with pytest.raises(TemplateRenderError) as exc_info: + renderer.render_template(template, {}, "TestPrompt") + + assert "TestPrompt" in str(exc_info.value) + + def test_render_tools_template_success(self): + """Test successful tools template rendering.""" + renderer = TemplateRenderer() + tools_template = '["tool1", "tool2"]' + + result = renderer.render_tools_template( + tools_template, {}, {}, "TestPrompt" + ) + assert result == ["tool1", "tool2"] + + def test_validate_template(self): + """Test template validation.""" + renderer = TemplateRenderer() + + assert renderer.validate_template("Hello {{ name }}!") + assert not renderer.validate_template("Hello {{ unclosed") + + +class TestVersionManager: + """Test the VersionManager component.""" + + def test_find_live_version_success(self): + """Test finding live version successfully.""" + manager = VersionManager() + versions = { + "v1": {"is_live": False}, + "v2": {"is_live": True}, + "v3": {"is_live": False} + } + + live_version = manager.find_live_version(versions, "TestPrompt") + assert live_version == "v2" + + def test_find_live_version_none(self): + """Test error when no live version found.""" + manager = VersionManager() + versions = { + "v1": {"is_live": False}, + "v2": {"is_live": False} + } + + with pytest.raises(NoLiveVersionError) as exc_info: + manager.find_live_version(versions, "TestPrompt") + + assert "TestPrompt" in str(exc_info.value) + + def test_find_live_version_multiple(self): + """Test error when multiple live versions found.""" + manager = VersionManager() + versions = { + "v1": {"is_live": True}, + "v2": {"is_live": True} + } + + with pytest.raises(MultipleLiveVersionsError) as exc_info: + manager.find_live_version(versions, "TestPrompt") + + assert "v1" in str(exc_info.value) + assert "v2" in str(exc_info.value) + + def test_get_version_data_specific(self): + """Test getting specific version data.""" + manager = VersionManager() + versions = { + "v1": {"config": {"model": "gpt-3.5-turbo"}}, + "v2": {"config": {"model": "gpt-4"}} + } + + version_data = manager.get_version_data(versions, "v2", "TestPrompt") + assert version_data["config"]["model"] == "gpt-4" + + def test_get_version_data_not_found(self): + """Test error when specific version not found.""" + manager = VersionManager() + versions = {"v1": {}} + + with pytest.raises(VersionNotFoundError) as exc_info: + manager.get_version_data(versions, "v3", "TestPrompt") + + assert "v3" in str(exc_info.value) + + def test_get_system_instruction_success(self): + """Test getting system instruction from version data.""" + manager = VersionManager() + version_data = { + "config": {"system_instruction": "You are a helpful assistant."} + } + + instruction = manager.get_system_instruction(version_data, "TestPrompt") + assert instruction == "You are a helpful assistant." + + def test_get_system_instruction_missing(self): + """Test error when system instruction is missing.""" + manager = VersionManager() + version_data = {"config": {}} + + with pytest.raises(ValueError) as exc_info: + manager.get_system_instruction(version_data, "TestPrompt") + + assert "system_instruction" in str(exc_info.value) + + +class TestModelConfigBuilder: + """Test the ModelConfigBuilder component.""" + + def test_validate_memory_format_success(self): + """Test successful memory format validation.""" + builder = ModelConfigBuilder() + memory = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + + # Should not raise any exception + builder.validate_memory_format(memory) + + def test_validate_memory_format_invalid_type(self): + """Test memory validation with invalid type.""" + builder = ModelConfigBuilder() + + with pytest.raises(InvalidMemoryFormatError): + builder.validate_memory_format("not a list") + + def test_validate_memory_format_invalid_message(self): + """Test memory validation with invalid message format.""" + builder = ModelConfigBuilder() + memory = [{"role": "user"}] # Missing content + + with pytest.raises(InvalidMemoryFormatError) as exc_info: + builder.validate_memory_format(memory) + + assert "content" in str(exc_info.value) + + def test_validate_memory_format_invalid_role(self): + """Test memory validation with invalid role.""" + builder = ModelConfigBuilder() + memory = [{"role": "invalid", "content": "test"}] + + with pytest.raises(InvalidMemoryFormatError) as exc_info: + builder.validate_memory_format(memory) + + assert "role" in str(exc_info.value) + + def test_build_model_config_success(self): + """Test successful model config building.""" + builder = ModelConfigBuilder() + system_message = "You are helpful." + memory = [{"role": "user", "content": "Hello"}] + version_data = {"config": {"model": "gpt-3.5-turbo"}} + + config = builder.build_model_config(system_message, memory, version_data, "TestPrompt") + + assert config["model"] == "gpt-3.5-turbo" + assert config["messages"][0]["role"] == "system" + assert config["messages"][0]["content"] == "You are helpful." + assert config["messages"][1]["role"] == "user" + + def test_build_model_config_missing_model(self): + """Test error when model is missing from config.""" + builder = ModelConfigBuilder() + version_data = {"config": {}} + + with pytest.raises(ConfigurationError) as exc_info: + builder.build_model_config("test", [], version_data, "TestPrompt") + + assert "Model must be specified" in str(exc_info.value) + + def test_prepare_anthropic_config(self): + """Test Anthropic-specific config preparation.""" + builder = ModelConfigBuilder() + system_message = "You are helpful." + memory = [{"role": "user", "content": "Hello"}] + version_data = {"config": {"model": "claude-3-sonnet-20240229"}} + + config = builder.prepare_anthropic_config(system_message, memory, version_data, "TestPrompt") + + assert config["model"] == "claude-3-sonnet-20240229" + assert config["system"] == "You are helpful." + assert config["messages"] == memory + + +class TestContainer: + """Test the dependency injection container.""" + + def test_container_initialization(self): + """Test container initialization with defaults.""" + container = Container() + + # Should have default services + logger = container.get("logger") + assert logger is not None + + adapters = container.get("adapters") + assert "openai" in adapters + assert "anthropic" in adapters + + def test_container_register_singleton(self): + """Test registering and retrieving singleton services.""" + container = Container() + test_service = Mock() + + container.register_singleton("test_service", test_service) + retrieved = container.get("test_service") + + assert retrieved is test_service + + def test_container_register_factory(self): + """Test registering and using factory services.""" + container = Container() + + def create_service(): + return Mock(name="factory_service") + + container.register_factory("factory_service", create_service) + service1 = container.get("factory_service") + service2 = container.get("factory_service") + + # Factory should create new instances each time + assert service1 is not service2 + + def test_container_missing_dependency(self): + """Test error when dependency is missing.""" + container = Container() + + with pytest.raises(Exception) as exc_info: + container.get("nonexistent_service") + + assert "nonexistent_service" in str(exc_info.value) + + def test_container_scope(self): + """Test container scoping functionality.""" + container = Container() + original_service = Mock(name="original") + override_service = Mock(name="override") + + container.register_singleton("test_service", original_service) + + # Create scope and override + scope = container.create_scope() + scope.override("test_service", override_service) + + # Original container should return original service + assert container.get("test_service") is original_service + + # Scope should return override + assert scope.get("test_service") is override_service + + +class TestRefactoredIntegration: + """Integration tests for the refactored architecture.""" + + def setup_method(self): + """Setup for each test method.""" + reset_container() + + @patch('promptix.core.components.prompt_loader.config') + @patch('promptix.core.components.prompt_loader.PromptLoaderFactory') + def test_promptix_integration(self, mock_factory, mock_config): + """Test integration of refactored Promptix class.""" + # Setup mocks for prompt loading + mock_config.check_for_unsupported_files.return_value = [] + mock_config.get_prompt_file_path.return_value = "/path/to/prompts.yaml" + + mock_loader = Mock() + mock_loader.load.return_value = { + "TestPrompt": { + "versions": { + "v1": { + "is_live": True, + "config": {"system_instruction": "Hello {{ name }}!"}, + "schema": {"required": ["name"]} + } + } + } + } + mock_factory.get_loader.return_value = mock_loader + + # Test + result = Promptix.get_prompt("TestPrompt", name="World") + assert result == "Hello World!" + + @patch('promptix.core.components.prompt_loader.config') + @patch('promptix.core.components.prompt_loader.PromptLoaderFactory') + def test_builder_integration(self, mock_factory, mock_config): + """Test integration of refactored PromptixBuilder class.""" + # Setup mocks + mock_config.check_for_unsupported_files.return_value = [] + mock_config.get_prompt_file_path.return_value = "/path/to/prompts.yaml" + + mock_loader = Mock() + mock_loader.load.return_value = { + "TestPrompt": { + "versions": { + "v1": { + "is_live": True, + "config": { + "system_instruction": "You are {{ role }}.", + "model": "gpt-3.5-turbo" + }, + "schema": { + "required": ["role"], + "properties": { + "role": {"type": "string"} + }, + "additionalProperties": True + } + } + } + } + } + mock_factory.get_loader.return_value = mock_loader + + # Test builder + config = (Promptix.builder("TestPrompt") + .with_role("a helpful assistant") + .with_memory([{"role": "user", "content": "Hello"}]) + .build()) + + assert config["model"] == "gpt-3.5-turbo" + assert "helpful assistant" in config["messages"][0]["content"] + assert config["messages"][1]["content"] == "Hello" + + def test_custom_container_usage(self): + """Test using custom container for dependency injection.""" + # Create custom container with mock logger + custom_container = Container() + mock_logger = Mock() + custom_container.override("logger", mock_logger) + + # Create Promptix instance with custom container + promptix = Promptix(custom_container) + + # Verify it uses the custom logger + assert promptix._logger is mock_logger