|
3 | 3 | from collections.abc import Sequence |
4 | 4 | from typing import Any |
5 | 5 |
|
| 6 | +from pydantic_ai.capabilities import AbstractCapability, PrefixTools, Toolset |
6 | 7 | from pydantic_ai.mcp import MCPServerSSE, MCPServerStdio, MCPServerStreamableHTTP |
7 | | -from pydantic_ai.toolsets import AbstractToolset, PrefixedToolset |
8 | 8 | from sqlalchemy.ext.asyncio import AsyncSession |
9 | 9 |
|
10 | 10 | from backend.common.exception import errors |
11 | 11 | from backend.common.pagination import paging_data |
12 | 12 | from backend.plugin.ai.crud.crud_mcp import mcp_dao |
| 13 | +from backend.plugin.ai.dataclasses import ChatAgentDeps |
13 | 14 | from backend.plugin.ai.enums import McpType |
14 | 15 | from backend.plugin.ai.model import Mcp |
15 | 16 | from backend.plugin.ai.schema.mcp import CreateMcpParam, UpdateMcpParam |
16 | 17 |
|
17 | | -McpToolset = AbstractToolset[Any] |
18 | | - |
19 | 18 |
|
20 | 19 | class McpService: |
21 | 20 | """MCP 服务类""" |
@@ -45,16 +44,16 @@ async def get_all(*, db: AsyncSession) -> Sequence[Mcp]: |
45 | 44 | return await mcp_dao.get_all(db) |
46 | 45 |
|
47 | 46 | @staticmethod |
48 | | - async def get_toolsets(*, db: AsyncSession, mcp_ids: list[int]) -> list[McpToolset]: |
| 47 | + async def get_capabilities(*, db: AsyncSession, mcp_ids: list[int]) -> list[AbstractCapability[ChatAgentDeps]]: |
49 | 48 | """ |
50 | | - 获取 MCP 工具集 |
| 49 | + 获取 MCP 能力 |
51 | 50 |
|
52 | 51 | :param db: 数据库会话 |
53 | 52 | :param mcp_ids: MCP ID 列表 |
54 | 53 | :return: |
55 | 54 | """ |
56 | 55 | mcps = await mcp_dao.get_by_ids(db, mcp_ids) |
57 | | - toolsets: list[McpToolset] = [] |
| 56 | + capabilities: list[AbstractCapability[ChatAgentDeps]] = [] |
58 | 57 | for mcp in mcps: |
59 | 58 | headers = json.loads(mcp.headers) if isinstance(mcp.headers, str) else (mcp.headers or {}) |
60 | 59 | if not isinstance(headers, dict): |
@@ -91,9 +90,8 @@ async def get_toolsets(*, db: AsyncSession, mcp_ids: list[int]) -> list[McpTools |
91 | 90 | timeout=mcp.timeout, |
92 | 91 | read_timeout=mcp.read_timeout, |
93 | 92 | ) |
94 | | - # 此举是为了为避免 MCP 工具名称冲突 |
95 | | - toolsets.append(PrefixedToolset(toolset, prefix=f'mcp_{mcp.id}')) |
96 | | - return toolsets |
| 93 | + capabilities.append(PrefixTools(Toolset(toolset), prefix=f'mcp_{mcp.id}')) |
| 94 | + return capabilities |
97 | 95 |
|
98 | 96 | @staticmethod |
99 | 97 | async def get_list(*, db: AsyncSession, name: str | None, type: int | None) -> dict[str, Any]: |
|
0 commit comments