-
Notifications
You must be signed in to change notification settings - Fork 218
agent short & long term memory with langgraph. #851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d204e5e
6bdf0b5
aa1889b
34337d9
b4e1e42
2a0ab26
3b0c07c
04efe73
6ba9367
aa1ef0d
76d73e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| # Copyright (C) 2024 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import json | ||
| import uuid | ||
| from datetime import datetime | ||
| from typing import List, Optional | ||
|
|
||
| from langchain_core.runnables import RunnableConfig | ||
| from langgraph.checkpoint.memory import MemorySaver | ||
| from langgraph.graph import StateGraph | ||
| from langgraph.store.memory import InMemoryStore | ||
| from pydantic import BaseModel | ||
|
|
||
|
|
||
| class PersistenceConfig(BaseModel): | ||
| checkpointer: bool = False | ||
| store: bool = False | ||
|
|
||
|
|
||
| class PersistenceInfo(BaseModel): | ||
| user_id: str = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the relationship between user_id and assistant_id?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will comfirm |
||
| thread_id: str = None | ||
| started_at: datetime | ||
|
|
||
|
|
||
| class AgentPersistence: | ||
| def __init__(self, config: PersistenceConfig): | ||
| # for short-term memory | ||
| self.checkpointer = None | ||
| # for long-term memory | ||
| self.store = None | ||
| self.config = config | ||
| print(f"Initializing AgentPersistence: {config}") | ||
| self.initialize() | ||
|
|
||
| def initialize(self) -> None: | ||
| if self.config.checkpointer: | ||
| self.checkpointer = MemorySaver() | ||
| if self.config.store: | ||
| self.store = InMemoryStore() | ||
|
|
||
| def save( | ||
| self, | ||
| config: RunnableConfig, | ||
| content: str, | ||
| context: str, | ||
| memory_id: Optional[str] = None, | ||
| ): | ||
| """This function is only for long-term memory.""" | ||
| mem_id = memory_id or uuid.uuid4() | ||
| user_id = config["configurable"]["user_id"] | ||
| self.store.put( | ||
| ("memories", user_id), | ||
| key=str(mem_id), | ||
| value={"content": content, "context": context}, | ||
| ) | ||
| return f"Stored memory {content}" | ||
|
|
||
| def get(self, config: RunnableConfig): | ||
| """This function is only for long-term memory.""" | ||
| user_id = config["configurable"]["user_id"] | ||
| namespace = ("memories", user_id) | ||
| memories = self.store.search(namespace) | ||
| return memories | ||
|
|
||
| def update_state(self, config, graph: StateGraph): | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,9 +57,15 @@ def setup_chat_model(args): | |
| } | ||
| if args.llm_engine == "vllm" or args.llm_engine == "tgi": | ||
| openai_endpoint = f"{args.llm_endpoint_url}/v1" | ||
| llm = ChatOpenAI(openai_api_key="EMPTY", openai_api_base=openai_endpoint, model_name=args.model, **params) | ||
| llm = ChatOpenAI( | ||
| openai_api_key="EMPTY", | ||
| openai_api_base=openai_endpoint, | ||
| model_name=args.model, | ||
| request_timeout=args.timeout, | ||
| **params, | ||
| ) | ||
| elif args.llm_engine == "openai": | ||
| llm = ChatOpenAI(model_name=args.model, **params) | ||
| llm = ChatOpenAI(model_name=args.model, request_timeout=args.timeout, **params) | ||
| else: | ||
| raise ValueError("llm_engine must be vllm, tgi or openai") | ||
| return llm | ||
|
|
@@ -129,6 +135,9 @@ def get_args(): | |
| parser.add_argument("--repetition_penalty", type=float, default=1.03) | ||
| parser.add_argument("--return_full_text", type=bool, default=False) | ||
| parser.add_argument("--custom_prompt", type=str, default=None) | ||
| parser.add_argument("--with_memory", type=bool, default=False) | ||
| parser.add_argument("--with_store", type=bool, default=False) | ||
| parser.add_argument("--timeout", type=int, default=60) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for v1.1 timeout only applies to waiting for LLM response. Can we add timeout for tools in later release?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will comfirm |
||
|
|
||
| sys_args, unknown_args = parser.parse_known_args() | ||
| # print("env_config: ", env_config) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think instantiate_agent when microservice starts makes sense when it is chat_completion, but does not quite make sense when it is assistants api. Shall we initiate agent only when user send a create_assistant request? And even then, we are not materializing the agent, but instead only record the configs (like llama-stack create_agent), the agent is then materialized later when user send request to the thread api (like llama-stack get_agent).
The benefits of such an approach: one microservice can support multiple configs, which means multiple different types of agents, instead of just one config. So this is more scalable.