Skip to content

Commit 8998f89

Browse files
Doing self-route and handling RAG response in stream to avoid double loop
1 parent d7947ff commit 8998f89

File tree

10 files changed

+144
-100
lines changed

10 files changed

+144
-100
lines changed

Diff for: django_app/tests/test_consumers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from websockets import WebSocketClientProtocol
1717
from websockets.legacy.client import Connect
1818

19-
from redbox.graph.root import FINAL_RESPONSE_TAG, ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG
2019
from redbox.models.chain import LLMCallMetadata, RequestMetadata
20+
from redbox.models.graph import FINAL_RESPONSE_TAG, ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG
2121
from redbox_app.redbox_core import error_messages
2222
from redbox_app.redbox_core.consumers import ChatConsumer
2323
from redbox_app.redbox_core.models import Chat, ChatMessage, ChatMessageTokenUse, ChatRoleEnum, File, User
@@ -484,7 +484,7 @@ def mocked_connect(uploaded_file: File) -> Connect:
484484
"data": {"chunk": Token(content="Good afternoon, ")},
485485
},
486486
{"event": "on_chat_model_stream", "tags": [FINAL_RESPONSE_TAG], "data": {"chunk": Token(content="Mr. Amor.")}},
487-
{"event": "on_chain_end", "tags": [ROUTE_NAME_TAG], "data": {"output": "gratitude"}},
487+
{"event": "on_chain_end", "tags": [ROUTE_NAME_TAG], "data": {"output": {"route_name": "gratitude"}}},
488488
{
489489
"event": "on_retriever_end",
490490
"tags": [SOURCE_DOCUMENTS_TAG],
@@ -529,7 +529,7 @@ def mocked_connect_with_naughty_citation(uploaded_file: File) -> CannedGraphLLM:
529529
"tags": [FINAL_RESPONSE_TAG],
530530
"data": {"chunk": Token(content="Good afternoon, Mr. Amor.")},
531531
},
532-
{"event": "on_chain_end", "tags": [ROUTE_NAME_TAG], "data": {"output": "gratitude"}},
532+
{"event": "on_chain_end", "tags": [ROUTE_NAME_TAG], "data": {"output": {"route_name": "gratitude"}}},
533533
{
534534
"event": "on_retriever_end",
535535
"tags": [SOURCE_DOCUMENTS_TAG],

Diff for: redbox-core/redbox/app.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from langchain_core.vectorstores import VectorStoreRetriever
22

3+
from redbox.models.graph import FINAL_RESPONSE_TAG, ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG
34
from redbox.graph.root import get_root_graph
45
from redbox.models.chain import RedboxState
56
from redbox.models.chat import ChatRoute
7+
from redbox.models.graph import RedboxEventType
68
from redbox.models.settings import Settings
79
from redbox.chains.components import get_all_chunks_retriever, get_metadata_retriever, get_parameterised_retriever
8-
from redbox.graph.root import (
10+
from redbox.models.graph import (
911
ROUTABLE_KEYWORDS,
10-
ROUTE_NAME_TAG,
11-
FINAL_RESPONSE_TAG,
12-
SOURCE_DOCUMENTS_TAG,
1312
)
1413

1514

@@ -25,7 +24,6 @@ def __init__(
2524
metadata_retriever: VectorStoreRetriever | None = None,
2625
env: Settings | None = None,
2726
debug: bool = False,
28-
interrupt_after: list[str] = [],
2927
):
3028
_env = env or Settings()
3129
_all_chunks_retriever = all_chunks_retriever or get_all_chunks_retriever(_env)
@@ -54,11 +52,13 @@ async def run(
5452
content = event["data"]["output"]
5553
if isinstance(content, str):
5654
await response_tokens_callback(content)
55+
elif kind == "on_custom_event" and event["name"] == RedboxEventType.response_tokens.value:
56+
await response_tokens_callback(event["data"])
5757
elif kind == "on_chain_end" and ROUTE_NAME_TAG in tags:
58-
await route_name_callback(event["data"]["output"])
58+
await route_name_callback(event["data"]["output"]["route_name"])
5959
elif kind == "on_retriever_end" and SOURCE_DOCUMENTS_TAG in tags:
6060
await documents_callback(event["data"]["output"])
61-
elif kind == "on_custom_event" and event["name"] == "on_metadata_generation":
61+
elif kind == "on_custom_event" and event["name"] == RedboxEventType.on_metadata_generation.value:
6262
await metadata_tokens_callback(event["data"])
6363
elif kind == "on_chain_end" and event["name"] == "LangGraph":
6464
final_state = RedboxState(**event["data"]["output"])

Diff for: redbox-core/redbox/chains/runnables.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Iterator
2+
from typing import Any, Callable, Iterator, Iterable
33
import re
44
from operator import itemgetter
55

@@ -9,9 +9,13 @@
99
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
1010
from langchain_core.output_parsers import StrOutputParser
1111
from langchain_core.prompts import ChatPromptTemplate
12-
from langchain_core.runnables import Runnable, chain, RunnableLambda
12+
from langchain_core.runnables import Runnable, chain, RunnableLambda, RunnableGenerator
13+
from langchain_core.callbacks.manager import dispatch_custom_event
14+
1315
from tiktoken import Encoding
1416

17+
from redbox.models.graph import RedboxEventType
18+
1519
from redbox.api.format import format_documents
1620
from redbox.chains.components import get_tokeniser
1721
from redbox.models.chain import ChainChatMessage, RedboxState
@@ -67,25 +71,54 @@ def _chat_prompt_from_messages(state: RedboxState) -> Runnable:
6771
return _chat_prompt_from_messages
6872

6973

70-
def build_llm_chain(prompt_set: PromptSet, llm: BaseChatModel, final_response_chain: bool = False) -> Runnable:
74+
def build_llm_chain(
75+
prompt_set: PromptSet,
76+
llm: BaseChatModel,
77+
output_parser: Runnable | Callable = None,
78+
final_response_chain: bool = False,
79+
) -> Runnable:
7180
"""Builds a chain that correctly forms a text and metadata state update.
7281
7382
Permits both invoke and astream_events.
7483
"""
7584
model_name = getattr(llm, "model_name", "unknown-model")
7685
_llm = llm.with_config(tags=["response_flag"]) if final_response_chain else llm
77-
86+
_output_parser = output_parser if output_parser else StrOutputParser()
7887
return (
7988
build_chat_prompt_from_messages_runnable(prompt_set)
8089
| {
8190
"prompt": RunnableLambda(lambda prompt: prompt.to_string()),
82-
"response": _llm | StrOutputParser(),
91+
"response": _llm | _output_parser,
8392
"model": lambda x: model_name,
8493
}
8594
| {"text": itemgetter("response"), "metadata": to_request_metadata}
8695
)
8796

8897

98+
def build_self_route_output_parser(final_response_chain: bool = False):
99+
def _self_route_output_parser(chunks: Iterable[AIMessageChunk]) -> Iterable[str]:
100+
current_content = ""
101+
tokens_to_pass = 4
102+
token_count = 0
103+
for chunk in chunks:
104+
current_content += chunk.content
105+
token_count += 1
106+
if "unanswerable" in current_content:
107+
yield current_content
108+
return
109+
elif token_count > tokens_to_pass:
110+
break
111+
if final_response_chain:
112+
dispatch_custom_event(RedboxEventType.response_tokens, current_content)
113+
yield current_content
114+
for chunk in chunks:
115+
if final_response_chain:
116+
dispatch_custom_event(RedboxEventType.response_tokens, chunk.content)
117+
yield chunk.content
118+
119+
return RunnableGenerator(_self_route_output_parser)
120+
121+
89122
class CannedChatLLM(BaseChatModel):
90123
"""A custom chat model that returns its text as if an LLM returned it.
91124

Diff for: redbox-core/redbox/graph/nodes/processes.py

+42-30
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import logging
2+
from operator import add
23
import re
34
import json
4-
from typing import Any, Callable
5+
from typing import Any
56
from uuid import uuid4
67
from functools import reduce
78

89
from langchain.schema import StrOutputParser
9-
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
10+
from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel
1011
from langchain_core.vectorstores import VectorStoreRetriever
1112

1213
from redbox.chains.components import get_tokeniser, get_chat_llm
1314
from redbox.chains.runnables import build_llm_chain, CannedChatLLM
15+
from redbox.models.graph import ROUTE_NAME_TAG
1416
from redbox.models import ChatRoute, Settings
1517
from redbox.models.chain import RedboxState, RequestMetadata
1618
from redbox.transform import combine_documents, structure_documents
@@ -29,7 +31,7 @@
2931

3032
def build_retrieve_pattern(
3133
retriever: VectorStoreRetriever, final_source_chain: bool = False
32-
) -> Callable[[RedboxState], dict[str, Any]]:
34+
) -> Runnable[RedboxState, dict[str, Any]]:
3335
"""Returns a function that uses state["request"] and state["text"] to set state["documents"]."""
3436
retriever_chain = RunnableParallel({"documents": retriever | structure_documents})
3537

@@ -38,16 +40,13 @@ def build_retrieve_pattern(
3840
else:
3941
_retriever = retriever_chain
4042

41-
def _retrieve(state: RedboxState) -> dict[str, Any]:
42-
return _retriever.invoke(state)
43-
44-
return _retrieve
43+
return _retriever
4544

4645

4746
def build_chat_pattern(
4847
prompt_set: PromptSet,
4948
final_response_chain: bool = False,
50-
) -> Callable[[RedboxState], dict[str, Any]]:
49+
) -> Runnable[RedboxState, dict[str, Any]]:
5150
"""Returns a function that uses state["request"] to set state["text"]."""
5251

5352
def _chat(state: RedboxState) -> dict[str, Any]:
@@ -64,7 +63,7 @@ def _chat(state: RedboxState) -> dict[str, Any]:
6463
def build_merge_pattern(
6564
prompt_set: PromptSet,
6665
final_response_chain: bool = False,
67-
) -> Callable[[RedboxState], dict[str, Any]]:
66+
) -> Runnable[RedboxState, dict[str, Any]]:
6867
"""Returns a function that uses state["request"] and state["documents"] to return one item in state["documents"].
6968
7069
When combined with chunk send, will replace each Document with what's returned from the LLM.
@@ -75,6 +74,7 @@ def build_merge_pattern(
7574
"""
7675
tokeniser = get_tokeniser()
7776

77+
@RunnableLambda
7878
def _merge(state: RedboxState) -> dict[str, Any]:
7979
llm = get_chat_llm(Settings(), state["request"].ai_settings)
8080

@@ -117,46 +117,55 @@ def _merge(state: RedboxState) -> dict[str, Any]:
117117

118118
def build_stuff_pattern(
119119
prompt_set: PromptSet,
120+
output_parser: Runnable = None,
120121
final_response_chain: bool = False,
121-
) -> Callable[[RedboxState], dict[str, Any]]:
122+
) -> Runnable[RedboxState, dict[str, Any]]:
122123
"""Returns a function that uses state["request"] and state["documents"] to set state["text"]."""
123124

125+
@RunnableLambda
124126
def _stuff(state: RedboxState) -> dict[str, Any]:
125127
llm = get_chat_llm(Settings(), state["request"].ai_settings)
126128

127-
return build_llm_chain(prompt_set=prompt_set, llm=llm, final_response_chain=final_response_chain).invoke(state)
129+
events = []
130+
131+
for event in build_llm_chain(
132+
prompt_set=prompt_set, llm=llm, output_parser=output_parser, final_response_chain=final_response_chain
133+
).stream(state):
134+
events.append(event)
135+
136+
if len(events) == 0:
137+
return None
138+
else:
139+
return reduce(add, events)
128140

129141
return _stuff
130142

131143

132144
## Utility patterns
133145

134146

135-
def build_set_route_pattern(route: ChatRoute) -> Callable[[RedboxState], dict[str, Any]]:
136-
"""Returns a function that sets state["route_name"]."""
147+
def build_set_route_pattern(route: ChatRoute) -> Runnable[RedboxState, dict[str, Any]]:
148+
"""Returns a Runnable that sets state["route_name"]."""
137149

138150
def _set_route(state: RedboxState) -> dict[str, Any]:
139-
set_route_chain = (RunnablePassthrough() | StrOutputParser()).with_config(tags=["route_flag"])
140-
141-
return {"route_name": set_route_chain.invoke(route.value)}
151+
return {"route_name": route}
142152

143-
return _set_route
153+
return RunnableLambda(_set_route).with_config(tags=[ROUTE_NAME_TAG])
144154

145155

156+
@RunnableLambda
146157
def set_self_route_from_llm_answer(state: RedboxState):
147-
llm_response = state["text"].lower()
148-
if llm_response == "true":
149-
route = ChatRoute.search
150-
elif llm_response == "false":
151-
route = ChatRoute.chat_with_docs_map_reduce
158+
llm_response = state["text"]
159+
if "unanswerable" in llm_response[: min(20, len(llm_response))]:
160+
return {"route_name": ChatRoute.chat_with_docs_map_reduce}
152161
else:
153-
route = ChatRoute.search
154-
return {"route_name": route.value}
162+
return {"route_name": ChatRoute.search}
155163

156164

157-
def build_passthrough_pattern() -> Callable[[RedboxState], dict[str, Any]]:
158-
"""Returns a function that uses state["request"] to set state["text"]."""
165+
def build_passthrough_pattern() -> Runnable[RedboxState, dict[str, Any]]:
166+
"""Returns a Runnable that uses state["request"] to set state["text"]."""
159167

168+
@RunnableLambda
160169
def _passthrough(state: RedboxState) -> dict[str, Any]:
161170
return {
162171
"text": state["request"].question,
@@ -166,10 +175,11 @@ def _passthrough(state: RedboxState) -> dict[str, Any]:
166175

167176

168177
def build_set_text_pattern(text: str, final_response_chain: bool = False):
169-
"""Returns a function that can arbitrarily set state["text"] to a value."""
178+
"""Returns a Runnable that can arbitrarily set state["text"] to a value."""
170179
llm = CannedChatLLM(text=text)
171180
_llm = llm.with_config(tags=["response_flag"]) if final_response_chain else llm
172181

182+
@RunnableLambda
173183
def _set_text(state: RedboxState) -> dict[str, Any]:
174184
set_text_chain = _llm | StrOutputParser()
175185

@@ -179,6 +189,7 @@ def _set_text(state: RedboxState) -> dict[str, Any]:
179189

180190

181191
def build_set_metadata_pattern():
192+
@RunnableLambda
182193
def _set_metadata_pattern(state: RedboxState):
183194
flat_docs = flatten_document_state(state.get("documents", {}))
184195
return {
@@ -192,10 +203,11 @@ def _set_metadata_pattern(state: RedboxState):
192203

193204

194205
def build_error_pattern(text: str, route_name: str | None):
206+
@RunnableLambda
195207
def _error_pattern(state: RedboxState):
196-
return build_set_text_pattern(text, final_response_chain=True)(state) | build_set_route_pattern(route_name)(
197-
state
198-
)
208+
return build_set_text_pattern(text, final_response_chain=True).invoke(state) | build_set_route_pattern(
209+
route_name
210+
).invoke(state)
199211

200212
return _error_pattern
201213

0 commit comments

Comments
 (0)