11import datetime
2+ import logging
23import re
34from enum import Enum
45from typing import Any , Callable , Dict , List , Optional , Pattern , Tuple , Type , Union
56
67from fastapi import APIRouter , Body , Depends , Query
8+ from fastapi .encoders import DictIntStrAny , SetIntStr
79from pydantic import BaseModel , Extra , Json
8- from sqlalchemy import Column , Table , delete , func , insert , update
10+ from sqlalchemy import Column , Table , delete , func , insert
911from sqlalchemy .future import select
1012from sqlalchemy .orm import InstrumentedAttribute , Session
1113from sqlalchemy .sql import Select
@@ -206,7 +208,10 @@ class SQLModelCrud(BaseCrud, SQLModelSelector):
206208 engine : SqlalchemyDatabase = None
207209 create_fields : List [SQLModelField ] = [] # 新增数据字段
208210 readonly_fields : List [SQLModelListField ] = [] # 只读字段
211+ """readonly fields, deprecated, not recommended, will be removed in version 0.4.0"""
209212 update_fields : List [SQLModelListField ] = [] # 可编辑字段
213+ update_exclude : Optional [Union [SetIntStr , DictIntStrAny ]] = None
214+ """update exclude fields, such as: {'id', 'key', 'name'} or {'id': True, 'category': {'id', 'name'}}"""
210215
211216 def __init__ (
212217 self ,
@@ -220,6 +225,11 @@ def __init__(
220225 self .db = get_engine_db (self .engine )
221226 SQLModelSelector .__init__ (self , model , fields )
222227 BaseCrud .__init__ (self , self .model , router )
228+ if self .readonly_fields :
229+ logging .warning (
230+ "readonly fields, deprecated, not recommended, will be removed in version 0.4.0."
231+ "Please replace them with update_fields and update_exclude."
232+ )
223233
224234 def _create_schema_list (self ):
225235 if self .schema_list :
@@ -291,11 +301,33 @@ def _create_schema_create(self):
291301 )
292302 return schema_create_by_modelfield (f"{ self .schema_name_prefix } Create" , modelfields )
293303
294- def _read_items (self , session : Session , item_id : List [str ]):
295- stmt = select (self .model ).where (self .pk .in_ (list (map (get_python_type_parse (self .pk ), item_id ))))
296- items = session .scalars (stmt ).all ()
304+ def read_item (self , obj : SQLModel ) -> BaseModel :
305+ """read database data and parse to schema_read"""
297306 parse = self .schema_read .from_orm if self .schema_read .Config .orm_mode else self .schema_read .parse_obj
298- return [parse (obj ) for obj in items ]
307+ return parse (obj )
308+
309+ def update_item (self , obj : SQLModel , values : Dict [str , Any ]) -> None :
310+ """update schema_update data to database,support relational attributes"""
311+ for k , v in values .items ():
312+ if isinstance (v , dict ) and hasattr (obj , k ):
313+ # Relational attributes, nested;such as: setattr(article.content, "body", "new body")
314+ self .update_item (getattr (obj , k ), v )
315+ else :
316+ setattr (obj , k , v )
317+
318+ def _fetch_item_scalars (self , session : Session , item_id : List [str ]) -> List [SQLModel ]:
319+ stmt = select (self .model ).where (self .pk .in_ (list (map (get_python_type_parse (self .pk ), item_id ))))
320+ return session .scalars (stmt ).all ()
321+
322+ def _read_items (self , session : Session , item_id : List [str ]):
323+ items = self ._fetch_item_scalars (session , item_id )
324+ return [self .read_item (obj ) for obj in items ]
325+
326+ def _update_items (self , session : Session , item_id : List [str ], values : Dict [str , Any ]):
327+ items = self ._fetch_item_scalars (session , item_id )
328+ for item in items :
329+ self .update_item (item , values )
330+ return len (items )
299331
300332 @property
301333 def schema_name_prefix (self ):
@@ -316,7 +348,7 @@ async def on_update_pre(
316348 item_id : Union [List [str ], List [int ]],
317349 ** kwargs ,
318350 ) -> Dict [str , Any ]:
319- data = obj .dict (exclude_unset = True , by_alias = True )
351+ data = obj .dict (exclude = self . update_exclude , exclude_unset = True , by_alias = True )
320352 data = {key : val for key , val in data .items () if val is not None or self .model .__fields__ [key ].allow_none }
321353 return data
322354
@@ -415,13 +447,11 @@ async def route(
415447 if not await self .has_update_permission (request , item_id , data ):
416448 return self .error_no_router_permission (request )
417449 item_id = list (map (get_python_type_parse (self .pk ), item_id ))
418- stmt = update (self .model ).where (self .pk .in_ (item_id ))
419450 values = await self .on_update_pre (request , data , item_id = item_id )
420451 if not values :
421452 return self .error_data_handle (request )
422- stmt = stmt .values (values )
423- result = await self .db .async_execute (stmt )
424- return BaseApiOut (data = getattr (result , "rowcount" , None ))
453+ result = await self .db .async_run_sync (self ._update_items , item_id , values )
454+ return BaseApiOut (data = result )
425455
426456 return route
427457
0 commit comments