Skip to content

Commit

Permalink
[Improvement] update LLM memory get function (#1162)
Browse files Browse the repository at this point in the history
Co-authored-by: Deven Patel <[email protected]>
  • Loading branch information
deven298 and Deven Patel authored Jan 12, 2024
1 parent f582c1f commit c020e65
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 12 deletions.
7 changes: 5 additions & 2 deletions embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,11 @@ def reset(self):
# Send anonymous telemetry
self.telemetry.capture(event_name="reset", properties=self._telemetry_props)

def get_history(self, num_rounds: int = 10, display_format: bool = True):
return self.llm.memory.get(app_id=self.config.id, num_rounds=num_rounds, display_format=display_format)
def get_history(self, num_rounds: int = 10, display_format: bool = True, session_id: Optional[str] = "default"):
history = self.llm.memory.get(
app_id=self.config.id, session_id=session_id, num_rounds=num_rounds, display_format=display_format
)
return history

def delete_session_chat_history(self, session_id: str = "default"):
self.llm.memory.delete(app_id=self.config.id, session_id=session_id)
Expand Down
45 changes: 36 additions & 9 deletions embedchain/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,40 @@ def delete(self, app_id: str, session_id: Optional[str] = None):
self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
self.connection.commit()

def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]:
def get(
self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
) -> list[ChatMessage]:
"""
Get the most recent num_rounds rounds of conversations
between human and AI, for a given app_id.
Get the chat history for a given app_id.
param: app_id - The app_id to get chat history
param: session_id (optional) - The session_id to get chat history. Defaults to "default"
param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10
param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
"""

QUERY = """
base_query = """
SELECT * FROM ec_chat_history
WHERE app_id=? AND session_id=?
ORDER BY created_at DESC
LIMIT ?
WHERE app_id=?
"""

if fetch_all:
additional_query = "ORDER BY created_at DESC"
params = (app_id,)
else:
additional_query = """
AND session_id=?
ORDER BY created_at DESC
LIMIT ?
"""
params = (app_id, session_id, num_rounds)

QUERY = base_query + additional_query

self.cursor.execute(
QUERY,
(app_id, session_id, num_rounds),
params,
)

results = self.cursor.fetchall()
Expand All @@ -97,7 +116,15 @@ def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[C
metadata = self._deserialize_json(metadata=metadata)
# Return list of dict if display_format is True
if display_format:
history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp})
history.append(
{
"session_id": session_id,
"human": question,
"ai": answer,
"metadata": metadata,
"timestamp": timestamp,
}
)
else:
memory = ChatMessage()
memory.add_user_message(question, metadata=metadata)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
version = "0.1.62"
version = "0.1.63"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [
"Taranjeet Singh <[email protected]>",
Expand Down
4 changes: 4 additions & 0 deletions tests/memory/test_chat_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def test_get(chat_memory_instance):

assert len(recent_memories) == 5

all_memories = chat_memory_instance.get(app_id, fetch_all=True)

assert len(all_memories) == 6


def test_delete_chat_history(chat_memory_instance):
app_id = "test_app"
Expand Down

0 comments on commit c020e65

Please sign in to comment.