diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index b3533c2eb6f6c..ad015757bceca 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -4,7 +4,7 @@ import warnings from collections.abc import Callable, Iterable, Mapping from functools import partial -from typing import Any, Literal, cast +from typing import Any, Literal, TypedDict, cast from langchain_core.messages import ( AIMessage, @@ -124,6 +124,37 @@ """ +class TriggerClause(TypedDict, total=False): + """Dictionary-based trigger specification for AND conditions. + + All specified thresholds in a single `TriggerClause` must be met for the clause to + trigger summarization (AND semantics). When multiple clauses are provided in a list, + summarization triggers if any clause is met (OR semantics). + + Attributes: + tokens: Trigger when token count reaches or exceeds this value. + messages: Trigger when message count reaches or exceeds this value. + fraction: Trigger when token count reaches or exceeds this fraction of the + model's maximum input tokens. + + Example: + ```python + # AND: Trigger when tokens >= 4000 AND messages >= 10 + trigger_clause: TriggerClause = {"tokens": 4000, "messages": 10} + + # Use in a list for OR semantics: + trigger_list: list[TriggerClause] = [ + {"tokens": 5000, "messages": 3}, + {"tokens": 3000, "messages": 6}, + ] + ``` + """ + + tokens: int + messages: int + fraction: float + + def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter: """Tune parameters of approximate token counter based on model type.""" if model._llm_type == "anthropic-chat": # noqa: SLF001 @@ -145,7 +176,9 @@ def __init__( self, model: str | BaseChatModel, *, - trigger: ContextSize | list[ContextSize] | None = None, + trigger: ( + ContextSize | list[ContextSize] | TriggerClause | list[TriggerClause] | None + ) = None, keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP), token_counter: TokenCounter = count_tokens_approximately, summary_prompt: str = DEFAULT_SUMMARY_PROMPT, @@ -175,6 +208,13 @@ def __init__( # Trigger summarization either when 80% of model's max input tokens # is reached or when 100 messages is reached (whichever comes first) [("fraction", 0.8), ("messages", 100)] + + # Trigger when tokens >= 4000 AND messages >= 10 + {"tokens": 4000, "messages": 10} + + # Trigger when (tokens >= 5000 AND messages >= 3) OR + # (tokens >= 3000 AND messages >= 6) + [{"tokens": 5000, "messages": 3}, {"tokens": 3000, "messages": 6}] ``` See [`ContextSize`][langchain.agents.middleware.summarization.ContextSize] @@ -234,18 +274,15 @@ def __init__( model = init_chat_model(model) self.model = model - if trigger is None: - self.trigger: ContextSize | list[ContextSize] | None = None - trigger_conditions: list[ContextSize] = [] - elif isinstance(trigger, list): - validated_list = [self._validate_context_size(item, "trigger") for item in trigger] - self.trigger = validated_list - trigger_conditions = validated_list - else: - validated = self._validate_context_size(trigger, "trigger") - self.trigger = validated - trigger_conditions = [validated] - self._trigger_conditions = trigger_conditions + + # Store the original trigger for backward compatibility + self.trigger: ( + ContextSize | list[ContextSize] | TriggerClause | list[TriggerClause] | None + ) = trigger + + # Normalize trigger into a list of TriggerClause + # (AND inside a TriggerClause, OR across items) + self._trigger_conditions = self._normalize_trigger(trigger) self.keep = self._validate_context_size(keep, "keep") if token_counter is count_tokens_approximately: @@ -255,7 +292,7 @@ def __init__( self.summary_prompt = summary_prompt self.trim_tokens_to_summarize = trim_tokens_to_summarize - requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions) + requires_profile = any("fraction" in clause for clause in self._trigger_conditions) if self.keep[0] == "fraction": requires_profile = True if requires_profile and self._get_profile_limits() is None: @@ -280,8 +317,16 @@ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | cutoff_index = self._determine_cutoff_index(messages) + # If cutoff_index <= 0, we may still want to summarize when the keep policy + # is message-count based (default) and a trigger clause matched; in that case + # summarize at least one message to preserve backward compatibility + # (e.g., tuple-based triggers). if cutoff_index <= 0: - return None + kind, _ = self.keep + if kind == "messages" and len(messages) > 0: + cutoff_index = 1 + else: + return None messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) @@ -308,8 +353,14 @@ async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, cutoff_index = self._determine_cutoff_index(messages) + # If cutoff_index <= 0, mirror sync behavior: allow summarization for message-count keeps + # to preserve backward compatibility. if cutoff_index <= 0: - return None + kind, _ = self.keep + if kind == "messages" and len(messages) > 0: + cutoff_index = 1 + else: + return None messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index) @@ -324,6 +375,71 @@ async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, ] } + def _normalize_trigger( + self, + trigger: (ContextSize | list[ContextSize] | TriggerClause | list[TriggerClause] | None), + ) -> list[TriggerClause]: + """Normalize supported trigger inputs into list of Trigger clauses. + + - tuple ("tokens", 3000) -> [{"tokens": 3000}] + - dict {"tokens": 4000, "messages": 10} -> [{"tokens": 4000, "messages": 10}] + - list of either -> OR across items + """ + if trigger is None: + return [] + + def _validate_and_convert_tuple(t: tuple) -> TriggerClause: + kind, value = self._validate_context_size(t, "trigger") + return cast("TriggerClause", {kind: value}) + + def _validate_mapping(m: Mapping) -> TriggerClause: + """Validate and convert a mapping to a TriggerClause.""" + out: dict[str, float | int] = {} + for k, v in m.items(): + if k not in {"tokens", "messages", "fraction"}: + msg = f"Unsupported trigger metric: {k!r}" + raise ValueError(msg) + if k == "fraction": + try: + fv = float(v) + except Exception as err: + msg = f"Fraction trigger values must be numeric, got {v!r}" + raise ValueError(msg) from err + if not 0 < fv <= 1: + msg = "fraction must be > 0 and <= 1" + raise ValueError(msg) + out[k] = fv + else: + try: + iv = int(v) + except Exception as err: + msg = f"{k} trigger values must be integer-like, got {v!r}" + raise ValueError(msg) from err + if iv <= 0: + msg = f"{k} threshold must be > 0" + raise ValueError(msg) + out[k] = iv + return cast("TriggerClause", out) + + clauses: list[TriggerClause] = [] + if isinstance(trigger, Mapping): + clauses.append(_validate_mapping(trigger)) + elif isinstance(trigger, tuple): + clauses.append(_validate_and_convert_tuple(trigger)) + elif isinstance(trigger, list): + for item in trigger: + if isinstance(item, Mapping): + clauses.append(_validate_mapping(item)) + elif isinstance(item, tuple): + clauses.append(_validate_and_convert_tuple(item)) + else: + msg = f"Unsupported trigger item type: {type(item)}" + raise TypeError(msg) + else: + msg = f"Unsupported trigger type: {type(trigger)}" + raise TypeError(msg) + return clauses + def _should_summarize_based_on_reported_tokens( self, messages: list[AnyMessage], threshold: float ) -> bool: @@ -348,27 +464,36 @@ def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bo if not self._trigger_conditions: return False - for kind, value in self._trigger_conditions: - if kind == "messages" and len(messages) >= value: - return True - if kind == "tokens" and total_tokens >= value: - return True - if kind == "tokens" and self._should_summarize_based_on_reported_tokens( - messages, value - ): + for clause in self._trigger_conditions: + clause_met = True + for kind, value in clause.items(): + if kind == "messages" and len(messages) < cast("int", value): + clause_met = False + break + if kind == "tokens": + threshold_tokens = cast("int", value) + # Trigger if total tokens exceed threshold OR reported tokens do + if ( + total_tokens < threshold_tokens + and not self._should_summarize_based_on_reported_tokens( + messages, float(threshold_tokens) + ) + ): + clause_met = False + break + if kind == "fraction": + max_input_tokens = self._get_profile_limits() + if max_input_tokens is None: + clause_met = False + break + threshold = int(max_input_tokens * cast("float", value)) + if threshold <= 0: + threshold = 1 + if total_tokens < threshold: + clause_met = False + break + if clause_met: return True - if kind == "fraction": - max_input_tokens = self._get_profile_limits() - if max_input_tokens is None: - continue - threshold = int(max_input_tokens * value) - if threshold <= 0: - threshold = 1 - if total_tokens >= threshold: - return True - - if self._should_summarize_based_on_reported_tokens(messages, threshold): - return True return False def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int: diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index 728c6c97dfe25..eb3c5d095b1d0 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -962,6 +962,308 @@ def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None: assert cutoff == 2 +def test_and_trigger_conditions() -> None: + """Test AND-capable trigger conditions (all conditions in dict must be met).""" + model = FakeToolCallingModel() + + # Create middleware with AND condition: tokens >= 1000 AND messages >= 5 + middleware = SummarizationMiddleware( + model=model, + trigger={"tokens": 1000, "messages": 5}, + keep=("messages", 2), # Explicitly set a smaller keep value + ) + + # Test case 1: Only tokens threshold met (messages = 3 < 5) + # Should NOT trigger summarization + def token_counter_high(messages): + return 1500 # Above token threshold + + middleware.token_counter = token_counter_high + state = { + "messages": [ + HumanMessage(content="1"), + AIMessage(content="2"), + HumanMessage(content="3"), + ] + } + result = middleware.before_model(state, None) + assert result is None, "Should not summarize when only tokens condition is met" + + # Test case 2: Only messages threshold met (tokens = 500 < 1000) + # Should NOT trigger summarization + def token_counter_low(messages): + return 500 # Below token threshold + + middleware.token_counter = token_counter_low + state = { + "messages": [ + HumanMessage(content="1"), + AIMessage(content="2"), + HumanMessage(content="3"), + AIMessage(content="4"), + HumanMessage(content="5"), + AIMessage(content="6"), + ] + } + result = middleware.before_model(state, None) + assert result is None, "Should not summarize when only messages condition is met" + + # Test case 3: Both conditions met (tokens >= 1000 AND messages >= 5) + # Should trigger summarization + middleware.token_counter = token_counter_high + result = middleware.before_model(state, None) + assert result is not None, "Should summarize when both conditions are met" + assert isinstance(result["messages"][0], RemoveMessage) + + +def test_or_trigger_conditions_with_and_clauses() -> None: + """Test OR across multiple AND clauses.""" + model = FakeToolCallingModel() + + # Create middleware with OR of AND conditions: + # (tokens >= 5000 AND messages >= 3) OR (tokens >= 3000 AND messages >= 6) + middleware = SummarizationMiddleware( + model=model, + trigger=[ + {"tokens": 5000, "messages": 3}, + {"tokens": 3000, "messages": 6}, + ], + ) + + # Test case 1: First clause met (tokens = 5500, messages = 4) + # Should trigger summarization + def token_counter_5500(messages): + return 5500 + + middleware.token_counter = token_counter_5500 + state = { + "messages": [ + HumanMessage(content="1"), + AIMessage(content="2"), + HumanMessage(content="3"), + AIMessage(content="4"), + ] + } + result = middleware.before_model(state, None) + assert result is not None, "Should summarize when first OR clause is met" + + # Test case 2: Second clause met (tokens = 3500, messages = 7) + # Should trigger summarization + def token_counter_3500(messages): + return 3500 + + middleware.token_counter = token_counter_3500 + state = {"messages": [HumanMessage(content=str(i)) for i in range(7)]} + result = middleware.before_model(state, None) + assert result is not None, "Should summarize when second OR clause is met" + + # Test case 3: Neither clause fully met + # (tokens = 4500 meets second token threshold but not message count) + # (messages = 4 meets first message threshold but not token count) + # Should NOT trigger summarization + def token_counter_4500(messages): + return 4500 + + middleware.token_counter = token_counter_4500 + state = { + "messages": [ + HumanMessage(content="1"), + AIMessage(content="2"), + HumanMessage(content="3"), + AIMessage(content="4"), + ] + } + result = middleware.before_model(state, None) + assert result is None, "Should not summarize when no complete clause is met" + + +def test_backward_compatibility_tuple_trigger() -> None: + """Test backward compatibility with existing tuple-based triggers.""" + model = FakeToolCallingModel() + + # Single tuple trigger + middleware_single = SummarizationMiddleware( + model=model, + trigger=("tokens", 1000), + ) + + def token_counter_high(messages): + return 1500 + + middleware_single.token_counter = token_counter_high + state = {"messages": [HumanMessage(content="test")]} + result = middleware_single.before_model(state, None) + assert result is not None, "Single tuple trigger should work" + + # List of tuples trigger + middleware_list = SummarizationMiddleware( + model=model, + trigger=[("tokens", 1000), ("messages", 5)], + ) + + # Should trigger with high tokens (first condition met) + middleware_list.token_counter = token_counter_high + state = {"messages": [HumanMessage(content="test")]} + result = middleware_list.before_model(state, None) + assert result is not None, "List of tuples should trigger when any condition met" + + # Should trigger with many messages (second condition met) + def token_counter_low(messages): + return 100 + + middleware_list.token_counter = token_counter_low + state = {"messages": [HumanMessage(content=str(i)) for i in range(6)]} + result = middleware_list.before_model(state, None) + assert result is not None, "List of tuples should trigger when second condition met" + + +def test_mixed_and_or_conditions() -> None: + """Test mixing dict (AND) and tuple (single condition) triggers in a list (OR).""" + model = FakeToolCallingModel() + + # (tokens >= 4000 AND messages >= 10) OR (messages >= 50) + middleware = SummarizationMiddleware( + model=model, + trigger=[ + {"tokens": 4000, "messages": 10}, + ("messages", 50), + ], + ) + + # Test case 1: First AND clause met + def token_counter_high(messages): + return 4500 + + middleware.token_counter = token_counter_high + state = {"messages": [HumanMessage(content=str(i)) for i in range(12)]} + result = middleware.before_model(state, None) + assert result is not None, "Should trigger when AND clause is met" + + # Test case 2: Second simple condition met + def token_counter_low(messages): + return 1000 + + middleware.token_counter = token_counter_low + state = {"messages": [HumanMessage(content=str(i)) for i in range(55)]} + result = middleware.before_model(state, None) + assert result is not None, "Should trigger when simple messages condition is met" + + # Test case 3: Neither condition met + middleware.token_counter = token_counter_low + state = {"messages": [HumanMessage(content=str(i)) for i in range(8)]} + result = middleware.before_model(state, None) + assert result is None, "Should not trigger when no condition is met" + + +def test_fraction_in_and_trigger() -> None: + """Test using fraction threshold in AND conditions.""" + # Create middleware with AND condition: fraction >= 0.8 AND messages >= 5 + middleware = SummarizationMiddleware( + model=ProfileChatModel(), + trigger={"fraction": 0.8, "messages": 5}, + ) + + def token_counter(messages): + return len(messages) * 200 # Each message = 200 tokens + + middleware.token_counter = token_counter + + # Test case 1: Both conditions met + # 5 messages * 200 = 1000 tokens (profile max is 1000) + # 1000 / 1000 = 1.0 >= 0.8 AND messages = 5 >= 5 + state = {"messages": [HumanMessage(content=str(i)) for i in range(5)]} + result = middleware.before_model(state, None) + assert result is not None, "Should trigger when both fraction and messages conditions met" + + # Test case 2: Only messages condition met + # 3 messages * 200 = 600 tokens + # 600 / 1000 = 0.6 < 0.8 and messages = 3 < 5 + state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]} + result = middleware.before_model(state, None) + assert result is None, "Should not trigger when neither condition is fully met" + + # Test case 3: High fraction but not enough messages + # 4 messages * 200 = 800 tokens + # 800 / 1000 = 0.8 >= 0.8 but messages = 4 < 5 + state = {"messages": [HumanMessage(content=str(i)) for i in range(4)]} + result = middleware.before_model(state, None) + assert result is None, "Should not trigger when only fraction condition is met" + + +def test_trigger_validation_errors() -> None: + """Test validation errors for invalid trigger configurations.""" + model = FakeToolCallingModel() + + # Invalid metric name + with pytest.raises(ValueError, match="Unsupported trigger metric"): + SummarizationMiddleware( + model=model, + trigger={"invalid_metric": 100}, + ) + + # Invalid fraction value (> 1) + with pytest.raises(ValueError, match="fraction must be > 0 and <= 1"): + SummarizationMiddleware( + model=model, + trigger={"fraction": 1.5}, + ) + + # Invalid fraction value (<= 0) + with pytest.raises(ValueError, match="fraction must be > 0 and <= 1"): + SummarizationMiddleware( + model=model, + trigger={"fraction": 0}, + ) + + # Invalid token threshold (<= 0) + with pytest.raises(ValueError, match="tokens threshold must be > 0"): + SummarizationMiddleware( + model=model, + trigger={"tokens": 0}, + ) + + # Invalid message threshold (<= 0) + with pytest.raises(ValueError, match="messages threshold must be > 0"): + SummarizationMiddleware( + model=model, + trigger={"messages": -5}, + ) + + # Non-numeric fraction value + with pytest.raises(ValueError, match="Fraction trigger values must be numeric"): + SummarizationMiddleware( + model=model, + trigger={"fraction": "invalid"}, + ) + + # Invalid list item type + with pytest.raises(TypeError, match="Unsupported trigger item type"): + SummarizationMiddleware( + model=model, + trigger=["invalid"], + ) + + +def test_empty_and_condition() -> None: + """Test that empty dict trigger clause is rejected or handled appropriately.""" + model = FakeToolCallingModel() + + # Empty dict should be allowed but never triggers (no conditions to check) + middleware = SummarizationMiddleware( + model=model, + trigger={}, + ) + + def token_counter_high(messages): + return 5000 + + middleware.token_counter = token_counter_high + state = {"messages": [HumanMessage(content=str(i)) for i in range(100)]} + # Empty clause should vacuously be true (all zero conditions are met) + result = middleware.before_model(state, None) + assert result is not None, "Empty trigger clause should trigger" + + def test_create_summary_uses_get_buffer_string_format() -> None: """Test that `_create_summary` formats messages using `get_buffer_string`.