Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,25 @@ def merge_message_runs(

# TODO: Update so validation errors (for token_counter, for example) are raised on
# init not at runtime.
@_runnable_support
@overload
def trim_messages(
messages: None = None,
*,
max_tokens: int,
token_counter: Callable[[list[BaseMessage]], int]
| Callable[[BaseMessage], int]
| BaseLanguageModel
| Literal["approximate"],
strategy: Literal["first", "last"] = "last",
allow_partial: bool = False,
end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
include_system: bool = False,
text_splitter: Callable[[str], list[str]] | TextSplitter | None = None,
) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...


@overload
def trim_messages(
messages: Iterable[MessageLikeRepresentation] | PromptValue,
*,
Expand All @@ -745,7 +763,26 @@ def trim_messages(
start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
include_system: bool = False,
text_splitter: Callable[[str], list[str]] | TextSplitter | None = None,
) -> list[BaseMessage]:
) -> list[BaseMessage]: ...


def trim_messages(
messages: Iterable[MessageLikeRepresentation] | PromptValue | None = None,
*,
max_tokens: int,
token_counter: Callable[[list[BaseMessage]], int]
| Callable[[BaseMessage], int]
| BaseLanguageModel
| Literal["approximate"],
strategy: Literal["first", "last"] = "last",
allow_partial: bool = False,
end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
include_system: bool = False,
text_splitter: Callable[[str], list[str]] | TextSplitter | None = None,
) -> (
list[BaseMessage] | Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]
):
r"""Trim messages to be below a token count.

`trim_messages` can be used to reduce the size of a chat history to a specified
Expand Down Expand Up @@ -1047,8 +1084,6 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
msg = "include_system parameter is only valid with strategy='last'"
raise ValueError(msg)

messages = convert_to_messages(messages)

# Handle string shortcuts for token counter
if isinstance(token_counter, str):
if token_counter in _TOKEN_COUNTER_SHORTCUTS:
Expand Down Expand Up @@ -1084,11 +1119,32 @@ def list_token_counter(messages: Sequence[BaseMessage]) -> int:
else:
msg = (
f"'token_counter' expected to be a model that implements "
f"'get_num_tokens_from_messages()' or a function. Received object of type "
f"{type(actual_token_counter)}."
f"'get_num_tokens_from_messages()' or a function. "
f"Received object of type {type(actual_token_counter)}."
)
raise ValueError(msg)

if messages is None:
from langchain_core.runnables import RunnableLambda # noqa: PLC0415

# Avoid circular import.
return RunnableLambda(
partial(
trim_messages,
max_tokens=max_tokens,
token_counter=token_counter,
strategy=strategy,
allow_partial=allow_partial,
end_on=end_on,
start_on=start_on,
include_system=include_system,
text_splitter=text_splitter,
),
name="trim_messages",
)

messages = convert_to_messages(messages)

if _HAS_LANGCHAIN_TEXT_SPLITTERS and isinstance(text_splitter, TextSplitter):
text_splitter_fn = text_splitter.split_text
elif text_splitter:
Expand Down Expand Up @@ -1203,7 +1259,7 @@ def convert_to_openai_messages(
{
"role": "user",
"content": [
{"type": "text", "text": "whats in this"},
{"type": "text", "text": "what's in this"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,'/9j/4AAQSk'"},
Expand All @@ -1222,15 +1278,15 @@ def convert_to_openai_messages(
],
),
ToolMessage("foobar", tool_call_id="1", name="bar"),
{"role": "assistant", "content": "thats nice"},
{"role": "assistant", "content": "that's nice"},
]
oai_messages = convert_to_openai_messages(messages)
# -> [
# {'role': 'system', 'content': 'foo'},
# {'role': 'user', 'content': [{'type': 'text', 'text': 'whats in this'}, {'type': 'image_url', 'image_url': {'url': "data:image/png;base64,'/9j/4AAQSk'"}}]},
# {'role': 'user', 'content': [{'type': 'text', 'text': "what's in this"}, {'type': 'image_url', 'image_url': {'url': "data:image/png;base64,'/9j/4AAQSk'"}}]},
# {'role': 'assistant', 'tool_calls': [{'type': 'function', 'id': '1','function': {'name': 'analyze', 'arguments': '{"baz": "buz"}'}}], 'content': ''},
# {'role': 'tool', 'name': 'bar', 'content': 'foobar'},
# {'role': 'assistant', 'content': 'thats nice'}
# {'role': 'assistant', 'content': 'that's nice'}
# ]
```

Expand Down
5 changes: 3 additions & 2 deletions libs/core/tests/unit_tests/messages/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,6 @@ def test_trim_messages_bound_model_token_counter() -> None:


def test_trim_messages_bad_token_counter() -> None:
trimmer = trim_messages(max_tokens=10, token_counter={}) # type: ignore[call-overload]
with pytest.raises(
ValueError,
match=re.escape(
Expand All @@ -535,7 +534,7 @@ def test_trim_messages_bad_token_counter() -> None:
"Received object of type <class 'dict'>."
),
):
trimmer.invoke([HumanMessage("foobar")])
trim_messages(max_tokens=10, token_counter={}) # type: ignore[call-overload]


def dummy_token_counter(messages: list[BaseMessage]) -> int:
Expand Down Expand Up @@ -670,6 +669,7 @@ def test_trim_messages_start_on_with_allow_partial() -> None:
)

assert len(result) == 1
assert isinstance(result, list)
assert result[0].content == "Second human message"
assert messages == messages_copy

Expand Down Expand Up @@ -744,6 +744,7 @@ def test_trim_messages_token_counter_shortcut_with_options() -> None:
)

# Should include system message and start on human
assert isinstance(result, list)
assert len(result) >= 2
assert isinstance(result[0], SystemMessage)
assert any(isinstance(msg, HumanMessage) for msg in result[1:])
Expand Down