diff --git a/dapr_agents/types/llm.py b/dapr_agents/types/llm.py index 65c03369..0d5b442c 100644 --- a/dapr_agents/types/llm.py +++ b/dapr_agents/types/llm.py @@ -1,4 +1,4 @@ -from typing import List, Union, Optional, Dict, Any, Literal, IO, Tuple +from typing import List, Union, Optional, Dict, Any, Literal, IO, Tuple, cast from pydantic import BaseModel, Field, model_validator, field_validator, ConfigDict from pydantic_core import PydanticUseDefault from pathlib import Path @@ -113,7 +113,7 @@ class OpenAIModelConfig(OpenAIClientConfig): type: Literal["openai"] = Field( "openai", description="Type of the model, must always be 'openai'" ) - name: str = Field(default=None, description="Name of the OpenAI model") + name: str = Field(default="", description="Name of the OpenAI model") class AzureOpenAIModelConfig(AzureOpenAIClientConfig): @@ -127,7 +127,7 @@ class HFHubModelConfig(HFInferenceClientConfig): "huggingface", description="Type of the model, must always be 'huggingface'" ) name: str = Field( - default=None, description="Name of the model available through Hugging Face" + default="", description="Name of the model available through Hugging Face" ) @@ -136,7 +136,7 @@ class NVIDIAModelConfig(NVIDIAClientConfig): "nvidia", description="Type of the model, must always be 'nvidia'" ) name: str = Field( - default=None, description="Name of the model available through NVIDIA" + default="", description="Name of the model available through NVIDIA" ) @@ -340,6 +340,14 @@ def sync_model_name(cls, values: dict): elif configuration.get("type") == "nvidia": configuration = NVIDIAModelConfig(**configuration) + configuration = cast( + OpenAIModelConfig + | AzureOpenAIModelConfig + | HFHubModelConfig + | NVIDIAModelConfig, + configuration, + ) + # Ensure 'parameters' is properly validated as a model, not a dict if isinstance(parameters, dict): if configuration and isinstance(configuration, OpenAIModelConfig): @@ -351,12 +359,27 @@ def sync_model_name(cls, values: dict): elif configuration and isinstance(configuration, NVIDIAModelConfig): parameters = NVIDIAChatCompletionParams(**parameters) + parameters = cast( + OpenAIChatCompletionParams + | HFHubChatCompletionParams + | NVIDIAChatCompletionParams, + parameters, + ) + if configuration and parameters: # Check if 'name' or 'azure_deployment' is explicitly set if "name" in configuration.model_fields_set: - parameters.model = configuration.name + parameters.model = ( + configuration.name + if not isinstance(configuration, AzureOpenAIModelConfig) + else None + ) elif "azure_deployment" in configuration.model_fields_set: - parameters.model = configuration.azure_deployment + parameters.model = ( + configuration.azure_deployment + if isinstance(configuration, AzureOpenAIModelConfig) + else None + ) values["configuration"] = configuration values["parameters"] = parameters @@ -471,7 +494,7 @@ def validate_file( elif isinstance(value, BufferedReader) or ( hasattr(value, "read") and callable(value.read) ): - if value.closed: + if hasattr(value, "closed") and value.closed: raise ValueError("File-like object must remain open during request.") return value elif isinstance(value, tuple): @@ -535,7 +558,7 @@ def validate_file( elif isinstance(value, BufferedReader) or ( hasattr(value, "read") and callable(value.read) ): - if value.closed: # Reopen if closed + if hasattr(value, "closed") and value.closed: # Reopen if closed raise ValueError("File-like object must remain open during request.") return value elif isinstance(value, tuple): diff --git a/dapr_agents/types/message.py b/dapr_agents/types/message.py index eac34892..8b6c9c81 100644 --- a/dapr_agents/types/message.py +++ b/dapr_agents/types/message.py @@ -1,3 +1,4 @@ +from typing import Any from pydantic import ( BaseModel, field_validator, @@ -5,7 +6,7 @@ model_validator, ConfigDict, ) -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict import json @@ -239,6 +240,7 @@ def get_tool_calls(self) -> Optional[List[ToolCall]]: return self.tool_calls if isinstance(self.tool_calls, ToolCall): return [self.tool_calls] + return None def has_tool_calls(self) -> bool: """ @@ -250,6 +252,7 @@ def has_tool_calls(self) -> bool: return True if isinstance(self.tool_calls, ToolCall): return True + return False class ToolMessage(BaseMessage): diff --git a/mypy.ini b/mypy.ini index d16537b4..9fe5198f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -52,9 +52,6 @@ ignore_errors = True [mypy-dapr_agents.tool.base.*] ignore_errors = True -[mypy-dapr_agents.types.*] -ignore_errors = True - [mypy-dapr_agents.workflow.*] ignore_errors = True diff --git a/tox.ini b/tox.ini index 6c2c0a4a..a2d294e6 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,10 @@ [tox] skipsdist = False -minversion = 3.9.0 envlist = - py{39,310,311,312,313} + py{310,311,312,313} flake8, ruff, - mypy, + type, pytest [testenv]