22import logging
33import re
44from enum import Enum
5- from typing import Any , Callable , Dict , List , Optional , Pattern , Tuple , Type , Union
5+ from typing import (
6+ Any ,
7+ Callable ,
8+ Dict ,
9+ Generic ,
10+ List ,
11+ Optional ,
12+ Pattern ,
13+ Tuple ,
14+ Type ,
15+ TypeVar ,
16+ Union ,
17+ )
618
719from fastapi import APIRouter , Body , Depends , Query
820from fastapi .encoders import DictIntStrAny , SetIntStr
9- from pydantic import BaseModel , Extra , Json
21+ from pydantic import Extra , Json
1022from sqlalchemy import Column , Table , delete , func , insert
1123from sqlalchemy .future import select
1224from sqlalchemy .orm import InstrumentedAttribute , Session
1729
1830from fastapi_amis_admin .utils .functools import cached_property
1931
20- from .base import BaseCrud
32+ from .base import (
33+ BaseCrud ,
34+ SchemaCreateT ,
35+ SchemaFilterT ,
36+ SchemaListT ,
37+ SchemaReadT ,
38+ SchemaUpdateT ,
39+ )
2140from .parser import (
2241 SqlField ,
2342 SQLModelField ,
5170 "-" : "between" ,
5271}
5372
73+ ModelT = TypeVar ("ModelT" , bound = SQLModel )
74+
5475
55- class SQLModelSelector :
56- model : Type [SQLModel ] = None # SQLModel模型
76+ class SQLModelSelector ( Generic [ ModelT ]) :
77+ model : Type [ModelT ] = None # SQLModel模型
5778 fields : List [SQLModelListField ] = [] # 需要查询的字段
5879 list_filter : List [SQLModelListField ] = [] # 查询可过滤的字段
5980 exclude : List [SQLModelField ] = [] # 不需要查询的字段
@@ -71,7 +92,7 @@ class SQLModelSelector:
7192 """
7293 pk_name : str = "id" # 主键名称
7394
74- def __init__ (self , model : Type [SQLModel ] = None , fields : List [SQLModelListField ] = None ) -> None :
95+ def __init__ (self , model : Type [ModelT ] = None , fields : List [SQLModelListField ] = None ) -> None :
7596 self .model = model or self .model
7697 assert self .model , "model is None"
7798 self .pk_name : str = self .pk_name or self .model .__table__ .primary_key .columns .keys ()[0 ]
@@ -215,7 +236,7 @@ class SQLModelCrud(BaseCrud, SQLModelSelector):
215236
216237 def __init__ (
217238 self ,
218- model : Type [SQLModel ],
239+ model : Type [ModelT ],
219240 engine : SqlalchemyDatabase ,
220241 fields : List [SQLModelListField ] = None ,
221242 router : APIRouter = None ,
@@ -231,7 +252,7 @@ def __init__(
231252 "Please replace them with update_fields and update_exclude."
232253 )
233254
234- def _create_schema_list (self ):
255+ def _create_schema_list (self ) -> Type [ SchemaListT ] :
235256 if self .schema_list :
236257 return self .schema_list
237258 modelfields = list (
@@ -247,7 +268,7 @@ def _create_schema_list(self):
247268 extra = Extra .allow ,
248269 )
249270
250- def _create_schema_filter (self ):
271+ def _create_schema_filter (self ) -> Type [ SchemaFilterT ] :
251272 if self .schema_filter :
252273 return self .schema_filter
253274 self .list_filter = self .list_filter or self ._select_entities .values ()
@@ -275,7 +296,7 @@ def _create_schema_filter(self):
275296 set_none = True ,
276297 )
277298
278- def _create_schema_update (self ):
299+ def _create_schema_update (self ) -> Type [ SchemaUpdateT ] :
279300 if self .schema_update :
280301 return self .schema_update
281302 if not self .readonly_fields and not self .update_fields :
@@ -288,7 +309,7 @@ def _create_schema_update(self):
288309 modelfields = [field for field in modelfields if field .name not in readonly_fields ]
289310 return schema_create_by_modelfield (f"{ self .schema_name_prefix } Update" , modelfields , set_none = True )
290311
291- def _create_schema_create (self ):
312+ def _create_schema_create (self ) -> Type [ SchemaCreateT ] :
292313 if self .schema_create :
293314 return self .schema_create
294315 if not self .create_fields :
@@ -301,12 +322,12 @@ def _create_schema_create(self):
301322 )
302323 return schema_create_by_modelfield (f"{ self .schema_name_prefix } Create" , modelfields )
303324
304- def read_item (self , obj : SQLModel ) -> BaseModel :
325+ def read_item (self , obj : ModelT ) -> SchemaReadT :
305326 """read database data and parse to schema_read"""
306327 parse = self .schema_read .from_orm if self .schema_read .Config .orm_mode else self .schema_read .parse_obj
307328 return parse (obj )
308329
309- def update_item (self , obj : SQLModel , values : Dict [str , Any ]) -> None :
330+ def update_item (self , obj : ModelT , values : Dict [str , Any ]) -> None :
310331 """update schema_update data to database,support relational attributes"""
311332 for k , v in values .items ():
312333 if isinstance (v , dict ) and hasattr (obj , k ):
@@ -315,7 +336,7 @@ def update_item(self, obj: SQLModel, values: Dict[str, Any]) -> None:
315336 else :
316337 setattr (obj , k , v )
317338
318- def _fetch_item_scalars (self , session : Session , item_id : List [str ]) -> List [SQLModel ]:
339+ def _fetch_item_scalars (self , session : Session , item_id : List [str ]) -> List [ModelT ]:
319340 stmt = select (self .model ).where (self .pk .in_ (list (map (get_python_type_parse (self .pk ), item_id ))))
320341 return session .scalars (stmt ).all ()
321342
@@ -335,7 +356,7 @@ def schema_name_prefix(self):
335356 return self .model .__name__
336357 return super ().schema_name_prefix
337358
338- async def on_create_pre (self , request : Request , obj : BaseModel , ** kwargs ) -> Dict [str , Any ]:
359+ async def on_create_pre (self , request : Request , obj : SchemaCreateT , ** kwargs ) -> Dict [str , Any ]:
339360 data_dict = obj .dict (by_alias = True ) # exclude=set(self.pk)
340361 if self .pk_name in data_dict and not data_dict .get (self .pk_name ):
341362 del data_dict [self .pk_name ]
@@ -344,15 +365,15 @@ async def on_create_pre(self, request: Request, obj: BaseModel, **kwargs) -> Dic
344365 async def on_update_pre (
345366 self ,
346367 request : Request ,
347- obj : BaseModel ,
368+ obj : SchemaUpdateT ,
348369 item_id : Union [List [str ], List [int ]],
349370 ** kwargs ,
350371 ) -> Dict [str , Any ]:
351372 data = obj .dict (exclude = self .update_exclude , exclude_unset = True , by_alias = True )
352373 data = {key : val for key , val in data .items () if val is not None or self .model .__fields__ [key ].allow_none }
353374 return data
354375
355- async def on_filter_pre (self , request : Request , obj : BaseModel , ** kwargs ) -> Dict [str , Any ]:
376+ async def on_filter_pre (self , request : Request , obj : SchemaFilterT , ** kwargs ) -> Dict [str , Any ]:
356377 return obj and {k : v for k , v in obj .dict (exclude_unset = True , by_alias = True ).items () if v is not None }
357378
358379 @property
0 commit comments