Skip to content
Open
80 changes: 65 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Feel free to use the abstraction as provided or else modify them / extend them a

## Requirements

The package currently only supports the [psycogp3](https://www.psycopg.org/psycopg3/) driver.
- [psycopg3](https://www.psycopg.org/psycopg3/): The PostgreSQL driver.
- [psycopg_pool](https://www.psycopg.org/psycopg3/docs/advanced/pool.html): For connection pooling support.

## Installation

Expand All @@ -25,24 +26,23 @@ pip install -U langchain-postgres

## Change Log

0.0.6:
- Remove langgraph as a dependency as it was causing dependency conflicts.
- Base interface for checkpointer changed in langgraph, so existing implementation would've broken regardless.
**0.0.7:**

- Added support for asynchronous connection pooling in `PostgresChatMessageHistory`.
- Adjusted parameter order in `PostgresChatMessageHistory` to make `session_id` the first parameter.

## Usage

### ChatMessageHistory

The chat message history abstraction helps to persist chat message history
in a postgres table.
The chat message history abstraction helps to persist chat message history in a Postgres table.

PostgresChatMessageHistory is parameterized using a `table_name` and a `session_id`.
`PostgresChatMessageHistory` is parameterized using a `session_id` and an optional `table_name` (default is `"chat_history"`).

The `table_name` is the name of the table in the database where
the chat messages will be stored.
- **`session_id`:** A unique identifier for the chat session. It can be assigned using `uuid.uuid4()`.
- **`table_name`:** The name of the table in the database where the chat messages will be stored.

The `session_id` is a unique identifier for the chat session. It can be assigned
by the caller using `uuid.uuid4()`.
#### **Synchronous Usage**

```python
import uuid
Expand All @@ -52,8 +52,7 @@ from langchain_postgres import PostgresChatMessageHistory
import psycopg

# Establish a synchronous connection to the database
# (or use psycopg.AsyncConnection for async)
conn_info = ... # Fill in with your connection info
conn_info = "postgresql://user:password@host:port/dbname" # Replace with your connection info
sync_connection = psycopg.connect(conn_info)

# Create the table schema (only needs to be done once)
Expand All @@ -64,8 +63,8 @@ session_id = str(uuid.uuid4())

# Initialize the chat history manager
chat_history = PostgresChatMessageHistory(
table_name,
session_id,
session_id=session_id,
table_name=table_name,
sync_connection=sync_connection
)

Expand All @@ -79,6 +78,57 @@ chat_history.add_messages([
print(chat_history.messages)
```

#### **Asynchronous Usage with Connection Pooling**

```python
import uuid
import asyncio

from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_postgres import PostgresChatMessageHistory
from psycopg_pool import AsyncConnectionPool

# Asynchronous main function
async def main():
# Database connection string
conn_info = "postgresql://user:password@host:port/dbname" # Replace with your connection info

# Initialize the connection pool
pool = AsyncConnectionPool(conninfo=conn_info)

try:
# Create the table schema (only needs to be done once)
async with pool.connection() as async_connection:
table_name = "chat_history"
await PostgresChatMessageHistory.adrop_table(async_connection, table_name)
await PostgresChatMessageHistory.acreate_tables(async_connection, table_name)

session_id = str(uuid.uuid4())

# Initialize the chat history manager with the connection pool
chat_history = PostgresChatMessageHistory(
session_id=session_id,
table_name=table_name,
conn_pool=pool
)

# Add messages to the chat history asynchronously
await chat_history.aadd_messages([
SystemMessage(content="System message"),
AIMessage(content="AI response"),
HumanMessage(content="Human message"),
])

# Retrieve messages from the chat history
messages = await chat_history.aget_messages()
print(messages)
finally:
# Close the connection pool
await pool.close()

# Run the async main function
asyncio.run(main())
```

### Vectorstore

Expand Down
82 changes: 54 additions & 28 deletions langchain_postgres/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from typing import List, Optional, Sequence

import psycopg
from psycopg_pool import AsyncConnectionPool
from typing import Optional, Union, AsyncGenerator
from contextlib import asynccontextmanager
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
from psycopg import sql
Expand Down Expand Up @@ -77,12 +80,12 @@ def _insert_message_query(table_name: str) -> sql.Composed:
class PostgresChatMessageHistory(BaseChatMessageHistory):
def __init__(
self,
table_name: str,
session_id: str,
/,
table_name: str = "chat_history",
*,
sync_connection: Optional[psycopg.Connection] = None,
async_connection: Optional[psycopg.AsyncConnection] = None,
conn_pool: Optional[AsyncConnectionPool] = None,
) -> None:
"""Client for persisting chat message history in a Postgres database,

Expand Down Expand Up @@ -132,6 +135,8 @@ def __init__(
table_name: The name of the database table to use
sync_connection: An existing psycopg connection instance
async_connection: An existing psycopg async connection instance
conn_pool: AsyncConnectionPool instance for managing async connections.


Usage:
- Use the create_tables or acreate_tables method to set up the table
Expand Down Expand Up @@ -181,11 +186,14 @@ def __init__(

print(chat_history.messages)
"""
if not sync_connection and not async_connection:
raise ValueError("Must provide sync_connection or async_connection")
if not sync_connection and not async_connection and not conn_pool:
raise ValueError(
"Must provide sync_connection, async_connection, or conn_pool."
)

self._connection = sync_connection
self._aconnection = async_connection
self._conn_pool = conn_pool

# Validate that session id is a UUID
try:
Expand Down Expand Up @@ -290,23 +298,33 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
self._connection.commit()

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add messages to the chat message history."""
if self._aconnection is None:
"""Add messages to the chat message history asynchronously."""
if self._conn_pool is not None:
values = [
(self._session_id, json.dumps(message_to_dict(message)))
for message in messages
]
async with self._conn_pool.connection() as async_connection:
query = self._insert_message_query(self._table_name)
async with async_connection.cursor() as cursor:
await cursor.executemany(query, values)
await async_connection.commit()
elif self._aconnection is not None:
# Existing code using self._aconnection
values = [
(self._session_id, json.dumps(message_to_dict(message)))
for message in messages
]
query = self._insert_message_query(self._table_name)
async with self._aconnection.cursor() as cursor:
await cursor.executemany(query, values)
await self._aconnection.commit()
else:
raise ValueError(
"Please initialize the PostgresChatMessageHistory "
"with an async connection or use the sync add_messages method instead."
"with an async connection or connection pool."
)

values = [
(self._session_id, json.dumps(message_to_dict(message)))
for message in messages
]

query = _insert_message_query(self._table_name)
async with self._aconnection.cursor() as cursor:
await cursor.executemany(query, values)
await self._aconnection.commit()

def get_messages(self) -> List[BaseMessage]:
"""Retrieve messages from the chat message history."""
if self._connection is None:
Expand All @@ -325,21 +343,29 @@ def get_messages(self) -> List[BaseMessage]:
return messages

async def aget_messages(self) -> List[BaseMessage]:
"""Retrieve messages from the chat message history."""
if self._aconnection is None:
"""Retrieve messages from the chat message history asynchronously."""
if self._conn_pool is not None:
async with self._conn_pool.connection() as async_connection:
query = self._get_messages_query(self._table_name)
async with async_connection.cursor() as cursor:
await cursor.execute(query, {"session_id": self._session_id})
items = [record[0] for record in await cursor.fetchall()]
messages = messages_from_dict(items)
return messages
elif self._aconnection is not None:
# Existing code using self._aconnection
query = self._get_messages_query(self._table_name)
async with self._aconnection.cursor() as cursor:
await cursor.execute(query, {"session_id": self._session_id})
items = [record[0] for record in await cursor.fetchall()]
messages = messages_from_dict(items)
return messages
else:
raise ValueError(
"Please initialize the PostgresChatMessageHistory "
"with an async connection or use the sync get_messages method instead."
"with an async connection or connection pool."
)

query = _get_messages_query(self._table_name)
async with self._aconnection.cursor() as cursor:
await cursor.execute(query, {"session_id": self._session_id})
items = [record[0] for record in await cursor.fetchall()]

messages = messages_from_dict(items)
return messages

@property # type: ignore[override]
def messages(self) -> List[BaseMessage]:
"""The abstraction required a property."""
Expand Down
75 changes: 74 additions & 1 deletion tests/unit_tests/test_chat_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

from langchain_postgres.chat_message_histories import PostgresChatMessageHistory
from tests.utils import asyncpg_client, syncpg_client
from psycopg_pool import AsyncConnectionPool
from tests.utils import asyncpg_client, syncpg_client, DSN


def test_sync_chat_history() -> None:
Expand Down Expand Up @@ -121,3 +122,75 @@ async def test_async_chat_history() -> None:
# clear
await chat_history.aclear()
assert await chat_history.aget_messages() == []


async def test_async_chat_history_with_pool() -> None:
"""Test the async chat history using a connection pool."""
# Initialize the connection pool
pool = AsyncConnectionPool(conninfo=DSN)
try:
table_name = "chat_history"
session_id = str(uuid.uuid4())

# Create tables using a connection from the pool
async with pool.connection() as async_connection:
await PostgresChatMessageHistory.adrop_table(async_connection, table_name)
await PostgresChatMessageHistory.acreate_tables(async_connection, table_name)

# Create PostgresChatMessageHistory with conn_pool
chat_history = PostgresChatMessageHistory(
session_id=session_id,
table_name=table_name,
conn_pool=pool,
)

# Ensure the chat history is empty
messages = await chat_history.aget_messages()
assert messages == []

# Add messages to the chat history
await chat_history.aadd_messages(
[
SystemMessage(content="System message"),
AIMessage(content="AI response"),
HumanMessage(content="Human message"),
]
)

# Retrieve messages from the chat history
messages = await chat_history.aget_messages()
assert len(messages) == 3
assert messages == [
SystemMessage(content="System message"),
AIMessage(content="AI response"),
HumanMessage(content="Human message"),
]

# Add more messages
await chat_history.aadd_messages(
[
SystemMessage(content="Another system message"),
AIMessage(content="Another AI response"),
HumanMessage(content="Another human message"),
]
)

# Verify all messages are retrieved
messages = await chat_history.aget_messages()
assert len(messages) == 6
assert messages == [
SystemMessage(content="System message"),
AIMessage(content="AI response"),
HumanMessage(content="Human message"),
SystemMessage(content="Another system message"),
AIMessage(content="Another AI response"),
HumanMessage(content="Another human message"),
]

# Clear the chat history
await chat_history.aclear()
messages = await chat_history.aget_messages()
assert messages == []
finally:
# Close the connection pool
await pool.close()