Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions ex_app/lib/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,63 @@
print(f"Reading file '{key_file_path}'.")
key = file.read()

def load_conversation(checkpointer, conversation_token: str):
def load_conversation_old(conversation_token: str):
"""
Load a checkpointer with the conversation state from the conversation token

This is the old way which is only used as a fallback anymore
"""
checkpointer = MemorySaver()
if conversation_token == '' or conversation_token == '{}':
# return an empty checkpointer
return checkpointer
# Verify whether this conversation token was signed by this instance of context_agent
serialized_checkpointer = verify_signature(conversation_token, key)
# Deserialize the checkpointer storage
checkpointer.storage = checkpointer.serde.loads(serialized_checkpointer.encode())
# return the prepared checkpointer
return checkpointer

def load_conversation(conversation_token: str):
"""
Load a checkpointer with the conversation state from the conversation token

This is the new way which only restores that last checkpoint of the checkpointer instead of the whole checkpointer history
"""
checkpointer = MemorySaver()
if conversation_token == '' or conversation_token == '{}':
return
checkpointer.storage = checkpointer.serde.loads(verify_signature(conversation_token, key).encode())
# return an empty checkpointer
return checkpointer

# Verify whether this was signed by this instance of context_agent
serialized_state = verify_signature(conversation_token, key)
# Deserialize the saved state
conversation = checkpointer.serde.loads(serialized_state.encode())
# Get the last checkpoint state
last_checkpoint = conversation['last_checkpoint']
# get the last checkpointer config
last_config = conversation['last_config']
# insert the last checkpoint state at the right spot in the checkpointer storage
checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']] = last_checkpoint
# return the prepared checkpointer
return checkpointer

def export_conversation(checkpointer):
return add_signature(checkpointer.serde.dumps(checkpointer.storage).decode('utf-8'), key)

"""
Prepare and sign a conversation token from a checkpointer

This only uses the new way which only saves the last checkpoint of the checkpointer instead of the whole checkpointer history
"""
# get the last config which holds the last written checkpoint
last_config = checkpointer.last_config
# Select the last written checkpoint
last_checkpoint = checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']]
# prepare the to-serialize blob
state = {"last_config": last_config, "last_checkpoint": last_checkpoint}
serialized_state = checkpointer.serde.dumps(state)
# sign the serialized state
conversation_token = add_signature(serialized_state.decode('utf-8'), key)
return conversation_token

async def react(task, nc: Nextcloud):
safe_tools, dangerous_tools = await get_tools(nc)
Expand Down Expand Up @@ -94,10 +143,15 @@ async def call_model(
# We return a list, because this will get added to the existing list
return {"messages": [response]}

checkpointer = MemorySaver()
graph = await get_graph(call_model, safe_tools, dangerous_tools, checkpointer)
try:
# Try to load state using the new conversation_token type
checkpointer = load_conversation(task['input']['conversation_token'])
except Exception as e:
# fallback to trying to load the state using the old conversation_token type
# if this fails, we fail the whole task
checkpointer = load_conversation_old(task['input']['conversation_token'])

load_conversation(checkpointer, task['input']['conversation_token'])
graph = await get_graph(call_model, safe_tools, dangerous_tools, checkpointer)

state_snapshot = graph.get_state(thread)

Expand Down
11 changes: 10 additions & 1 deletion ex_app/lib/memorysaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class InMemorySaver(
dict[tuple[str, int], tuple[str, str, tuple[str, bytes], str]],
]

last_config: Optional[dict] = None

def __init__(
self,
*,
Expand Down Expand Up @@ -352,7 +354,7 @@ def put(
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
"""
c = checkpoint.copy()
c.pop("pending_sends") # type: ignore[misc]
#c.pop("pending_sends") # type: ignore[misc]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pending_sends seems to no longer be part of the checkpoints in langgraph

thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
self.storage[thread_id][checkpoint_ns].update(
Expand All @@ -364,6 +366,13 @@ def put(
)
}
)
self.last_config = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
return {
"configurable": {
"thread_id": thread_id,
Expand Down
Loading
Loading