Skip to content

Commit dd3badd

Browse files
committed
chore: More comments
Signed-off-by: Marcel Klehr <[email protected]>
1 parent 2290a47 commit dd3badd

File tree

1 file changed

+41
-5
lines changed

1 file changed

+41
-5
lines changed

ex_app/lib/agent.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,62 @@
3636
key = file.read()
3737

3838
def load_conversation_old(conversation_token: str):
39+
"""
40+
Load a checkpointer with the conversation state from the conversation token
41+
42+
This is the old way which is only used as a fallback anymore
43+
"""
3944
checkpointer = MemorySaver()
4045
if conversation_token == '' or conversation_token == '{}':
46+
# return an empty checkpointer
4147
return checkpointer
42-
checkpointer.storage = checkpointer.serde.loads(verify_signature(conversation_token, key).encode())
48+
# Verify whether this conversation token was signed by this instance of context_agent
49+
serialized_checkpointer = verify_signature(conversation_token, key)
50+
# Deserialize the checkpointer storage
51+
checkpointer.storage = checkpointer.serde.loads(serialized_checkpointer.encode())
52+
# return the prepared checkpointer
4353
return checkpointer
4454

4555
def load_conversation(conversation_token: str):
56+
"""
57+
Load a checkpointer with the conversation state from the conversation token
58+
59+
This is the new way which only restores that last checkpoint of the checkpointer instead of the whole checkpointer history
60+
"""
4661
checkpointer = MemorySaver()
4762
if conversation_token == '' or conversation_token == '{}':
63+
# return an empty checkpointer
4864
return checkpointer
4965

50-
conversation = checkpointer.serde.loads(verify_signature(conversation_token, key).encode())
66+
# Verify whether this was signed by this instance of context_agent
67+
serialized_state = verify_signature(conversation_token, key)
68+
# Deserialize the saved state
69+
conversation = checkpointer.serde.loads(serialized_state.encode())
70+
# Get the last checkpoint state
5171
last_checkpoint = conversation['last_checkpoint']
72+
# get the last checkpointer config
5273
last_config = conversation['last_config']
74+
# insert the last checkpoint state at the right spot in the checkpointer storage
5375
checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']] = last_checkpoint
76+
# return the prepared checkpointer
5477
return checkpointer
5578

5679
def export_conversation(checkpointer):
80+
"""
81+
Prepare and sign a conversation token from a checkpointer
82+
83+
This only uses the new way which only saves the last checkpoint of the checkpointer instead of the whole checkpointer history
84+
"""
85+
# get the last config which holds the last written checkpoint
5786
last_config = checkpointer.last_config
87+
# Select the last written checkpoint
5888
last_checkpoint = checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']]
59-
conversation_token = {"last_config": last_config, "last_checkpoint": last_checkpoint}
60-
return add_signature(checkpointer.serde.dumps(conversation_token).decode('utf-8'), key)
61-
89+
# prepare the to-serialize blob
90+
state = {"last_config": last_config, "last_checkpoint": last_checkpoint}
91+
serialized_state = checkpointer.serde.dumps(state)
92+
# sign the serialized state
93+
conversation_token = add_signature(serialized_state.decode('utf-8'), key)
94+
return conversation_token
6295

6396
async def react(task, nc: Nextcloud):
6497
safe_tools, dangerous_tools = await get_tools(nc)
@@ -111,8 +144,11 @@ async def call_model(
111144
return {"messages": [response]}
112145

113146
try:
147+
# Try to load state using the new conversation_token type
114148
checkpointer = load_conversation(task['input']['conversation_token'])
115149
except Exception as e:
150+
# fallback to trying to load the state using the old conversation_token type
151+
# if this fails, we fail the whole task
116152
checkpointer = load_conversation_old(task['input']['conversation_token'])
117153

118154
graph = await get_graph(call_model, safe_tools, dangerous_tools, checkpointer)

0 commit comments

Comments
 (0)