Streaming model outputs#1236
Conversation
aaba9d4 to
fc4546a
Compare
| self, | ||
| tools: List[Tool], | ||
| model: Callable[[List[Dict[str, str]]], ChatMessage], | ||
| model: Model, |
There was a problem hiding this comment.
If we add the option to stream model outputs, model can not just be a Callable returning a ChatMessage.
This means we'l have to edit the parts of the doc that show how to create a Model, to explain how to inherit from the base Model class instead of directly creating a Callable.
| def _run( | ||
| self, task: str, max_steps: int, images: List["PIL.Image.Image"] | None = None | ||
| ) -> Generator[ActionStep | AgentType, None, None]: | ||
| ) -> Generator[ActionStep | FinalAnswerStep, None, None]: |
There was a problem hiding this comment.
Fixing this type hint.
| self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0 | ||
| ): | ||
| planning_step = self._create_planning_step( | ||
| planning_step = self._generate_planning_step( |
There was a problem hiding this comment.
"Generate" is better IMO since an LLM output is generated in this function: it's not simply about creating an empty object.
| yield action_step | ||
| yield FinalAnswerStep(handle_agent_output_types(final_answer)) | ||
|
|
||
| def _create_action_step(self, step_start_time: float, images: List["PIL.Image.Image"] | None) -> ActionStep: |
There was a problem hiding this comment.
This method is not useful anymore and obscures the workflow
| except Exception as e: | ||
| raise AgentParsingError(f"Error while generating or parsing output:\n{e}", self.logger) from e | ||
| if self.stream_outputs: | ||
| raise NotImplementedError("Stream outputs are not yet implemented for ToolCallingAgent") |
There was a problem hiding this comment.
Streaming output with ToolCallingAgent implies streaming ChoiceDeltaToolCallFunction objects from various APIs, which is worth another PR.
| return asdict(self) | ||
|
|
||
| def to_messages(self, **kwargs) -> List[Dict[str, Any]]: | ||
| def to_messages(self, summary_mode: bool = False) -> List[Message]: |
There was a problem hiding this comment.
Harmonize the API for all to_messages methods
| ) | ||
| return self.postprocess_message(first_message, tools_to_call_from) | ||
|
|
||
| def generate_stream( |
There was a problem hiding this comment.
New generate_stream methods. once we've setup streaming for ToolCallingAgent, the generate method will simply be able to call generate_stream and return the final completion.
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks for the contributions! There's a lot of great work here. Having so many changes bundled into a single PR does make it a bit challenging to review thoroughly, but I appreciate the effort.
These are just some initial comments, I’ll continue reviewing the rest of the PR shortly, once you tell me no more changes are coming in...
| `ChatMessage`: A chat message object containing the model's response. | ||
| """ | ||
| pass # To be implemented in child classes! | ||
| raise NotImplementedError("This method must be implemented in child classes") |
There was a problem hiding this comment.
What about defining Model as an abstract class and decorating this method as abstractmethod?
There was a problem hiding this comment.
For generate we could! For generate_stream however, it will sometimes be implemented by child classes, sometimes not, so making it an abstract method would prevent proper intialization. Do we prefer to make only generate an abstract method, or keep the common implementation by only raising NotImplementedError in both methods?
There was a problem hiding this comment.
I think we can delete generate_stream here (see comment above).
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@albertvillanova it's only minor changes now, you can review |
albertvillanova
left a comment
There was a problem hiding this comment.
Some comments to maintain backward compatibility: users may pass a Callable as model.
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
albertvillanova
left a comment
There was a problem hiding this comment.
Another batch of comments...
Sorry, difficult to go trough more than 1,000 modified lines...
|
|
||
|
|
||
| def has_implemented_method(instance, parent_class, method_name: str) -> bool: | ||
| instance_method = getattr(instance.__class__, method_name, None) | ||
| parent_method = getattr(parent_class, method_name, None) | ||
| return instance_method is not parent_method |
There was a problem hiding this comment.
No longer used:
| def has_implemented_method(instance, parent_class, method_name: str) -> bool: | |
| instance_method = getattr(instance.__class__, method_name, None) | |
| parent_method = getattr(parent_class, method_name, None) | |
| return instance_method is not parent_method |
There was a problem hiding this comment.
It was a mistake to not be using it, we do need it as a check in the init: just reintroduced it!
There was a problem hiding this comment.
I think this is a very hacky way to check if the model has the generate_stream method.
My suggestion:
- as this method is optional, the parent
Modelshould not have it (see discussion about settinggenerateas abstractmethod, but notgenerate_stream: Streaming model outputs #1236 (comment)). - we can remove this hacky method
- we can just check if the model hast
generate_streammethod:hasattr(self.model, "generate_stream")
| **completion_kwargs, stream=True, stream_options={"include_usage": True} | ||
| ): | ||
| if event.choices: | ||
| if event.choices[0].delta is None: |
There was a problem hiding this comment.
Have you tested this? I'm wondering if event.choices[0].delta can be None or it is always a class instance.
Anyway, maybe we could add some tests for generate_stream.
There was a problem hiding this comment.
Just aded tests for generate_stream in LiteLLMModel, InferenceClientModel, and TransformersModel.
There was a problem hiding this comment.
I have manually checked your tests for Transformers and InferenceClient and the condition event.choices[0].delta is None is never fulfilled.
There was a problem hiding this comment.
Have you checked it with LiteLLMModel, gpt-4?
| def parse_tool_calls(self, message: ChatMessage) -> ChatMessage: | ||
| """Sometimes APIs do not return the tool call as a specific object, so we need to parse it.""" | ||
| message.role = MessageRole.ASSISTANT # Overwrite role if needed | ||
| if not message.tool_calls: | ||
| assert message.content is not None, "Message contains no content and no tool calls" | ||
| message.tool_calls = [ | ||
| get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key) | ||
| ] | ||
| assert len(message.tool_calls) > 0, "No tool call was found in the model output" | ||
| for tool_call in message.tool_calls: | ||
| tool_call.function.arguments = parse_json_if_needed(tool_call.function.arguments) | ||
| return message |
There was a problem hiding this comment.
Why do we need this for streaming and we didn't need before? Maybe I'm missing something...
There was a problem hiding this comment.
It's more a simplification: cf this comment
| from vllm import LLM # type: ignore | ||
| from vllm.transformers_utils.tokenizer import get_tokenizer # type: ignore | ||
|
|
||
| self.model_kwargs = { | ||
| **(model_kwargs or {}), | ||
| "model": model_id, | ||
| } | ||
| self.model_kwargs = model_kwargs or {} | ||
| super().__init__(**kwargs) | ||
| self.model_id = model_id | ||
| self.model = LLM(**self.model_kwargs) | ||
| self.model = LLM(model=model_id, **self.model_kwargs) | ||
| assert self.model is not None |
There was a problem hiding this comment.
I think these changes are not related to streaming. But is it necessary the assert here? I mean, any model is prone to receiving a None as model_id...
|
|
||
| import torch | ||
| from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel | ||
| from vllm.distributed.parallel_state import ( # type: ignore |
There was a problem hiding this comment.
Why this # type: ignore? We are not enforcing static type checking...
There was a problem hiding this comment.
It will improve readability for everyone using static type checking. If you're against that we can also remove it!
| for message in messages: | ||
| if not isinstance(message["content"], str): | ||
| message["content"] = message["content"][0]["text"] |
There was a problem hiding this comment.
Why do we need this now and not before?
There was a problem hiding this comment.
It was a dirty fix for an error that I missed: just fixed it.
albertvillanova
left a comment
There was a problem hiding this comment.
Another batch of reviews.
Thanks for your contribution.
| for event in self.client.completion(**completion_kwargs, stream=True, stream_options={"include_usage": True}): | ||
| if event.choices: | ||
| if event.choices[0].delta is None: | ||
| if not event.choices[0].finish_reason == "stop": |
There was a problem hiding this comment.
Simplify the logic:
| if not event.choices[0].finish_reason == "stop": | |
| if event.choices[0].finish_reason != "stop": |
| yield CompletionDelta( | ||
| content=event.choices[0].delta.content, | ||
| ) | ||
| if getattr(event, "usage", None): |
There was a problem hiding this comment.
This condition can only happen if the condition above is False, is this assumption right?
| if getattr(event, "usage", None): | |
| elif getattr(event, "usage", None): |
There was a problem hiding this comment.
Maybe some messages contain both some content and usage, so we would need to catch both using the double if instead of if/elif.
| if tools_to_call_from: | ||
| chat_message.tool_calls = [ | ||
| get_tool_call_from_text(output_text, self.tool_name_key, self.tool_arguments_key) | ||
| ] | ||
| return chat_message |
There was a problem hiding this comment.
Why do we no longer need to set .tool_calls attribute?
There was a problem hiding this comment.
Because now this will be handled directly in the ToolCallingAgent.step method by parse_tool_calls!
| self.model_id = model_id | ||
|
|
||
| default_max_tokens = 5000 | ||
| default_max_tokens = 4096 |
There was a problem hiding this comment.
Any reason fir this change?
There was a problem hiding this comment.
Powers of 2 are always better!
| or kwargs.get("max_tokens") | ||
| or self.kwargs.get("max_new_tokens") | ||
| or self.kwargs.get("max_tokens") | ||
| or 1024 |
There was a problem hiding this comment.
Do we want to hardcode this value?
There was a problem hiding this comment.
I'm actually not sure: in case it's not filled, should we leave this to the underlying model/API?
| """Sometimes APIs do not return the tool call as a specific object, so we need to parse it.""" | ||
| message.role = MessageRole.ASSISTANT # Overwrite role if needed | ||
| if not message.tool_calls: | ||
| assert message.content is not None, "Message contains no content and no tool calls" |
There was a problem hiding this comment.
Differently from before, now we can raise an error here. Is this intended?
There was a problem hiding this comment.
Yes: either the model returns a tool call, either it returns some text, but it should at least return one.
| message.tool_calls = [ | ||
| get_tool_call_from_text(message.content, self.tool_name_key, self.tool_arguments_key) | ||
| ] | ||
| assert len(message.tool_calls) > 0, "No tool call was found in the model output" |
There was a problem hiding this comment.
Differently from before, now we can raise an error here. Is this intended?
There was a problem hiding this comment.
Yes: it will help the model correct its output!
| def __call__(self, *args, **kwargs): | ||
| return self.generate(*args, **kwargs) | ||
|
|
||
| def parse_tool_calls(self, message: ChatMessage) -> ChatMessage: |
There was a problem hiding this comment.
This function seems to replace the previous postprocess_message. However, this new function is only called by ToolCallingAgent.step, whereas the previous postprocess_message was called by all API models (__call__ method). Is this intended?
There was a problem hiding this comment.
Yes: the idea is that we now more clearly separate:
- Generation: the
Modeljust generates text. Sometimes, depending on the API/Model, it can contain pre-defined tool_calls in the dedicated attribute. - Parsing:
postprocess_message, which will if there's no tool call so far, fill the tool_calls attribute using tool calls parsed from the text.
albertvillanova
left a comment
There was a problem hiding this comment.
Another batch of reviews done before your today modifications.
| executor_type: str | None = "local", | ||
| executor_kwargs: Optional[Dict[str, Any]] = None, | ||
| max_print_outputs_length: Optional[int] = None, | ||
| stream_outputs: bool = False, |
There was a problem hiding this comment.
What about calling the param just stream, as in the OpenAI spec for Chat completion create?
There was a problem hiding this comment.
This is a difficult question: for a chat completion, stream is obviously about streaming model outputs.
For an agent, what do you stream: agent steps? (as in agent.run() with stream=True)
Since here it's about streaming outputs, I put that in the name stream_outputs. but maybe there's an even more intuitive API.
| executor_type (`str`, default `"local"`): Which executor type to use between `"local"`, `"e2b"`, or `"docker"`. | ||
| executor_kwargs (`dict`, *optional*): Additional arguments to pass to initialize the executor. | ||
| max_print_outputs_length (`int`, *optional*): Maximum length of the print outputs. | ||
| stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution. |
There was a problem hiding this comment.
In docstrings, optional means default None.
| stream_outputs (`bool`, *optional*, default `False`): Whether to stream outputs during execution. | |
| stream_outputs (`bool`, default `False`): Whether to stream outputs during execution. |
|
|
||
|
|
||
| def has_implemented_method(instance, parent_class, method_name: str) -> bool: | ||
| instance_method = getattr(instance.__class__, method_name, None) | ||
| parent_method = getattr(parent_class, method_name, None) | ||
| return instance_method is not parent_method |
There was a problem hiding this comment.
I think this is a very hacky way to check if the model has the generate_stream method.
My suggestion:
- as this method is optional, the parent
Modelshould not have it (see discussion about settinggenerateas abstractmethod, but notgenerate_stream: Streaming model outputs #1236 (comment)). - we can remove this hacky method
- we can just check if the model hast
generate_streammethod:hasattr(self.model, "generate_stream")
| `ChatMessage`: A chat message object containing the model's response. | ||
| """ | ||
| pass # To be implemented in child classes! | ||
| raise NotImplementedError("This method must be implemented in child classes") |
There was a problem hiding this comment.
I think we can delete generate_stream here (see comment above).
| def generate_stream(self, *args, **kwargs) -> Generator[CompletionDelta, None, None]: | ||
| raise NotImplementedError("This method must be implemented in child classes") | ||
|
|
There was a problem hiding this comment.
| def generate_stream(self, *args, **kwargs) -> Generator[CompletionDelta, None, None]: | |
| raise NotImplementedError("This method must be implemented in child classes") |
| self.stream_outputs = stream_outputs | ||
| can_stream = has_implemented_method(self.model, Model, "generate_stream") | ||
| if self.stream_outputs and not can_stream: | ||
| raise ValueError( | ||
| "`stream_outputs` is set to True, but the model class implements no `generate_stream` method." | ||
| ) |
There was a problem hiding this comment.
| self.stream_outputs = stream_outputs | |
| can_stream = has_implemented_method(self.model, Model, "generate_stream") | |
| if self.stream_outputs and not can_stream: | |
| raise ValueError( | |
| "`stream_outputs` is set to True, but the model class implements no `generate_stream` method." | |
| ) | |
| if stream_outputs and not hasattr(self.model, "generate_stream"): | |
| raise ValueError( | |
| "`stream_outputs` is set to True, but the model class implements no `generate_stream` method." | |
| ) | |
| self.stream_outputs = stream_outputs |
| **completion_kwargs, stream=True, stream_options={"include_usage": True} | ||
| ): | ||
| if event.choices: | ||
| if event.choices[0].delta is None: |
There was a problem hiding this comment.
I have manually checked your tests for Transformers and InferenceClient and the condition event.choices[0].delta is None is never fulfilled.
Implement streaming model outputs, to let user see the thoughts of their model displaying live.
Tested for:
Streaming was not implemented, left for future PRs: