|
35 | 35 | print(f"Reading file '{key_file_path}'.") |
36 | 36 | key = file.read() |
37 | 37 |
|
38 | | -def load_conversation(checkpointer, conversation_token: str): |
| 38 | +def load_conversation_old(conversation_token: str): |
| 39 | + """ |
| 40 | + Load a checkpointer with the conversation state from the conversation token |
| 41 | +
|
| 42 | + This is the old way which is only used as a fallback anymore |
| 43 | + """ |
| 44 | + checkpointer = MemorySaver() |
| 45 | + if conversation_token == '' or conversation_token == '{}': |
| 46 | + # return an empty checkpointer |
| 47 | + return checkpointer |
| 48 | + # Verify whether this conversation token was signed by this instance of context_agent |
| 49 | + serialized_checkpointer = verify_signature(conversation_token, key) |
| 50 | + # Deserialize the checkpointer storage |
| 51 | + checkpointer.storage = checkpointer.serde.loads(serialized_checkpointer.encode()) |
| 52 | + # return the prepared checkpointer |
| 53 | + return checkpointer |
| 54 | + |
| 55 | +def load_conversation(conversation_token: str): |
| 56 | + """ |
| 57 | + Load a checkpointer with the conversation state from the conversation token |
| 58 | +
|
| 59 | + This is the new way which only restores that last checkpoint of the checkpointer instead of the whole checkpointer history |
| 60 | + """ |
| 61 | + checkpointer = MemorySaver() |
39 | 62 | if conversation_token == '' or conversation_token == '{}': |
40 | | - return |
41 | | - checkpointer.storage = checkpointer.serde.loads(verify_signature(conversation_token, key).encode()) |
| 63 | + # return an empty checkpointer |
| 64 | + return checkpointer |
| 65 | + |
| 66 | + # Verify whether this was signed by this instance of context_agent |
| 67 | + serialized_state = verify_signature(conversation_token, key) |
| 68 | + # Deserialize the saved state |
| 69 | + conversation = checkpointer.serde.loads(serialized_state.encode()) |
| 70 | + # Get the last checkpoint state |
| 71 | + last_checkpoint = conversation['last_checkpoint'] |
| 72 | + # get the last checkpointer config |
| 73 | + last_config = conversation['last_config'] |
| 74 | + # insert the last checkpoint state at the right spot in the checkpointer storage |
| 75 | + checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']] = last_checkpoint |
| 76 | + # return the prepared checkpointer |
| 77 | + return checkpointer |
42 | 78 |
|
43 | 79 | def export_conversation(checkpointer): |
44 | | - return add_signature(checkpointer.serde.dumps(checkpointer.storage).decode('utf-8'), key) |
45 | | - |
| 80 | + """ |
| 81 | + Prepare and sign a conversation token from a checkpointer |
| 82 | +
|
| 83 | + This only uses the new way which only saves the last checkpoint of the checkpointer instead of the whole checkpointer history |
| 84 | + """ |
| 85 | + # get the last config which holds the last written checkpoint |
| 86 | + last_config = checkpointer.last_config |
| 87 | + # Select the last written checkpoint |
| 88 | + last_checkpoint = checkpointer.storage[last_config['configurable']['thread_id']][last_config['configurable']['checkpoint_ns']][last_config['configurable']['checkpoint_id']] |
| 89 | + # prepare the to-serialize blob |
| 90 | + state = {"last_config": last_config, "last_checkpoint": last_checkpoint} |
| 91 | + serialized_state = checkpointer.serde.dumps(state) |
| 92 | + # sign the serialized state |
| 93 | + conversation_token = add_signature(serialized_state.decode('utf-8'), key) |
| 94 | + return conversation_token |
46 | 95 |
|
47 | 96 | async def react(task, nc: Nextcloud): |
48 | 97 | safe_tools, dangerous_tools = await get_tools(nc) |
@@ -94,10 +143,15 @@ async def call_model( |
94 | 143 | # We return a list, because this will get added to the existing list |
95 | 144 | return {"messages": [response]} |
96 | 145 |
|
97 | | - checkpointer = MemorySaver() |
98 | | - graph = await get_graph(call_model, safe_tools, dangerous_tools, checkpointer) |
| 146 | + try: |
| 147 | + # Try to load state using the new conversation_token type |
| 148 | + checkpointer = load_conversation(task['input']['conversation_token']) |
| 149 | + except Exception as e: |
| 150 | + # fallback to trying to load the state using the old conversation_token type |
| 151 | + # if this fails, we fail the whole task |
| 152 | + checkpointer = load_conversation_old(task['input']['conversation_token']) |
99 | 153 |
|
100 | | - load_conversation(checkpointer, task['input']['conversation_token']) |
| 154 | + graph = await get_graph(call_model, safe_tools, dangerous_tools, checkpointer) |
101 | 155 |
|
102 | 156 | state_snapshot = graph.get_state(thread) |
103 | 157 |
|
|
0 commit comments