Skip to content

Commit f6df52b

Browse files
authored
Merge pull request #5 from NicholasGoh/feat/langgraph
Feat/langgraph
2 parents 5c11809 + 3184ab1 commit f6df52b

File tree

18 files changed

+406
-94
lines changed

18 files changed

+406
-94
lines changed

backend/api/config.py

Lines changed: 0 additions & 20 deletions
This file was deleted.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import functools
2+
3+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
4+
from langchain_core.runnables.base import RunnableSequence
5+
from langchain_core.tools import StructuredTool
6+
from langchain_openai import ChatOpenAI
7+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
8+
from langgraph.graph import MessagesState, StateGraph
9+
from langgraph.graph.state import CompiledStateGraph
10+
from langgraph.prebuilt import ToolNode, tools_condition
11+
12+
from api.core.agent.prompts import SYSTEM_PROMPT
13+
14+
15+
class State(MessagesState):
16+
next: str
17+
18+
19+
def agent_factory(
20+
llm: ChatOpenAI, tools: list[StructuredTool], system_prompt: str
21+
) -> RunnableSequence:
22+
prompt = ChatPromptTemplate.from_messages(
23+
[
24+
("system", system_prompt),
25+
MessagesPlaceholder(variable_name="messages"),
26+
]
27+
)
28+
if tools:
29+
agent = prompt | llm.bind_tools(tools)
30+
else:
31+
agent = prompt | llm
32+
return agent
33+
34+
35+
def agent_node_factory(
36+
state: State,
37+
agent: RunnableSequence,
38+
) -> State:
39+
result = agent.invoke(state)
40+
return dict(messages=[result])
41+
42+
43+
def graph_factory(
44+
agent_node: functools.partial,
45+
tools: list[StructuredTool],
46+
checkpointer: AsyncPostgresSaver | None = None,
47+
name: str = "agent_node",
48+
) -> CompiledStateGraph:
49+
graph_builder = StateGraph(State)
50+
graph_builder.add_node(name, agent_node)
51+
graph_builder.add_node("tools", ToolNode(tools))
52+
53+
graph_builder.add_conditional_edges(name, tools_condition)
54+
graph_builder.add_edge("tools", name)
55+
56+
graph_builder.set_entry_point(name)
57+
graph = graph_builder.compile(checkpointer=checkpointer)
58+
return graph
59+
60+
61+
def get_graph(
62+
llm: ChatOpenAI,
63+
tools: list[StructuredTool] = [],
64+
system_prompt: str = SYSTEM_PROMPT,
65+
name: str = "agent_node",
66+
checkpointer: AsyncPostgresSaver | None = None,
67+
) -> CompiledStateGraph:
68+
agent = agent_factory(llm, tools, system_prompt)
69+
worker_node = functools.partial(agent_node_factory, agent=agent)
70+
return graph_factory(worker_node, tools, checkpointer, name)
71+
72+
73+
def get_config():
74+
return dict(
75+
configurable=dict(thread_id="1"),
76+
)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from contextlib import asynccontextmanager
2+
from typing import AsyncGenerator
3+
4+
import psycopg
5+
import psycopg.errors
6+
import uvicorn
7+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
8+
from psycopg_pool import AsyncConnectionPool
9+
10+
from api.core.logs import uvicorn
11+
12+
13+
@asynccontextmanager
14+
async def checkpointer_context(
15+
conn_str: str,
16+
) -> AsyncGenerator[AsyncPostgresSaver]:
17+
"""
18+
Async context manager that sets up and yields a LangGraph checkpointer.
19+
20+
Uses a psycopg async connection pool to initialize AsyncPostgresSaver.
21+
Skips setup if checkpointer is already configured.
22+
23+
Args:
24+
conn_str (str): PostgreSQL connection string.
25+
26+
Yields:
27+
AsyncPostgresSaver: The initialized checkpointer.
28+
"""
29+
# NOTE: LangGraph AsyncPostgresSaver does not support SQLAlchemy ORM Connections.
30+
# A compatible psycopg connection is created via the connection pool to connect to the checkpointer.
31+
async with AsyncConnectionPool(
32+
conninfo=conn_str,
33+
kwargs=dict(prepare_threshold=None),
34+
) as pool:
35+
checkpointer = AsyncPostgresSaver(pool)
36+
try:
37+
await checkpointer.setup()
38+
except (
39+
psycopg.errors.DuplicateColumn,
40+
psycopg.errors.ActiveSqlTransaction,
41+
):
42+
uvicorn.warning("Skipping checkpointer setup — already configured.")
43+
yield checkpointer
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import os
2+
3+
4+
def read_system_prompt():
5+
with open(os.path.join(os.path.dirname(__file__), "system.md"), "r") as f:
6+
return f.read()
7+
8+
9+
SYSTEM_PROMPT = read_system_prompt()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
You are a helpful assistant.

backend/api/core/config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pydantic import PostgresDsn, computed_field
2+
from pydantic_settings import BaseSettings, SettingsConfigDict
3+
4+
5+
class Settings(BaseSettings):
6+
model_config = SettingsConfigDict(
7+
env_file="/opt/.env",
8+
env_ignore_empty=True,
9+
extra="ignore",
10+
)
11+
12+
model: str = "gpt-4o-mini-2024-07-18"
13+
openai_api_key: str = ""
14+
mcp_server_port: int = 8050
15+
16+
postgres_dsn: PostgresDsn = (
17+
"postgresql+psycopg://postgres:[email protected]:6543/postgres"
18+
)
19+
20+
@computed_field
21+
@property
22+
def orm_conn_str(self) -> str:
23+
return self.postgres_dsn.encoded_string()
24+
25+
@computed_field
26+
@property
27+
def checkpoint_conn_str(self) -> str:
28+
# NOTE: LangGraph AsyncPostgresSaver has some issues
29+
# with specifying psycopg driver explicitly
30+
return self.postgres_dsn.encoded_string().replace("+psycopg", "")
31+
32+
33+
settings = Settings()

backend/api/core/dependencies.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from contextlib import asynccontextmanager
2+
from typing import Annotated, AsyncGenerator
3+
4+
from fastapi import Depends
5+
from langchain_mcp_adapters.tools import load_mcp_tools
6+
from langchain_openai import ChatOpenAI
7+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
8+
9+
from api.core.agent.persistence import checkpointer_context
10+
from api.core.config import settings
11+
from api.core.mcps import mcp_sse_client
12+
from api.core.models import Resource
13+
14+
15+
def get_llm() -> ChatOpenAI:
16+
return ChatOpenAI(
17+
streaming=True,
18+
model=settings.model,
19+
temperature=0,
20+
api_key=settings.openai_api_key,
21+
stream_usage=True,
22+
)
23+
24+
25+
LLMDep = Annotated[ChatOpenAI, Depends(get_llm)]
26+
27+
28+
engine: AsyncEngine = create_async_engine(settings.orm_conn_str)
29+
30+
31+
def get_engine() -> AsyncEngine:
32+
return engine
33+
34+
35+
EngineDep = Annotated[AsyncEngine, Depends(get_engine)]
36+
37+
38+
@asynccontextmanager
39+
async def setup_graph() -> AsyncGenerator[Resource]:
40+
async with checkpointer_context(
41+
settings.checkpoint_conn_str
42+
) as checkpointer:
43+
async with mcp_sse_client() as session:
44+
tools = await load_mcp_tools(session)
45+
yield Resource(
46+
checkpointer=checkpointer,
47+
tools=tools,
48+
session=session,
49+
)

backend/api/core/logs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from logging import getLogger
2+
3+
from rich.pretty import pprint as print
4+
5+
print # facade
6+
7+
uvicorn = getLogger("uvicorn")

backend/api/core/mcps.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from contextlib import asynccontextmanager
2+
from typing import AsyncGenerator
3+
4+
from mcp import ClientSession
5+
from mcp.client.sse import sse_client
6+
7+
from api.core.config import settings
8+
9+
10+
@asynccontextmanager
11+
async def mcp_sse_client() -> AsyncGenerator[ClientSession]:
12+
"""
13+
Creates and initializes an MCP client session over SSE.
14+
15+
Establishes an SSE connection to the MCP server and yields an initialized
16+
`ClientSession` for communication.
17+
18+
Yields:
19+
ClientSession: An initialized MCP client session.
20+
"""
21+
async with sse_client(f"http://mcp:{settings.mcp_server_port}/sse") as (
22+
read_stream,
23+
write_stream,
24+
):
25+
async with ClientSession(read_stream, write_stream) as session:
26+
await session.initialize()
27+
yield session

backend/api/core/models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from langchain_core.tools import StructuredTool
2+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
3+
from mcp import ClientSession
4+
from pydantic import BaseModel
5+
6+
7+
class Resource(BaseModel):
8+
checkpointer: AsyncPostgresSaver
9+
tools: list[StructuredTool]
10+
session: ClientSession
11+
12+
class Config:
13+
arbitrary_types_allowed = True

0 commit comments

Comments
 (0)