55from fastapi import HTTPException , status
66
77from fastadmin .models .base import BaseModelAdmin
8+ from fastadmin .models .helpers import get_admin_model
89from fastadmin .schemas .configuration import WidgetType
910from fastadmin .settings import settings
1011
1112
1213class TortoiseModelAdmin (BaseModelAdmin ):
13- async def save_model (self , obj : Any , payload : dict , add : bool = False ) -> None :
14+ async def save_model (self , id : str | None , payload : dict ) -> Any | None :
1415 """This method is used to save orm/db model object.
1516
16- :params obj : an orm/db model object.
17+ :params id : an id of object.
1718 :params payload: a payload from request.
18- :params add: a flag for add or update object.
19- :return: None.
19+ :return: A saved object or None.
2020 """
21+ fields = self .get_model_fields ()
22+ m2m_fields = [f for f , v in fields .items () if v .get ("is_m2m" , False )]
23+
24+ if id :
25+ obj = await self .model_cls .filter (id = id ).first ()
26+ if not obj :
27+ return None
28+ else :
29+ obj = self .model_cls ()
30+
31+ update_fields = []
2132 for key , value in payload .items ():
22- setattr (obj , key , value )
23- await obj .save (update_fields = payload .keys () if not add else None )
33+ if key not in m2m_fields :
34+ setattr (obj , key , value )
35+ update_fields .append (key )
36+
37+ await obj .save (update_fields = update_fields if id else None )
38+
39+ for key , values in payload .items ():
40+ if key in m2m_fields :
41+ m2m_rel = getattr (obj , key , None )
42+ if m2m_rel is None :
43+ continue
44+ remote_model = m2m_rel .remote_model
45+ await m2m_rel .clear ()
46+ remote_model_objs = []
47+ for id in values :
48+ remote_model_obj = remote_model ()
49+ setattr (remote_model_obj , remote_model ._meta .pk_attr , id )
50+ setattr (remote_model_obj , "_saved_in_db" , True )
51+ remote_model_objs .append (remote_model_obj )
52+ if remote_model_objs :
53+ await m2m_rel .add (* remote_model_objs )
54+ return obj
2455
2556 async def delete_model (self , obj : Any ) -> None :
2657 """This method is used to delete orm/db model object.
@@ -36,7 +67,20 @@ async def get_obj(self, id: str) -> Any | None:
3667 :params id: an id of object.
3768 :return: An object or None.
3869 """
39- return await self .model_cls .filter (id = id ).first ()
70+ fields = self .get_model_fields ()
71+ m2m_fields = [f for f , v in fields .items () if v .get ("is_m2m" , False )]
72+ obj = await self .model_cls .filter (id = id ).first ()
73+ if not obj :
74+ return obj
75+ obj_dict = {k : v for k , v in obj .__dict__ .items () if not k .startswith ("_" )}
76+ for field in m2m_fields :
77+ m2m_rel = getattr (obj , field , None )
78+ if m2m_rel is None :
79+ continue
80+ remote_model = m2m_rel .remote_model
81+ remote_ids = await m2m_rel .all ().values_list (remote_model ._meta .pk_attr , flat = True )
82+ obj_dict [field ] = remote_ids
83+ return obj_dict
4084
4185 async def get_list (
4286 self ,
@@ -126,8 +170,19 @@ def get_model_fields(self) -> OrderedDict[str, dict]:
126170 parent_model_label = None
127171 parent_model = getattr (field , "model_name" , "" ).rsplit ("." , 1 )[- 1 ] or None
128172 if parent_model :
173+ parent_admin_model = get_admin_model (parent_model )
129174 parent_model_id = "id"
130175 parent_model_label = "id"
176+ if parent_admin_model :
177+ parent_model_id = parent_admin_model .model_cls ._meta .pk_attr
178+ parent_model_label = parent_admin_model .model_cls ._meta .pk_attr
179+ parent_fields = parent_admin_model .model_cls ._meta .fields_db_projection .keys ()
180+ if "name" in parent_fields :
181+ parent_model_label = "name"
182+ elif "title" in parent_fields :
183+ parent_model_label = "title"
184+ elif "email" in parent_fields :
185+ parent_model_label = "email"
131186
132187 form_hidden = (
133188 getattr (field , "_generated" , False )
@@ -208,7 +263,7 @@ def get_form_widget(self, field_name: str) -> tuple[WidgetType, dict]:
208263 "format" : settings .ADMIN_TIME_FORMAT ,
209264 }
210265 case "ForeignKeyFieldInstance" :
211- if field in self .raw_id_fields :
266+ if field_name in self .raw_id_fields :
212267 return WidgetType .Input , widget_props
213268 return WidgetType .AsyncSelect , {
214269 ** widget_props ,
@@ -217,9 +272,9 @@ def get_form_widget(self, field_name: str) -> tuple[WidgetType, dict]:
217272 "labelField" : field .get ("parent_model_label" ) or "id" ,
218273 }
219274 case "ManyToManyFieldInstance" :
220- if field in self .raw_id_fields :
275+ if field_name in self .raw_id_fields :
221276 return WidgetType .Input , widget_props
222- if field in self .filter_vertical or field in self .filter_horizontal :
277+ if field_name in self .filter_vertical or field_name in self .filter_horizontal :
223278 return WidgetType .AsyncTransfer , {
224279 ** widget_props ,
225280 "required" : False ,
@@ -237,7 +292,7 @@ def get_form_widget(self, field_name: str) -> tuple[WidgetType, dict]:
237292 "labelField" : field .get ("parent_model_label" ) or "id" ,
238293 }
239294 case "OneToOneFieldInstance" :
240- if field in self .raw_id_fields :
295+ if field_name in self .raw_id_fields :
241296 return WidgetType .Input , widget_props
242297 return WidgetType .AsyncSelect , {
243298 ** widget_props ,
@@ -246,7 +301,7 @@ def get_form_widget(self, field_name: str) -> tuple[WidgetType, dict]:
246301 "labelField" : field .get ("parent_model_label" ) or "id" ,
247302 }
248303 case "CharEnumFieldInstance" :
249- if field in self .radio_fields :
304+ if field_name in self .radio_fields :
250305 return WidgetType .RadioGroup , {
251306 ** widget_props ,
252307 "options" : [{"label" : k , "value" : k } for k in field .get ("enum_type" ) or []],
0 commit comments