@@ -89,6 +89,7 @@ def __init__(
8989 self .provider = provider
9090 self ._turns : list [Turn ] = list (turns or [])
9191 self ._tools : dict [str , Tool ] = {}
92+ self .token_limits : Optional [tuple [int , int ]] = None
9293 self ._echo_options : EchoOptions = {
9394 "rich_markdown" : {},
9495 "rich_console" : {},
@@ -381,6 +382,121 @@ async def token_count_async(
381382 data_model = data_model ,
382383 )
383384
385+ def set_token_limits (
386+ self ,
387+ context_window : int ,
388+ max_tokens : int ,
389+ ):
390+ """
391+ Set a limit on the number of tokens that can be sent to the model.
392+
393+ By default, the size of the chat history is unbounded -- it keeps
394+ growing as you submit more input. This can be wasteful if you don't
395+ need to keep the entire chat history around, and can also lead to
396+ errors if the chat history gets too large for the model to handle.
397+
398+ This method allows you to set a limit to the number of tokens that can
399+ be sent to the model. If the limit is exceeded, the chat history will be
400+ truncated to fit within the limit (i.e., the oldest turns will be
401+ dropped).
402+
403+ Note that many models publish a context window as well as a maximum
404+ output token limit. For example,
405+
406+ <https://platform.openai.com/docs/models/gp#gpt-4o-realtime>
407+ <https://docs.anthropic.com/en/docs/about-claude/models#model-comparison-table>
408+
409+ Also, since the context window is the maximum number of input + output
410+ tokens, the maximum number of tokens that can be sent to the model in a
411+ single request is `context_window - max_tokens`.
412+
413+ Parameters
414+ ----------
415+ context_window
416+ The maximum number of tokens that can be sent to the model.
417+ max_tokens
418+ The maximum number of tokens that the model is allowed to generate
419+ in a single response.
420+
421+ Note
422+ ----
423+ This method uses `.token_count()` to estimate the token count for new input
424+ before truncating the chat history. This is an estimate, so it may not be
425+ perfect. Morever, any chat models based on `ChatOpenAI()` currently do not
426+ take the tool loop into account when estimating token counts. This means, if
427+ your input will trigger many tool calls, and/or the tool results are large,
428+ it's recommended to set a conservative limit on the `context_window`.
429+
430+ Examples
431+ --------
432+ ```python
433+ from chatlas import ChatOpenAI
434+
435+ chat = ChatOpenAI(model="claude-3-5-sonnet-20241022")
436+ chat.set_token_limit(200000, 8192)
437+ ```
438+ """
439+ if max_tokens >= context_window :
440+ raise ValueError ("`max_tokens` must be less than the `context_window`." )
441+ self .token_limits = (context_window , max_tokens )
442+
443+ def _maybe_drop_turns (
444+ self ,
445+ * args : Content | str ,
446+ data_model : Optional [type [BaseModel ]] = None ,
447+ ):
448+ """
449+ Drop turns from the chat history if they exceed the token limits.
450+ """
451+
452+ # Do nothing if token limits are not set
453+ if self .token_limits is None :
454+ return None
455+
456+ turns = self .get_turns (include_system_prompt = False )
457+
458+ # Do nothing if this is the first turn
459+ if len (turns ) == 0 :
460+ return None
461+
462+ last_turn = turns [- 1 ]
463+
464+ # Sanity checks (i.e., when about to submit new input, the last turn should
465+ # be from the assistant and should contain token counts)
466+ if last_turn .role != "assistant" :
467+ raise ValueError (
468+ "Expected the last turn must be from the assistant. Please report this issue."
469+ )
470+
471+ if last_turn .tokens is None :
472+ raise ValueError (
473+ "Can't impose token limits since assistant turns contain token counts. "
474+ "Please report this issue and consider setting `.token_limits` to `None`."
475+ )
476+
477+ context_window , max_tokens = self .token_limits
478+ max_input_size = context_window - max_tokens
479+
480+ # Estimate the token count for the (new) user turn
481+ input_tokens = self .token_count (* args , data_model = data_model )
482+
483+ # Do nothing if current history size plus input size is within the limit
484+ remaining_tokens = max_input_size - input_tokens
485+ if sum (last_turn .tokens ) < remaining_tokens :
486+ return self
487+
488+ tokens = self .tokens ()
489+
490+ # Drop turns until they (plus the new input) fit within the token limits
491+ # TODO: we also need to account for the fact that dropping part of a tool loop is problematic
492+ while sum (tokens ) >= remaining_tokens :
493+ del turns [2 :]
494+ del tokens [2 :]
495+
496+ self .set_turns (turns )
497+
498+ return None
499+
384500 def app (
385501 self ,
386502 * ,
@@ -531,6 +647,8 @@ def chat(
531647 A (consumed) response from the chat. Apply `str()` to this object to
532648 get the text content of the response.
533649 """
650+ self ._maybe_drop_turns (* args )
651+
534652 turn = user_turn (* args )
535653
536654 display = self ._markdown_display (echo = echo )
@@ -581,6 +699,9 @@ async def chat_async(
581699 A (consumed) response from the chat. Apply `str()` to this object to
582700 get the text content of the response.
583701 """
702+ # TODO: async version?
703+ self ._maybe_drop_turns (* args )
704+
584705 turn = user_turn (* args )
585706
586707 display = self ._markdown_display (echo = echo )
@@ -627,6 +748,8 @@ def stream(
627748 An (unconsumed) response from the chat. Iterate over this object to
628749 consume the response.
629750 """
751+ self ._maybe_drop_turns (* args )
752+
630753 turn = user_turn (* args )
631754
632755 display = self ._markdown_display (echo = echo )
@@ -672,6 +795,9 @@ async def stream_async(
672795 An (unconsumed) response from the chat. Iterate over this object to
673796 consume the response.
674797 """
798+ # TODO: async version?
799+ self ._maybe_drop_turns (* args )
800+
675801 turn = user_turn (* args )
676802
677803 display = self ._markdown_display (echo = echo )
@@ -715,6 +841,7 @@ def extract_data(
715841 dict[str, Any]
716842 The extracted data.
717843 """
844+ self ._maybe_drop_turns (* args , data_model = data_model )
718845
719846 display = self ._markdown_display (echo = echo )
720847
@@ -775,6 +902,8 @@ async def extract_data_async(
775902 dict[str, Any]
776903 The extracted data.
777904 """
905+ # TODO: async version?
906+ self ._maybe_drop_turns (* args , data_model = data_model )
778907
779908 display = self ._markdown_display (echo = echo )
780909
0 commit comments