1
1
import logging
2
+ from operator import add
2
3
import re
3
4
import json
4
- from typing import Any , Callable
5
+ from typing import Any
5
6
from uuid import uuid4
6
7
from functools import reduce
7
8
8
9
from langchain .schema import StrOutputParser
9
- from langchain_core .runnables import RunnableParallel , RunnablePassthrough
10
+ from langchain_core .runnables import Runnable , RunnableLambda , RunnableParallel
10
11
from langchain_core .vectorstores import VectorStoreRetriever
11
12
12
13
from redbox .chains .components import get_tokeniser , get_chat_llm
13
14
from redbox .chains .runnables import build_llm_chain , CannedChatLLM
15
+ from redbox .models .graph import ROUTE_NAME_TAG
14
16
from redbox .models import ChatRoute , Settings
15
17
from redbox .models .chain import RedboxState , RequestMetadata
16
18
from redbox .transform import combine_documents , structure_documents
29
31
30
32
def build_retrieve_pattern (
31
33
retriever : VectorStoreRetriever , final_source_chain : bool = False
32
- ) -> Callable [[ RedboxState ] , dict [str , Any ]]:
34
+ ) -> Runnable [ RedboxState , dict [str , Any ]]:
33
35
"""Returns a function that uses state["request"] and state["text"] to set state["documents"]."""
34
36
retriever_chain = RunnableParallel ({"documents" : retriever | structure_documents })
35
37
@@ -38,16 +40,13 @@ def build_retrieve_pattern(
38
40
else :
39
41
_retriever = retriever_chain
40
42
41
- def _retrieve (state : RedboxState ) -> dict [str , Any ]:
42
- return _retriever .invoke (state )
43
-
44
- return _retrieve
43
+ return _retriever
45
44
46
45
47
46
def build_chat_pattern (
48
47
prompt_set : PromptSet ,
49
48
final_response_chain : bool = False ,
50
- ) -> Callable [[ RedboxState ] , dict [str , Any ]]:
49
+ ) -> Runnable [ RedboxState , dict [str , Any ]]:
51
50
"""Returns a function that uses state["request"] to set state["text"]."""
52
51
53
52
def _chat (state : RedboxState ) -> dict [str , Any ]:
@@ -64,7 +63,7 @@ def _chat(state: RedboxState) -> dict[str, Any]:
64
63
def build_merge_pattern (
65
64
prompt_set : PromptSet ,
66
65
final_response_chain : bool = False ,
67
- ) -> Callable [[ RedboxState ] , dict [str , Any ]]:
66
+ ) -> Runnable [ RedboxState , dict [str , Any ]]:
68
67
"""Returns a function that uses state["request"] and state["documents"] to return one item in state["documents"].
69
68
70
69
When combined with chunk send, will replace each Document with what's returned from the LLM.
@@ -75,6 +74,7 @@ def build_merge_pattern(
75
74
"""
76
75
tokeniser = get_tokeniser ()
77
76
77
+ @RunnableLambda
78
78
def _merge (state : RedboxState ) -> dict [str , Any ]:
79
79
llm = get_chat_llm (Settings (), state ["request" ].ai_settings )
80
80
@@ -117,46 +117,55 @@ def _merge(state: RedboxState) -> dict[str, Any]:
117
117
118
118
def build_stuff_pattern (
119
119
prompt_set : PromptSet ,
120
+ output_parser : Runnable = None ,
120
121
final_response_chain : bool = False ,
121
- ) -> Callable [[ RedboxState ] , dict [str , Any ]]:
122
+ ) -> Runnable [ RedboxState , dict [str , Any ]]:
122
123
"""Returns a function that uses state["request"] and state["documents"] to set state["text"]."""
123
124
125
+ @RunnableLambda
124
126
def _stuff (state : RedboxState ) -> dict [str , Any ]:
125
127
llm = get_chat_llm (Settings (), state ["request" ].ai_settings )
126
128
127
- return build_llm_chain (prompt_set = prompt_set , llm = llm , final_response_chain = final_response_chain ).invoke (state )
129
+ events = []
130
+
131
+ for event in build_llm_chain (
132
+ prompt_set = prompt_set , llm = llm , output_parser = output_parser , final_response_chain = final_response_chain
133
+ ).stream (state ):
134
+ events .append (event )
135
+
136
+ if len (events ) == 0 :
137
+ return None
138
+ else :
139
+ return reduce (add , events )
128
140
129
141
return _stuff
130
142
131
143
132
144
## Utility patterns
133
145
134
146
135
- def build_set_route_pattern (route : ChatRoute ) -> Callable [[ RedboxState ] , dict [str , Any ]]:
136
- """Returns a function that sets state["route_name"]."""
147
+ def build_set_route_pattern (route : ChatRoute ) -> Runnable [ RedboxState , dict [str , Any ]]:
148
+ """Returns a Runnable that sets state["route_name"]."""
137
149
138
150
def _set_route (state : RedboxState ) -> dict [str , Any ]:
139
- set_route_chain = (RunnablePassthrough () | StrOutputParser ()).with_config (tags = ["route_flag" ])
140
-
141
- return {"route_name" : set_route_chain .invoke (route .value )}
151
+ return {"route_name" : route }
142
152
143
- return _set_route
153
+ return RunnableLambda ( _set_route ). with_config ( tags = [ ROUTE_NAME_TAG ])
144
154
145
155
156
+ @RunnableLambda
146
157
def set_self_route_from_llm_answer (state : RedboxState ):
147
- llm_response = state ["text" ].lower ()
148
- if llm_response == "true" :
149
- route = ChatRoute .search
150
- elif llm_response == "false" :
151
- route = ChatRoute .chat_with_docs_map_reduce
158
+ llm_response = state ["text" ]
159
+ if "unanswerable" in llm_response [: min (20 , len (llm_response ))]:
160
+ return {"route_name" : ChatRoute .chat_with_docs_map_reduce }
152
161
else :
153
- route = ChatRoute .search
154
- return {"route_name" : route .value }
162
+ return {"route_name" : ChatRoute .search }
155
163
156
164
157
- def build_passthrough_pattern () -> Callable [[ RedboxState ] , dict [str , Any ]]:
158
- """Returns a function that uses state["request"] to set state["text"]."""
165
+ def build_passthrough_pattern () -> Runnable [ RedboxState , dict [str , Any ]]:
166
+ """Returns a Runnable that uses state["request"] to set state["text"]."""
159
167
168
+ @RunnableLambda
160
169
def _passthrough (state : RedboxState ) -> dict [str , Any ]:
161
170
return {
162
171
"text" : state ["request" ].question ,
@@ -166,10 +175,11 @@ def _passthrough(state: RedboxState) -> dict[str, Any]:
166
175
167
176
168
177
def build_set_text_pattern (text : str , final_response_chain : bool = False ):
169
- """Returns a function that can arbitrarily set state["text"] to a value."""
178
+ """Returns a Runnable that can arbitrarily set state["text"] to a value."""
170
179
llm = CannedChatLLM (text = text )
171
180
_llm = llm .with_config (tags = ["response_flag" ]) if final_response_chain else llm
172
181
182
+ @RunnableLambda
173
183
def _set_text (state : RedboxState ) -> dict [str , Any ]:
174
184
set_text_chain = _llm | StrOutputParser ()
175
185
@@ -179,6 +189,7 @@ def _set_text(state: RedboxState) -> dict[str, Any]:
179
189
180
190
181
191
def build_set_metadata_pattern ():
192
+ @RunnableLambda
182
193
def _set_metadata_pattern (state : RedboxState ):
183
194
flat_docs = flatten_document_state (state .get ("documents" , {}))
184
195
return {
@@ -192,10 +203,11 @@ def _set_metadata_pattern(state: RedboxState):
192
203
193
204
194
205
def build_error_pattern (text : str , route_name : str | None ):
206
+ @RunnableLambda
195
207
def _error_pattern (state : RedboxState ):
196
- return build_set_text_pattern (text , final_response_chain = True )(state ) | build_set_route_pattern ( route_name ) (
197
- state
198
- )
208
+ return build_set_text_pattern (text , final_response_chain = True ). invoke (state ) | build_set_route_pattern (
209
+ route_name
210
+ ). invoke ( state )
199
211
200
212
return _error_pattern
201
213
0 commit comments