Skip to content

Commit aaf03b1

Browse files
committed
Add query conditions to query all models
1 parent 24eaa5e commit aaf03b1

3 files changed

Lines changed: 11 additions & 7 deletions

File tree

api/v1/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Annotated
22

3-
from fastapi import APIRouter, Depends, Path
3+
from fastapi import APIRouter, Depends, Path, Query
44

55
from backend.common.pagination import DependsPagination, PageData
66
from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base
@@ -20,8 +20,10 @@
2020

2121

2222
@router.get('/all', summary='获取所有模型', dependencies=[DependsJwtAuth])
23-
async def get_all_ai_models(db: CurrentSession) -> ResponseSchemaModel[list[GetAIModelDetail]]:
24-
data = await ai_model_service.get_all(db=db)
23+
async def get_all_ai_models(
24+
db: CurrentSession, provider_id: Annotated[int, Query(description='供应商 ID')]
25+
) -> ResponseSchemaModel[list[GetAIModelDetail]]:
26+
data = await ai_model_service.get_all(db=db, provider_id=provider_id)
2527
return response_base.success(data=data)
2628

2729

crud/crud_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ async def get_select(self) -> Select:
3535
"""获取模型列表查询表达式"""
3636
return await self.select_order('id', 'desc')
3737

38-
async def get_all(self, db: AsyncSession) -> Sequence[AIModel]:
38+
async def get_all(self, db: AsyncSession, provider_id: int) -> Sequence[AIModel]:
3939
"""
4040
获取所有模型
4141
4242
:param db: 数据库会话
43+
:param provider_id: 供应商 ID
4344
:return:
4445
"""
45-
return await self.select_models(db)
46+
return await self.select_models(db, provider_id=provider_id)
4647

4748
async def create(self, db: AsyncSession, obj: CreateAIModelParam) -> None:
4849
"""

service/model_service.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,15 @@ async def get_list(db: AsyncSession) -> dict[str, Any]:
3939
return await paging_data(db, ai_model_select)
4040

4141
@staticmethod
42-
async def get_all(*, db: AsyncSession) -> Sequence[AIModel]:
42+
async def get_all(*, db: AsyncSession, provider_id: int) -> Sequence[AIModel]:
4343
"""
4444
获取所有 AI 模型
4545
4646
:param db: 数据库会话
47+
:param provider_id: 供应商 ID
4748
:return:
4849
"""
49-
ai_models = await ai_model_dao.get_all(db)
50+
ai_models = await ai_model_dao.get_all(db, provider_id=provider_id)
5051
return ai_models
5152

5253
@staticmethod

0 commit comments

Comments
 (0)