33from queue import Queue
44from typing import TYPE_CHECKING , Any
55
6+ import jiter
67from litellm import ModelResponseStream
78
89from dspy .adapters .chat_adapter import ChatAdapter
@@ -49,6 +50,8 @@ def __init__(
4950 self .cache_hit = False
5051 self .allow_reuse = allow_reuse
5152
53+ self .json_adapter_state = {"field_accumulated_messages" : "" }
54+
5255 self .adapter_identifiers = {
5356 "ChatAdapter" : {
5457 "start_identifier" : f"[[ ## { self .signature_field_name } ## ]]" ,
@@ -62,7 +65,7 @@ def __init__(
6265 "end_identifier" : re .compile (r"\w*\"(,|\s*})" ),
6366 "start_indicator" : '"' ,
6467 "end_pattern_prefixes" : ['"' , '",' , '" ' , '"}' ],
65- "end_pattern_contains" : None ,
68+ "end_pattern_contains" : "}" ,
6669 },
6770 "XMLAdapter" : {
6871 "start_identifier" : f"<{ self .signature_field_name } >" ,
@@ -126,6 +129,7 @@ def receive(self, chunk: ModelResponseStream):
126129 self .cache_hit = False
127130 self .field_start_queue = []
128131 self .field_end_queue = Queue ()
132+ self .json_adapter_state ["field_accumulated_messages" ] = ""
129133 self .stream_start = False
130134 else :
131135 return
@@ -147,7 +151,7 @@ def receive(self, chunk: ModelResponseStream):
147151 is_last_chunk = self .stream_end ,
148152 )
149153
150- if chunk_message and start_identifier in chunk_message :
154+ if chunk_message and start_identifier in chunk_message and not isinstance ( settings . adapter , JSONAdapter ) :
151155 # If the cache is hit, the chunk_message could be the full response. When it happens we can
152156 # directly end the stream listening. In some models like gemini, each stream chunk can be multiple
153157 # tokens, so it's possible that response only has one chunk, we also fall back to this logic.
@@ -180,10 +184,13 @@ def receive(self, chunk: ModelResponseStream):
180184 # Keep the part after the start_identifier from the concat_message, we need to write it to the buffer.
181185 value_start_index = concat_message .find (start_identifier ) + len (start_identifier )
182186 chunk_message = concat_message [value_start_index :].lstrip ()
183- if isinstance (settings .adapter , JSONAdapter ) and chunk_message .startswith ('"' ):
184- # For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier
185- # because there could be a few splitters between ':' and '"', e.g., '"name": "value"'.
186- chunk_message = chunk_message [1 :]
187+
188+ if isinstance (settings .adapter , JSONAdapter ):
189+ # For JSONAdapter, we rely on partial json parsing to detect the end of the field we are listening
190+ # to, so we need to maintain a few extra states to help us with that.
191+ # We add an extra "{" to the beginning of the field_accumulated_messages, so we can detect the
192+ # appearance of the next key.
193+ self .json_adapter_state ["field_accumulated_messages" ] += "{" + start_identifier
187194
188195 elif self ._buffered_message_end_with_start_identifier (concat_message .strip (), start_identifier ):
189196 # If the buffered message ends with part of the start_identifier, we keep looking for the
@@ -196,30 +203,101 @@ def receive(self, chunk: ModelResponseStream):
196203
197204 if self .stream_start and chunk_message :
198205 # The stream is started, we keep returning the token until we see the start of the next field.
199- token = None
200206 self .field_end_queue .put (chunk_message )
201207
208+ token = None
202209 concat_message = "" .join (self .field_end_queue .queue ).strip ()
203- if re .search (end_identifier , concat_message ):
204- # The next field is identified, we can end the stream and flush out all tokens in the buffer.
205- self .stream_end = True
206- token = self .flush ()
207- token = token .rstrip () # Remove the trailing \n\n
208- elif not self ._could_form_end_identifier (concat_message , adapter_name ):
210+
211+ if not self ._could_form_end_identifier (concat_message , adapter_name ):
209212 # Buffer cannot form end identifier, safe to flush out the tokens in the buffer.
210213 token = self .flush ()
211214 elif self .field_end_queue .qsize () > 10 :
212- # Buffer could form end identifier, but we've exceeded max buffer size
213- # Yield the oldest token to prevent unbounded buffering
215+ # We keep the last 10 tokens in the buffer if they can potentially form the end_identifier to avoid
216+ # sending the DSPy boilerplate tokens to users. 10 is a heuristic number that is sufficient to capture
217+ # the end_identifier for all LMs.
214218 token = self .field_end_queue .get ()
215219
216- if token :
220+ # TODO: Put adapter streaming handling into individial classes, e.g., `JSONAdapterStreamListener`,
221+ # `ChatAdapterStreamListener`, `XMLAdapterStreamListener` instead of having many adhoc code in the
222+ # `StreamListener` class.
223+ if isinstance (settings .adapter , JSONAdapter ):
224+ # JSONAdapter uses partial json parsing to detect the end of the field we are listening to, instead of
225+ # relying on the end_identifier.
226+ return self ._json_adapter_handle_stream_chunk (token , chunk_message )
227+ else :
228+ # Other adapters rely on the end_identifier to detect the end of the field we are listening to.
229+ return self ._default_handle_stream_chunk (token , end_identifier )
230+
231+ def _json_adapter_handle_stream_chunk (self , token : str , chunk_message : str ) -> StreamResponse | None :
232+ self .json_adapter_state ["field_accumulated_messages" ] += chunk_message
233+ if self .json_adapter_state ["field_accumulated_messages" ].rstrip ().endswith ("}" ):
234+ # When the accumulated tokens end with a curly bracket, that means the streaming for the `dspy.Predict` we
235+ # are listening to is probably finished, we need to run a check and decide whether to end the stream.
236+ try :
237+ # If the parse doesn't raise an error, that means the accumulated tokens is a valid json object. Because
238+ # we add an extra "{" to the beginning of the field_accumulated_messages, so we know the streaming is
239+ # finished.
240+ jiter .from_json (self .json_adapter_state ["field_accumulated_messages" ].encode ("utf-8" ))
241+ self .stream_end = True
242+ last_token = self .flush ()
243+ right_curly_bracket_index = last_token .rfind ("}" )
244+ token = (
245+ token + last_token [:right_curly_bracket_index ] if token else last_token [:right_curly_bracket_index ]
246+ )
217247 return StreamResponse (
218- self .predict_name ,
219- self .signature_field_name ,
220- token ,
221- is_last_chunk = self .stream_end ,
248+ self .predict_name , self .signature_field_name , token , is_last_chunk = self .stream_end
222249 )
250+ except ValueError :
251+ pass
252+
253+ try :
254+ parsed = jiter .from_json (
255+ self .json_adapter_state ["field_accumulated_messages" ].encode ("utf-8" ),
256+ partial_mode = "trailing-strings" ,
257+ )
258+ if len (parsed ) > 1 :
259+ # If partial json parsing finds a second key, that means the streaming for the field we are listening to
260+ # is finished.
261+ self .stream_end = True
262+ last_token = self .flush ()
263+
264+ keys = list (parsed .keys ())
265+ next_field_name = None
266+ for key in keys :
267+ if key != self .signature_field_name :
268+ next_field_name = key
269+ break
270+
271+ last_token_index = last_token .find (next_field_name )
272+ token = token + last_token [:last_token_index ] if token else last_token [:last_token_index ]
273+ except ValueError :
274+ pass
275+
276+ if token :
277+ return StreamResponse (
278+ self .predict_name ,
279+ self .signature_field_name ,
280+ token ,
281+ is_last_chunk = self .stream_end ,
282+ )
283+
284+ def _default_handle_stream_chunk (self , token : str , end_identifier : str ) -> StreamResponse | None :
285+ concat_message = "" .join (self .field_end_queue .queue ).strip ()
286+
287+ if re .search (end_identifier , concat_message ):
288+ # The next field is identified, we can end the stream and flush out all tokens in the buffer.
289+ self .stream_end = True
290+ last_token = self .flush ()
291+ token = token + last_token if token else last_token
292+ token = token .rstrip () # Remove the trailing \n\n
293+
294+ if token :
295+ return StreamResponse (
296+ self .predict_name ,
297+ self .signature_field_name ,
298+ token ,
299+ is_last_chunk = self .stream_end ,
300+ )
223301
224302 def flush (self ) -> str :
225303 """Flush all tokens in the field end queue.
@@ -231,12 +309,7 @@ def flush(self) -> str:
231309 last_tokens = "" .join (self .field_end_queue .queue )
232310 self .field_end_queue = Queue ()
233311 if isinstance (settings .adapter , JSONAdapter ):
234- match = re .search (r'",|"\s*}' , last_tokens )
235- if match :
236- boundary_index = match .start ()
237- else :
238- boundary_index = len (last_tokens )
239- return last_tokens [:boundary_index ]
312+ return last_tokens
240313 elif isinstance (settings .adapter , XMLAdapter ):
241314 boundary_index = last_tokens .find (f"</{ self .signature_field_name } >" )
242315 if boundary_index == - 1 :
@@ -314,13 +387,6 @@ def find_predictor_for_stream_listeners(
314387 f"Signature field { field_name } is not unique in the program, cannot automatically determine which "
315388 "predictor to use for streaming. Please specify the predictor to listen to."
316389 )
317-
318- if not _is_streamable (field_info .annotation ):
319- raise ValueError (
320- f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, "
321- f"but your field { field_name } is of type { field_info .annotation } ."
322- )
323-
324390 field_name_to_named_predictor [field_name ] = (name , predictor )
325391
326392 predict_id_to_listener = defaultdict (list )
@@ -337,13 +403,3 @@ def find_predictor_for_stream_listeners(
337403 listener .predict_name , listener .predict = field_name_to_named_predictor [listener .signature_field_name ]
338404 predict_id_to_listener [id (listener .predict )].append (listener )
339405 return predict_id_to_listener
340-
341-
342- def _is_streamable (field_type : type | None ) -> bool :
343- if field_type is None :
344- return False
345- if field_type is str :
346- return True
347- if issubclass (field_type , Type ):
348- return field_type .is_streamable ()
349- return False
0 commit comments