|
36 | 36 | key = file.read() |
37 | 37 |
|
38 | 38 | def load_conversation_old(conversation_token: str): |
| 39 | + """ |
| 40 | + Load a checkpointer with the conversation state from the conversation token |
| 41 | +
|
| 42 | + This is the old way which is only used as a fallback anymore |
| 43 | + """ |
39 | 44 | checkpointer = MemorySaver() |
40 | 45 | if conversation_token == '' or conversation_token == '{}': |
| 46 | + # return an empty checkpointer |
41 | 47 | return checkpointer |
42 | | - checkpointer.storage = checkpointer.serde.loads(verify_signature(conversation_token, key).encode()) |
| 48 | + # Verify whether this conversation token was signed by this instance of context_agent |
| 49 | + serialized_checkpointer = verify_signature(conversation_token, key) |
| 50 | + # Deserialize the checkpointer storage |
| 51 | + checkpointer.storage = checkpointer.serde.loads(serialized_checkpointer.encode()) |
| 52 | + # return the prepared checkpointer |
43 | 53 | return checkpointer |
44 | 54 |
|
45 | 55 | def load_conversation(conversation_token: str): |
| 56 | + """ |
| 57 | + Load a checkpointer with the conversation state from the conversation token |
| 58 | +
|
| 59 | + This is the new way which only restores that last checkpoint of the checkpointer instead of the whole checkpointer history |
| 60 | + """ |
46 | 61 | checkpointer = MemorySaver() |
47 | 62 | if conversation_token == '' or conversation_token == '{}': |
| 63 | + # return an empty checkpointer |
48 | 64 | return checkpointer |
49 | 65 |
|
50 | | - conversation = checkpointer.serde.loads(verify_signature(conversation_token, key).encode()) |
| 66 | + # Verify whether this was signed by this instance of context_agent |
| 67 | + serialized_state = verify_signature(conversation_token, key) |
| 68 | + # Deserialize the saved state |
| 69 | + conversation = checkpointer.serde.loads(serialized_state.encode()) |
| 70 | + # Get the last checkpoint state |
51 | 71 | last_checkpoint = conversation['last_checkpoint'] |
| 72 | + # get the last checkpointer config |
52 | 73 | last_config = conversation['last_config'] |
| 74 | + # insert the last checkpoint state at the right spot in the checkpointer storage |
53 | 75 | checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']] = last_checkpoint |
| 76 | + # return the prepared checkpointer |
54 | 77 | return checkpointer |
55 | 78 |
|
56 | 79 | def export_conversation(checkpointer): |
| 80 | + """ |
| 81 | + Prepare and sign a conversation token from a checkpointer |
| 82 | +
|
| 83 | + This only uses the new way which only saves the last checkpoint of the checkpointer instead of the whole checkpointer history |
| 84 | + """ |
| 85 | + # get the last config which holds the last written checkpoint |
57 | 86 | last_config = checkpointer.last_config |
| 87 | + # Select the last written checkpoint |
58 | 88 | last_checkpoint = checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']] |
59 | | - conversation_token = {"last_config": last_config, "last_checkpoint": last_checkpoint} |
60 | | - return add_signature(checkpointer.serde.dumps(conversation_token).decode('utf-8'), key) |
61 | | - |
| 89 | + # prepare the to-serialize blob |
| 90 | + state = {"last_config": last_config, "last_checkpoint": last_checkpoint} |
| 91 | + serialized_state = checkpointer.serde.dumps(state) |
| 92 | + # sign the serialized state |
| 93 | + conversation_token = add_signature(serialized_state.decode('utf-8'), key) |
| 94 | + return conversation_token |
62 | 95 |
|
63 | 96 | async def react(task, nc: Nextcloud): |
64 | 97 | safe_tools, dangerous_tools = await get_tools(nc) |
@@ -111,8 +144,11 @@ async def call_model( |
111 | 144 | return {"messages": [response]} |
112 | 145 |
|
113 | 146 | try: |
| 147 | + # Try to load state using the new conversation_token type |
114 | 148 | checkpointer = load_conversation(task['input']['conversation_token']) |
115 | 149 | except Exception as e: |
| 150 | + # fallback to trying to load the state using the old conversation_token type |
| 151 | + # if this fails, we fail the whole task |
116 | 152 | checkpointer = load_conversation_old(task['input']['conversation_token']) |
117 | 153 |
|
118 | 154 | graph = await get_graph(call_model, safe_tools, dangerous_tools, checkpointer) |
|
0 commit comments