diff --git a/docs/api-reference/pipeline/delete.mdx b/docs/api-reference/pipeline/delete.mdx index 2c65e5f452..96d79554f8 100644 --- a/docs/api-reference/pipeline/delete.mdx +++ b/docs/api-reference/pipeline/delete.mdx @@ -2,7 +2,7 @@ title: 🗑 delete --- -`delete_chat_history()` method allows you to delete all previous messages in a chat history. +`delete_session_chat_history()` method allows you to delete all previous messages in a chat history. ## Usage @@ -15,5 +15,5 @@ app.add("https://www.forbes.com/profile/elon-musk") app.chat("What is the net worth of Elon Musk?") -app.delete_chat_history() +app.delete_session_chat_history() ``` \ No newline at end of file diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 4d8e49768f..fbd1a5a125 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -663,13 +663,17 @@ def reset(self): self.db.reset() self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,)) self.connection.commit() - self.delete_chat_history() + self.delete_all_chat_history(app_id=self.config.id) # Send anonymous telemetry self.telemetry.capture(event_name="reset", properties=self._telemetry_props) def get_history(self, num_rounds: int = 10, display_format: bool = True): return self.llm.memory.get(app_id=self.config.id, num_rounds=num_rounds, display_format=display_format) - def delete_chat_history(self, session_id: str = "default"): + def delete_session_chat_history(self, session_id: str = "default"): self.llm.memory.delete(app_id=self.config.id, session_id=session_id) self.llm.update_history(app_id=self.config.id) + + def delete_all_chat_history(self, app_id: str): + self.llm.memory.delete(app_id=app_id) + self.llm.update_history(app_id=app_id) diff --git a/embedchain/memory/base.py b/embedchain/memory/base.py index c453a35143..9bfa04f2de 100644 --- a/embedchain/memory/base.py +++ b/embedchain/memory/base.py @@ -53,7 +53,7 @@ def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]: logging.info(f"Added chat memory to db with id: {memory_id}") return memory_id - def delete(self, app_id: str, session_id: str): + def delete(self, app_id: str, session_id: Optional[str] = None): """ Delete all chat history for a given app_id and session_id. This is useful for deleting chat history for a given user. @@ -63,8 +63,14 @@ def delete(self, app_id: str, session_id: str): :return: None """ - DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?" - self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id)) + if session_id: + DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?" + params = (app_id, session_id) + else: + DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=?" + params = (app_id,) + + self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params) self.connection.commit() def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]: @@ -99,7 +105,7 @@ def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[C history.append(memory) return history - def count(self, app_id: str, session_id: str): + def count(self, app_id: str, session_id: Optional[str] = None): """ Count the number of chat messages for a given app_id and session_id. @@ -108,8 +114,14 @@ def count(self, app_id: str, session_id: str): :return: The number of chat messages for a given app_id and session_id """ - QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?" - self.cursor.execute(QUERY, (app_id, session_id)) + if session_id: + QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?" + params = (app_id, session_id) + else: + QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=?" + params = (app_id,) + + self.cursor.execute(QUERY, params) count = self.cursor.fetchone()[0] return count diff --git a/pyproject.toml b/pyproject.toml index fc1305a4f2..28c0979a1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.61" +version = "0.1.62" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ", diff --git a/tests/memory/test_chat_memory.py b/tests/memory/test_chat_memory.py index d41bfbd1ec..bcf6601268 100644 --- a/tests/memory/test_chat_memory.py +++ b/tests/memory/test_chat_memory.py @@ -59,9 +59,26 @@ def test_delete_chat_history(chat_memory_instance): chat_memory_instance.add(app_id, session_id, chat_message) + session_id_2 = "test_session_2" + + for i in range(1, 6): + human_message = f"Question {i}" + ai_message = f"Answer {i}" + + chat_message = ChatMessage() + chat_message.add_user_message(human_message) + chat_message.add_ai_message(ai_message) + + chat_memory_instance.add(app_id, session_id_2, chat_message) + chat_memory_instance.delete(app_id, session_id) assert chat_memory_instance.count(app_id, session_id) == 0 + assert chat_memory_instance.count(app_id) == 5 + + chat_memory_instance.delete(app_id) + + assert chat_memory_instance.count(app_id) == 0 @pytest.fixture