|
3 | 3 | import psycopg.errors |
4 | 4 | from fastapi import APIRouter |
5 | 5 | from langchain_core.messages import HumanMessage |
| 6 | +from langchain_mcp_adapters.tools import load_mcp_tools |
6 | 7 | from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver |
7 | 8 | from psycopg_pool import AsyncConnectionPool |
8 | 9 | from sse_starlette.sse import EventSourceResponse |
|
12 | 13 | from api.core.config import settings |
13 | 14 | from api.core.dependencies import LLMDep |
14 | 15 | from api.core.logs import print, uvicorn |
| 16 | +from api.routers.mcps import mcp_sse_client |
15 | 17 |
|
16 | 18 | router = APIRouter(tags=["chat"]) |
17 | 19 |
|
@@ -68,9 +70,15 @@ async def stream_graph( |
68 | 70 | ) as pool: |
69 | 71 | checkpointer = await checkpointer_setup(pool) |
70 | 72 |
|
71 | | - graph = get_graph(llm, checkpointer=checkpointer) |
72 | | - config = get_config() |
73 | | - events = dict(messages=[HumanMessage(content=query)]) |
74 | | - |
75 | | - async for event in graph.astream_events(events, config, version="v2"): |
76 | | - yield dict(data=event) |
| 73 | + async with mcp_sse_client() as session: |
| 74 | + tools = await load_mcp_tools(session) |
| 75 | + graph = get_graph(llm, tools=tools, checkpointer=checkpointer) |
| 76 | + config = get_config() |
| 77 | + events = dict(messages=[HumanMessage(content=query)]) |
| 78 | + |
| 79 | + async for event in graph.astream_events( |
| 80 | + events, config, version="v2" |
| 81 | + ): |
| 82 | + if event.get("event").endswith("end"): |
| 83 | + print(event) |
| 84 | + yield dict(data=event) |
0 commit comments