Skip to content

Commit 86d9917

Browse files
authored
Merge pull request #92 from nextcloud/fix/token-size
fix: Reduce size of conversation_token
2 parents a0e375b + dd3badd commit 86d9917

File tree

4 files changed

+164
-97
lines changed

4 files changed

+164
-97
lines changed

ex_app/lib/agent.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,63 @@
3535
print(f"Reading file '{key_file_path}'.")
3636
key = file.read()
3737

38-
def load_conversation(checkpointer, conversation_token: str):
38+
def load_conversation_old(conversation_token: str):
39+
"""
40+
Load a checkpointer with the conversation state from the conversation token
41+
42+
This is the old way which is only used as a fallback anymore
43+
"""
44+
checkpointer = MemorySaver()
45+
if conversation_token == '' or conversation_token == '{}':
46+
# return an empty checkpointer
47+
return checkpointer
48+
# Verify whether this conversation token was signed by this instance of context_agent
49+
serialized_checkpointer = verify_signature(conversation_token, key)
50+
# Deserialize the checkpointer storage
51+
checkpointer.storage = checkpointer.serde.loads(serialized_checkpointer.encode())
52+
# return the prepared checkpointer
53+
return checkpointer
54+
55+
def load_conversation(conversation_token: str):
56+
"""
57+
Load a checkpointer with the conversation state from the conversation token
58+
59+
This is the new way which only restores that last checkpoint of the checkpointer instead of the whole checkpointer history
60+
"""
61+
checkpointer = MemorySaver()
3962
if conversation_token == '' or conversation_token == '{}':
40-
return
41-
checkpointer.storage = checkpointer.serde.loads(verify_signature(conversation_token, key).encode())
63+
# return an empty checkpointer
64+
return checkpointer
65+
66+
# Verify whether this was signed by this instance of context_agent
67+
serialized_state = verify_signature(conversation_token, key)
68+
# Deserialize the saved state
69+
conversation = checkpointer.serde.loads(serialized_state.encode())
70+
# Get the last checkpoint state
71+
last_checkpoint = conversation['last_checkpoint']
72+
# get the last checkpointer config
73+
last_config = conversation['last_config']
74+
# insert the last checkpoint state at the right spot in the checkpointer storage
75+
checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']] = last_checkpoint
76+
# return the prepared checkpointer
77+
return checkpointer
4278

4379
def export_conversation(checkpointer):
44-
return add_signature(checkpointer.serde.dumps(checkpointer.storage).decode('utf-8'), key)
45-
80+
"""
81+
Prepare and sign a conversation token from a checkpointer
82+
83+
This only uses the new way which only saves the last checkpoint of the checkpointer instead of the whole checkpointer history
84+
"""
85+
# get the last config which holds the last written checkpoint
86+
last_config = checkpointer.last_config
87+
# Select the last written checkpoint
88+
last_checkpoint = checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']]
89+
# prepare the to-serialize blob
90+
state = {"last_config": last_config, "last_checkpoint": last_checkpoint}
91+
serialized_state = checkpointer.serde.dumps(state)
92+
# sign the serialized state
93+
conversation_token = add_signature(serialized_state.decode('utf-8'), key)
94+
return conversation_token
4695

4796
async def react(task, nc: Nextcloud):
4897
safe_tools, dangerous_tools = await get_tools(nc)
@@ -94,10 +143,15 @@ async def call_model(
94143
# We return a list, because this will get added to the existing list
95144
return {"messages": [response]}
96145

97-
checkpointer = MemorySaver()
98-
graph = await get_graph(call_model, safe_tools, dangerous_tools, checkpointer)
146+
try:
147+
# Try to load state using the new conversation_token type
148+
checkpointer = load_conversation(task['input']['conversation_token'])
149+
except Exception as e:
150+
# fallback to trying to load the state using the old conversation_token type
151+
# if this fails, we fail the whole task
152+
checkpointer = load_conversation_old(task['input']['conversation_token'])
99153

100-
load_conversation(checkpointer, task['input']['conversation_token'])
154+
graph = await get_graph(call_model, safe_tools, dangerous_tools, checkpointer)
101155

102156
state_snapshot = graph.get_state(thread)
103157

ex_app/lib/memorysaver.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class InMemorySaver(
7373
dict[tuple[str, int], tuple[str, str, tuple[str, bytes], str]],
7474
]
7575

76+
last_config: Optional[dict] = None
77+
7678
def __init__(
7779
self,
7880
*,
@@ -352,7 +354,7 @@ def put(
352354
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
353355
"""
354356
c = checkpoint.copy()
355-
c.pop("pending_sends") # type: ignore[misc]
357+
#c.pop("pending_sends") # type: ignore[misc]
356358
thread_id = config["configurable"]["thread_id"]
357359
checkpoint_ns = config["configurable"]["checkpoint_ns"]
358360
self.storage[thread_id][checkpoint_ns].update(
@@ -364,6 +366,13 @@ def put(
364366
)
365367
}
366368
)
369+
self.last_config = {
370+
"configurable": {
371+
"thread_id": thread_id,
372+
"checkpoint_ns": checkpoint_ns,
373+
"checkpoint_id": checkpoint["id"],
374+
}
375+
}
367376
return {
368377
"configurable": {
369378
"thread_id": thread_id,

0 commit comments

Comments
 (0)