-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Allow stream listener to work on any type #8833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9b584d8
0454cd5
146f62f
45facc4
67a1095
839b939
d562879
43e5fb0
7714b4f
7ebc89d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| from queue import Queue | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| import jiter | ||
| from litellm import ModelResponseStream | ||
|
|
||
| from dspy.adapters.chat_adapter import ChatAdapter | ||
|
|
@@ -49,6 +50,8 @@ def __init__( | |
| self.cache_hit = False | ||
| self.allow_reuse = allow_reuse | ||
|
|
||
| self.json_adapter_state = {"field_accumulated_messages": ""} | ||
|
|
||
| self.adapter_identifiers = { | ||
| "ChatAdapter": { | ||
| "start_identifier": f"[[ ## {self.signature_field_name} ## ]]", | ||
|
|
@@ -62,7 +65,7 @@ def __init__( | |
| "end_identifier": re.compile(r"\w*\"(,|\s*})"), | ||
| "start_indicator": '"', | ||
| "end_pattern_prefixes": ['"', '",', '" ', '"}'], | ||
| "end_pattern_contains": None, | ||
| "end_pattern_contains": "}", | ||
| }, | ||
| "XMLAdapter": { | ||
| "start_identifier": f"<{self.signature_field_name}>", | ||
|
|
@@ -126,6 +129,7 @@ def receive(self, chunk: ModelResponseStream): | |
| self.cache_hit = False | ||
| self.field_start_queue = [] | ||
| self.field_end_queue = Queue() | ||
| self.json_adapter_state["field_accumulated_messages"] = "" | ||
| self.stream_start = False | ||
| else: | ||
| return | ||
|
|
@@ -147,7 +151,7 @@ def receive(self, chunk: ModelResponseStream): | |
| is_last_chunk=self.stream_end, | ||
| ) | ||
|
|
||
| if chunk_message and start_identifier in chunk_message: | ||
| if chunk_message and start_identifier in chunk_message and not isinstance(settings.adapter, JSONAdapter): | ||
| # If the cache is hit, the chunk_message could be the full response. When it happens we can | ||
| # directly end the stream listening. In some models like gemini, each stream chunk can be multiple | ||
| # tokens, so it's possible that response only has one chunk, we also fall back to this logic. | ||
|
|
@@ -180,10 +184,13 @@ def receive(self, chunk: ModelResponseStream): | |
| # Keep the part after the start_identifier from the concat_message, we need to write it to the buffer. | ||
| value_start_index = concat_message.find(start_identifier) + len(start_identifier) | ||
| chunk_message = concat_message[value_start_index:].lstrip() | ||
| if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'): | ||
| # For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier | ||
| # because there could be a few splitters between ':' and '"', e.g., '"name": "value"'. | ||
| chunk_message = chunk_message[1:] | ||
|
|
||
| if isinstance(settings.adapter, JSONAdapter): | ||
| # For JSONAdapter, we rely on partial json parsing to detect the end of the field we are listening | ||
| # to, so we need to maintain a few extra states to help us with that. | ||
| # We add an extra "{" to the beginning of the field_accumulated_messages, so we can detect the | ||
| # appearance of the next key. | ||
| self.json_adapter_state["field_accumulated_messages"] += "{" + start_identifier | ||
|
|
||
| elif self._buffered_message_end_with_start_identifier(concat_message.strip(), start_identifier): | ||
| # If the buffered message ends with part of the start_identifier, we keep looking for the | ||
|
|
@@ -196,30 +203,98 @@ def receive(self, chunk: ModelResponseStream): | |
|
|
||
| if self.stream_start and chunk_message: | ||
| # The stream is started, we keep returning the token until we see the start of the next field. | ||
| token = None | ||
| self.field_end_queue.put(chunk_message) | ||
|
|
||
| token = None | ||
| concat_message = "".join(self.field_end_queue.queue).strip() | ||
| if re.search(end_identifier, concat_message): | ||
| # The next field is identified, we can end the stream and flush out all tokens in the buffer. | ||
| self.stream_end = True | ||
| token = self.flush() | ||
| token = token.rstrip() # Remove the trailing \n\n | ||
| elif not self._could_form_end_identifier(concat_message, adapter_name): | ||
|
|
||
| if not self._could_form_end_identifier(concat_message, adapter_name): | ||
| # Buffer cannot form end identifier, safe to flush out the tokens in the buffer. | ||
| token = self.flush() | ||
| elif self.field_end_queue.qsize() > 10: | ||
| # Buffer could form end identifier, but we've exceeded max buffer size | ||
| # Yield the oldest token to prevent unbounded buffering | ||
| # We keep the last 10 tokens in the buffer if they can potentially form the end_identifier to avoid | ||
| # sending the DSPy boilerplate tokens to users. 10 is a heuristic number that is sufficient to capture | ||
| # the end_identifier for all LMs. | ||
| token = self.field_end_queue.get() | ||
|
|
||
| if token: | ||
| if isinstance(settings.adapter, JSONAdapter): | ||
| # JSONAdapter uses partial json parsing to detect the end of the field we are listening to, instead of | ||
| # relying on the end_identifier. | ||
| return self._json_adapter_handle_stream_chunk(token, chunk_message) | ||
| else: | ||
| # Other adapters rely on the end_identifier to detect the end of the field we are listening to. | ||
| return self._default_handle_stream_chunk(token, end_identifier) | ||
|
|
||
| def _json_adapter_handle_stream_chunk(self, token: str, chunk_message: str) -> StreamResponse | None: | ||
| self.json_adapter_state["field_accumulated_messages"] += chunk_message | ||
| if self.json_adapter_state["field_accumulated_messages"].rstrip().endswith("}"): | ||
| # When the accumulated tokens end with a curly bracket, that means the streaming for the `dspy.Predict` we | ||
| # are listening to is probably finished, we need to run a check and decide whether to end the stream. | ||
| try: | ||
| # If the parse doesn't raise an error, that means the accumulated tokens is a valid json object. Because | ||
| # we add an extra "{" to the beginning of the field_accumulated_messages, so we know the streaming is | ||
| # finished. | ||
| jiter.from_json(self.json_adapter_state["field_accumulated_messages"].encode("utf-8")) | ||
| self.stream_end = True | ||
| last_token = self.flush() | ||
| right_curly_bracket_index = last_token.rfind("}") | ||
| token = ( | ||
| token + last_token[:right_curly_bracket_index] if token else last_token[:right_curly_bracket_index] | ||
| ) | ||
| return StreamResponse( | ||
| self.predict_name, | ||
| self.signature_field_name, | ||
| token, | ||
| is_last_chunk=self.stream_end, | ||
| self.predict_name, self.signature_field_name, token, is_last_chunk=self.stream_end | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So overall we will return a raw string chunk so the deserialization needs to happen on the caller side?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! It should be pretty simple for the caller to accumulate. |
||
| ) | ||
| except ValueError: | ||
| pass | ||
|
|
||
| try: | ||
| parsed = jiter.from_json( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is interesting, can't we just count the number of { and }?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. discussed offline, please see the new implementation for a more robust solution. |
||
| self.json_adapter_state["field_accumulated_messages"].encode("utf-8"), | ||
| partial_mode="trailing-strings", | ||
| ) | ||
| if len(parsed) > 1: | ||
| # If partial json parsing finds a second key, that means the streaming for the field we are listening to | ||
| # is finished. | ||
| self.stream_end = True | ||
| last_token = self.flush() | ||
|
|
||
| keys = list(parsed.keys()) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so. It shouldn't affect the logic here though, we just need the key name to cut off the extra characters. |
||
| next_field_name = None | ||
| for key in keys: | ||
| if key != self.signature_field_name: | ||
| next_field_name = key | ||
| break | ||
|
|
||
| last_token_index = last_token.find(next_field_name) | ||
| token = token + last_token[:last_token_index] if token else last_token[:last_token_index] | ||
| except ValueError: | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pass | ||
|
|
||
| if token: | ||
| return StreamResponse( | ||
| self.predict_name, | ||
| self.signature_field_name, | ||
| token, | ||
| is_last_chunk=self.stream_end, | ||
| ) | ||
|
|
||
| def _default_handle_stream_chunk(self, token: str, end_identifier: str) -> StreamResponse | None: | ||
chenmoneygithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| concat_message = "".join(self.field_end_queue.queue).strip() | ||
|
|
||
| if re.search(end_identifier, concat_message): | ||
| # The next field is identified, we can end the stream and flush out all tokens in the buffer. | ||
| self.stream_end = True | ||
| last_token = self.flush() | ||
| token = token + last_token if token else last_token | ||
| token = token.rstrip() # Remove the trailing \n\n | ||
|
|
||
| if token: | ||
| return StreamResponse( | ||
| self.predict_name, | ||
| self.signature_field_name, | ||
| token, | ||
| is_last_chunk=self.stream_end, | ||
| ) | ||
|
|
||
| def flush(self) -> str: | ||
| """Flush all tokens in the field end queue. | ||
|
|
@@ -231,12 +306,7 @@ def flush(self) -> str: | |
| last_tokens = "".join(self.field_end_queue.queue) | ||
| self.field_end_queue = Queue() | ||
| if isinstance(settings.adapter, JSONAdapter): | ||
| match = re.search(r'",|"\s*}', last_tokens) | ||
| if match: | ||
| boundary_index = match.start() | ||
| else: | ||
| boundary_index = len(last_tokens) | ||
| return last_tokens[:boundary_index] | ||
| return last_tokens | ||
| elif isinstance(settings.adapter, XMLAdapter): | ||
| boundary_index = last_tokens.find(f"</{self.signature_field_name}>") | ||
| if boundary_index == -1: | ||
|
|
@@ -314,13 +384,6 @@ def find_predictor_for_stream_listeners( | |
| f"Signature field {field_name} is not unique in the program, cannot automatically determine which " | ||
| "predictor to use for streaming. Please specify the predictor to listen to." | ||
| ) | ||
|
|
||
| if not _is_streamable(field_info.annotation): | ||
| raise ValueError( | ||
| f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, " | ||
| f"but your field {field_name} is of type {field_info.annotation}." | ||
| ) | ||
|
|
||
| field_name_to_named_predictor[field_name] = (name, predictor) | ||
|
|
||
| predict_id_to_listener = defaultdict(list) | ||
|
|
@@ -337,13 +400,3 @@ def find_predictor_for_stream_listeners( | |
| listener.predict_name, listener.predict = field_name_to_named_predictor[listener.signature_field_name] | ||
| predict_id_to_listener[id(listener.predict)].append(listener) | ||
| return predict_id_to_listener | ||
|
|
||
|
|
||
| def _is_streamable(field_type: type | None) -> bool: | ||
| if field_type is None: | ||
| return False | ||
| if field_type is str: | ||
| return True | ||
| if issubclass(field_type, Type): | ||
| return field_type.is_streamable() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we delete
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes good question, I did think about it, and my take is that's still useful to allow streaming listener to hit this part on certain-type fields like Citation. For custom type that's not streamable, it will use the normal streaming handling. |
||
| return False | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.