Skip to content

Commit 6e4f3dd

Browse files
committed
perf: Add schema type hint.
1 parent 178fc76 commit 6e4f3dd

2 files changed

Lines changed: 67 additions & 39 deletions

File tree

fastapi_amis_admin/crud/_sqlmodel.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,23 @@
22
import logging
33
import re
44
from enum import Enum
5-
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Type, Union
5+
from typing import (
6+
Any,
7+
Callable,
8+
Dict,
9+
Generic,
10+
List,
11+
Optional,
12+
Pattern,
13+
Tuple,
14+
Type,
15+
TypeVar,
16+
Union,
17+
)
618

719
from fastapi import APIRouter, Body, Depends, Query
820
from fastapi.encoders import DictIntStrAny, SetIntStr
9-
from pydantic import BaseModel, Extra, Json
21+
from pydantic import Extra, Json
1022
from sqlalchemy import Column, Table, delete, func, insert
1123
from sqlalchemy.future import select
1224
from sqlalchemy.orm import InstrumentedAttribute, Session
@@ -17,7 +29,14 @@
1729

1830
from fastapi_amis_admin.utils.functools import cached_property
1931

20-
from .base import BaseCrud
32+
from .base import (
33+
BaseCrud,
34+
SchemaCreateT,
35+
SchemaFilterT,
36+
SchemaListT,
37+
SchemaReadT,
38+
SchemaUpdateT,
39+
)
2140
from .parser import (
2241
SqlField,
2342
SQLModelField,
@@ -51,9 +70,11 @@
5170
"-": "between",
5271
}
5372

73+
ModelT = TypeVar("ModelT", bound=SQLModel)
74+
5475

55-
class SQLModelSelector:
56-
model: Type[SQLModel] = None # SQLModel模型
76+
class SQLModelSelector(Generic[ModelT]):
77+
model: Type[ModelT] = None # SQLModel模型
5778
fields: List[SQLModelListField] = [] # 需要查询的字段
5879
list_filter: List[SQLModelListField] = [] # 查询可过滤的字段
5980
exclude: List[SQLModelField] = [] # 不需要查询的字段
@@ -71,7 +92,7 @@ class SQLModelSelector:
7192
"""
7293
pk_name: str = "id" # 主键名称
7394

74-
def __init__(self, model: Type[SQLModel] = None, fields: List[SQLModelListField] = None) -> None:
95+
def __init__(self, model: Type[ModelT] = None, fields: List[SQLModelListField] = None) -> None:
7596
self.model = model or self.model
7697
assert self.model, "model is None"
7798
self.pk_name: str = self.pk_name or self.model.__table__.primary_key.columns.keys()[0]
@@ -215,7 +236,7 @@ class SQLModelCrud(BaseCrud, SQLModelSelector):
215236

216237
def __init__(
217238
self,
218-
model: Type[SQLModel],
239+
model: Type[ModelT],
219240
engine: SqlalchemyDatabase,
220241
fields: List[SQLModelListField] = None,
221242
router: APIRouter = None,
@@ -231,7 +252,7 @@ def __init__(
231252
"Please replace them with update_fields and update_exclude."
232253
)
233254

234-
def _create_schema_list(self):
255+
def _create_schema_list(self) -> Type[SchemaListT]:
235256
if self.schema_list:
236257
return self.schema_list
237258
modelfields = list(
@@ -247,7 +268,7 @@ def _create_schema_list(self):
247268
extra=Extra.allow,
248269
)
249270

250-
def _create_schema_filter(self):
271+
def _create_schema_filter(self) -> Type[SchemaFilterT]:
251272
if self.schema_filter:
252273
return self.schema_filter
253274
self.list_filter = self.list_filter or self._select_entities.values()
@@ -275,7 +296,7 @@ def _create_schema_filter(self):
275296
set_none=True,
276297
)
277298

278-
def _create_schema_update(self):
299+
def _create_schema_update(self) -> Type[SchemaUpdateT]:
279300
if self.schema_update:
280301
return self.schema_update
281302
if not self.readonly_fields and not self.update_fields:
@@ -288,7 +309,7 @@ def _create_schema_update(self):
288309
modelfields = [field for field in modelfields if field.name not in readonly_fields]
289310
return schema_create_by_modelfield(f"{self.schema_name_prefix}Update", modelfields, set_none=True)
290311

291-
def _create_schema_create(self):
312+
def _create_schema_create(self) -> Type[SchemaCreateT]:
292313
if self.schema_create:
293314
return self.schema_create
294315
if not self.create_fields:
@@ -301,12 +322,12 @@ def _create_schema_create(self):
301322
)
302323
return schema_create_by_modelfield(f"{self.schema_name_prefix}Create", modelfields)
303324

304-
def read_item(self, obj: SQLModel) -> BaseModel:
325+
def read_item(self, obj: ModelT) -> SchemaReadT:
305326
"""read database data and parse to schema_read"""
306327
parse = self.schema_read.from_orm if self.schema_read.Config.orm_mode else self.schema_read.parse_obj
307328
return parse(obj)
308329

309-
def update_item(self, obj: SQLModel, values: Dict[str, Any]) -> None:
330+
def update_item(self, obj: ModelT, values: Dict[str, Any]) -> None:
310331
"""update schema_update data to database,support relational attributes"""
311332
for k, v in values.items():
312333
if isinstance(v, dict) and hasattr(obj, k):
@@ -315,7 +336,7 @@ def update_item(self, obj: SQLModel, values: Dict[str, Any]) -> None:
315336
else:
316337
setattr(obj, k, v)
317338

318-
def _fetch_item_scalars(self, session: Session, item_id: List[str]) -> List[SQLModel]:
339+
def _fetch_item_scalars(self, session: Session, item_id: List[str]) -> List[ModelT]:
319340
stmt = select(self.model).where(self.pk.in_(list(map(get_python_type_parse(self.pk), item_id))))
320341
return session.scalars(stmt).all()
321342

@@ -335,7 +356,7 @@ def schema_name_prefix(self):
335356
return self.model.__name__
336357
return super().schema_name_prefix
337358

338-
async def on_create_pre(self, request: Request, obj: BaseModel, **kwargs) -> Dict[str, Any]:
359+
async def on_create_pre(self, request: Request, obj: SchemaCreateT, **kwargs) -> Dict[str, Any]:
339360
data_dict = obj.dict(by_alias=True) # exclude=set(self.pk)
340361
if self.pk_name in data_dict and not data_dict.get(self.pk_name):
341362
del data_dict[self.pk_name]
@@ -344,15 +365,15 @@ async def on_create_pre(self, request: Request, obj: BaseModel, **kwargs) -> Dic
344365
async def on_update_pre(
345366
self,
346367
request: Request,
347-
obj: BaseModel,
368+
obj: SchemaUpdateT,
348369
item_id: Union[List[str], List[int]],
349370
**kwargs,
350371
) -> Dict[str, Any]:
351372
data = obj.dict(exclude=self.update_exclude, exclude_unset=True, by_alias=True)
352373
data = {key: val for key, val in data.items() if val is not None or self.model.__fields__[key].allow_none}
353374
return data
354375

355-
async def on_filter_pre(self, request: Request, obj: BaseModel, **kwargs) -> Dict[str, Any]:
376+
async def on_filter_pre(self, request: Request, obj: SchemaFilterT, **kwargs) -> Dict[str, Any]:
356377
return obj and {k: v for k, v in obj.dict(exclude_unset=True, by_alias=True).items() if v is not None}
357378

358379
@property

fastapi_amis_admin/crud/base.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, List, Optional, Type, Union
1+
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
22

33
from fastapi import APIRouter, Depends
44
from pydantic import BaseModel
@@ -10,6 +10,13 @@
1010
from .schema import BaseApiOut, CrudEnum, ItemListSchema, Paginator
1111
from .utils import schema_create_by_schema
1212

13+
SchemaModelT = TypeVar("SchemaModelT", bound=BaseModel)
14+
SchemaListT = TypeVar("SchemaListT", bound=BaseModel)
15+
SchemaFilterT = TypeVar("SchemaFilterT", bound=BaseModel)
16+
SchemaCreateT = TypeVar("SchemaCreateT", bound=BaseModel)
17+
SchemaReadT = TypeVar("SchemaReadT", bound=BaseModel)
18+
SchemaUpdateT = TypeVar("SchemaUpdateT", bound=BaseModel)
19+
1320

1421
class RouterMixin:
1522
router: APIRouter = None
@@ -32,17 +39,17 @@ def error_no_router_permission(self, request: Request):
3239
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="No router permissions")
3340

3441

35-
class BaseCrud(RouterMixin):
36-
schema_model: Type[BaseModel] = None
37-
schema_list: Type[BaseModel] = None
38-
schema_filter: Type[BaseModel] = None
39-
schema_create: Type[BaseModel] = None
40-
schema_read: Type[BaseModel] = None
41-
schema_update: Type[BaseModel] = None
42+
class BaseCrud(RouterMixin, Generic[SchemaModelT, SchemaListT, SchemaFilterT, SchemaCreateT, SchemaReadT, SchemaUpdateT]):
43+
schema_model: Type[SchemaModelT] = None
44+
schema_list: Type[SchemaListT] = None
45+
schema_filter: Type[SchemaFilterT] = None
46+
schema_create: Type[SchemaCreateT] = None
47+
schema_read: Type[SchemaReadT] = None
48+
schema_update: Type[SchemaUpdateT] = None
4249
pk_name: str = "id"
4350
list_per_page_max: int = None
4451

45-
def __init__(self, schema_model: Type[BaseModel], router: APIRouter = None):
52+
def __init__(self, schema_model: Type[SchemaModelT], router: APIRouter = None):
4653
self.paginator = Paginator()
4754
self.schema_model = schema_model or self.schema_model
4855
assert self.schema_model, "schema_model is None"
@@ -59,11 +66,11 @@ def schema_name_prefix(self):
5966

6067
def register_crud(
6168
self,
62-
schema_list: Type[BaseModel] = None,
63-
schema_filter: Type[BaseModel] = None,
64-
schema_create: Type[BaseModel] = None,
65-
schema_read: Type[BaseModel] = None,
66-
schema_update: Type[BaseModel] = None,
69+
schema_list: Type[SchemaListT] = None,
70+
schema_filter: Type[SchemaFilterT] = None,
71+
schema_create: Type[SchemaCreateT] = None,
72+
schema_read: Type[SchemaReadT] = None,
73+
schema_update: Type[SchemaUpdateT] = None,
6774
list_per_page_max: int = None,
6875
depends_list: List[Depends] = None,
6976
depends_read: List[Depends] = None,
@@ -120,24 +127,24 @@ def register_crud(
120127
)
121128
return self
122129

123-
def _create_schema_list(self):
130+
def _create_schema_list(self) -> Type[SchemaListT]:
124131
return self.schema_list or self.schema_model
125132

126-
def _create_schema_filter(self):
133+
def _create_schema_filter(self) -> Type[SchemaFilterT]:
127134
return self.schema_filter or schema_create_by_schema(self.schema_list, f"{self.schema_name_prefix}Filter", set_none=True)
128135

129-
def _create_schema_read(self):
136+
def _create_schema_read(self) -> Type[SchemaReadT]:
130137
return self.schema_read or self.schema_model
131138

132-
def _create_schema_update(self):
139+
def _create_schema_update(self) -> Type[SchemaUpdateT]:
133140
return self.schema_update or schema_create_by_schema(
134141
self.schema_model,
135142
f"{self.schema_name_prefix}Update",
136143
exclude={self.pk_name},
137144
set_none=True,
138145
)
139146

140-
def _create_schema_create(self):
147+
def _create_schema_create(self) -> Type[SchemaCreateT]:
141148
return self.schema_create or schema_create_by_schema(self.schema_model, f"{self.schema_name_prefix}Create")
142149

143150
@property
@@ -164,12 +171,12 @@ async def has_list_permission(
164171
self,
165172
request: Request,
166173
paginator: Optional[Paginator],
167-
filters: Optional[BaseModel],
174+
filters: Optional[SchemaFilterT],
168175
**kwargs,
169176
) -> bool:
170177
return True
171178

172-
async def has_create_permission(self, request: Request, obj: Optional[BaseModel], **kwargs) -> bool:
179+
async def has_create_permission(self, request: Request, obj: Optional[SchemaCreateT], **kwargs) -> bool:
173180
return True
174181

175182
async def has_read_permission(self, request: Request, item_id: Optional[List[str]], **kwargs) -> bool:
@@ -179,7 +186,7 @@ async def has_update_permission(
179186
self,
180187
request: Request,
181188
item_id: Optional[List[str]],
182-
obj: Optional[BaseModel],
189+
obj: Optional[SchemaUpdateT],
183190
**kwargs,
184191
) -> bool:
185192
return True

0 commit comments

Comments
 (0)