diff --git a/docs/build.py b/docs/build.py
index ac70441..1866efe 100644
--- a/docs/build.py
+++ b/docs/build.py
@@ -594,6 +594,34 @@ def get_page_context(page_url):
"content": "Model-admin-specific methods and attributes:",
},
{"type": "code-python", "content": inspect.getsource(ModelAdmin)},
+ {
+ "type": "text",
+ "content": "You can customize relation loading and relation search by overriding orm_get_list and forwarding prefetch_related_fields and additional_search_fields:",
+ },
+ {
+ "type": "code-python",
+ "content": """class TaskAdmin(TortoiseModelAdmin):
+ search_fields = ("title",)
+
+ async def orm_get_list(
+ self,
+ offset=None,
+ limit=None,
+ search=None,
+ sort_by=None,
+ filters=None,
+ ):
+ return await super().orm_get_list(
+ offset=offset,
+ limit=limit,
+ search=search,
+ sort_by=sort_by,
+ filters=filters,
+ prefetch_related_fields=["user"],
+ additional_search_fields=["user__email"],
+ )
+""",
+ },
]
case "#model-form-field-types":
return [
diff --git a/docs/index.html b/docs/index.html
index 3cfb315..8d06fa8 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -340,7 +340,7 @@
+ You can customize relation loading and relation search by overriding orm_get_list and forwarding prefetch_related_fields and additional_search_fields:
+
+
+class TaskAdmin(TortoiseModelAdmin):
+ search_fields = ("title",)
+
+ async def orm_get_list(
+ self,
+ offset=None,
+ limit=None,
+ search=None,
+ sort_by=None,
+ filters=None,
+ ):
+ return await super().orm_get_list(
+ offset=offset,
+ limit=limit,
+ search=search,
+ sort_by=sort_by,
+ filters=filters,
+ prefetch_related_fields=["user"],
+ additional_search_fields=["user__email"],
+ )
+
+
+
+
+
+
+
diff --git a/fastadmin/models/base.py b/fastadmin/models/base.py
index d5815ad..53c6b51 100644
--- a/fastadmin/models/base.py
+++ b/fastadmin/models/base.py
@@ -227,6 +227,8 @@ async def orm_get_list(
search: str | None = None,
sort_by: str | None = None,
filters: dict | None = None,
+ prefetch_related_fields: list[str] | None = None,
+ additional_search_fields: list[str] | None = None,
) -> tuple[list[Any], int]:
"""This method is used to get list of orm/db model objects.
@@ -235,6 +237,8 @@ async def orm_get_list(
:params search: a search query.
:params sort_by: a sort by field name.
:params filters: a dict of filters.
+ :params prefetch_related_fields: a list of related fields to prefetch.
+ :params additional_search_fields: a list of additional search fields.
:return: A tuple of list of objects and total count.
"""
raise NotImplementedError
diff --git a/fastadmin/models/orms/django.py b/fastadmin/models/orms/django.py
index f9148b2..6ca6eee 100644
--- a/fastadmin/models/orms/django.py
+++ b/fastadmin/models/orms/django.py
@@ -1,4 +1,3 @@
-import operator
from base64 import b64decode
from typing import Any
from uuid import UUID
@@ -245,6 +244,8 @@ def orm_get_list(
search: str | None = None,
sort_by: str | None = None,
filters: dict | None = None,
+ prefetch_related_fields: list[str] | None = None,
+ additional_search_fields: list[str] | None = None,
) -> tuple[list[Any], int]:
"""This method is used to get list of orm/db model objects.
@@ -253,19 +254,30 @@ def orm_get_list(
:params search: a search query.
:params sort_by: a sort by field name.
:params filters: a dict of filters.
+ :params prefetch_related_fields: a list of related fields to prefetch.
+ :params additional_search_fields: a list of additional search fields.
:return: A tuple of list of objects and total count.
"""
qs = self.model_cls.objects.all()
+ if prefetch_related_fields:
+ qs = qs.prefetch_related(*prefetch_related_fields)
+
if filters:
for field_with_condition, value in filters.items():
field = field_with_condition[0]
condition = field_with_condition[1]
qs = qs.filter(**{f"{field}__{condition}" if condition != "exact" else field: value})
- if search and self.search_fields:
- search_conditions = [Q(**{f + "__icontains": search}) for f in self.search_fields]
- search_q = search_conditions[0] if len(search_conditions) == 1 else operator.or_(*search_conditions)
+ search_fields = list(self.search_fields)
+ if additional_search_fields:
+ search_fields.extend(additional_search_fields)
+
+ if search and search_fields:
+ search_conditions = [Q(**{f + "__icontains": search}) for f in search_fields]
+ search_q = search_conditions[0]
+ for condition in search_conditions[1:]:
+ search_q |= condition
qs = qs.filter(search_q)
if sort_by:
diff --git a/fastadmin/models/orms/ponyorm.py b/fastadmin/models/orms/ponyorm.py
index 11bff45..c803612 100644
--- a/fastadmin/models/orms/ponyorm.py
+++ b/fastadmin/models/orms/ponyorm.py
@@ -224,6 +224,8 @@ def orm_get_list(
search: str | None = None,
sort_by: str | None = None,
filters: dict | None = None,
+ prefetch_related_fields: list[str] | None = None,
+ additional_search_fields: list[str] | None = None,
) -> tuple[list[Any], int]:
"""This method is used to get list of orm/db model objects.
@@ -232,6 +234,8 @@ def orm_get_list(
:params search: a search query.
:params sort_by: a sort by field name.
:params filters: a dict of filters.
+ :params prefetch_related_fields: a list of related fields to prefetch.
+ :params additional_search_fields: a list of additional search fields.
:return: A tuple of list of objects and total count.
"""
@@ -271,9 +275,13 @@ def orm_get_list(
filter_expr = f""""{value}" {pony_condition} m.{field}"""
qs = qs.filter(filter_expr)
- if search and self.search_fields:
+ search_fields = list(self.search_fields)
+ if additional_search_fields:
+ search_fields.extend(additional_search_fields)
+
+ if search and search_fields:
ids = []
- for search_field in self.search_fields:
+ for search_field in search_fields:
pony_search_field = search_field.replace("__", ".")
# Pony string filter for case-insensitive search
filter_expr = f'"{search.lower()}" in m.{pony_search_field}.lower()'
@@ -296,6 +304,9 @@ def orm_get_list(
if self.list_select_related:
qs = qs.prefetch(*[getattr(self.model_cls, field) for field in self.list_select_related])
+ if prefetch_related_fields:
+ qs = qs.prefetch(*[getattr(self.model_cls, field) for field in prefetch_related_fields])
+
if offset is not None and limit is not None:
qs = qs.limit(limit, offset=offset)
diff --git a/fastadmin/models/orms/sqlalchemy.py b/fastadmin/models/orms/sqlalchemy.py
index 722cca4..b34b335 100644
--- a/fastadmin/models/orms/sqlalchemy.py
+++ b/fastadmin/models/orms/sqlalchemy.py
@@ -300,6 +300,8 @@ async def orm_get_list(
search: str | None = None,
sort_by: str | None = None,
filters: dict | None = None,
+ prefetch_related_fields: list[str] | None = None,
+ additional_search_fields: list[str] | None = None,
) -> tuple[list[Any], int]:
"""This method is used to get list of orm/db model objects.
@@ -308,6 +310,8 @@ async def orm_get_list(
:params search: a search query.
:params sort_by: a sort by field name.
:params filters: a dict of filters.
+ :params prefetch_related_fields: a list of related fields to prefetch.
+ :params additional_search_fields: a list of additional search fields.
:return: A tuple of list of objects and total count.
"""
@@ -354,9 +358,13 @@ def convert_sort_by(sort_by: str) -> str:
q.append(model_field.ilike(f"%{value}%"))
qs = qs.where(and_(*q))
- if search and self.search_fields:
+ search_fields = list(self.search_fields)
+ if additional_search_fields:
+ search_fields.extend(additional_search_fields)
+
+ if search and search_fields:
q = []
- for field in self.search_fields:
+ for field in search_fields:
condition = self._build_search_condition(field, search)
if condition is not None:
q.append(condition)
@@ -377,6 +385,28 @@ def convert_sort_by(sort_by: str) -> str:
for field in self.list_select_related:
qs = qs.options(selectinload(getattr(self.model_cls, field)))
+ if prefetch_related_fields:
+ for field_path in prefetch_related_fields:
+ parts = field_path.split("__")
+ current_model = self.model_cls
+ attr = getattr(current_model, parts[0], None)
+ if attr is None:
+ continue
+ option = selectinload(attr)
+ current_model = getattrs(attr, "property.mapper.class_")
+ for part in parts[1:]:
+ if current_model is None:
+ break
+ nested_attr = getattr(current_model, part, None)
+ if nested_attr is None:
+ break
+ next_model = getattrs(nested_attr, "property.mapper.class_")
+ if next_model is None:
+ break
+ option = option.selectinload(nested_attr)
+ current_model = next_model
+ qs = qs.options(option)
+
if offset is not None and limit is not None:
qs = qs.offset(offset)
qs = qs.limit(limit)
diff --git a/fastadmin/models/orms/tortoise.py b/fastadmin/models/orms/tortoise.py
index dda56fa..02ce4fe 100644
--- a/fastadmin/models/orms/tortoise.py
+++ b/fastadmin/models/orms/tortoise.py
@@ -254,6 +254,8 @@ async def orm_get_list(
search: str | None = None,
sort_by: str | None = None,
filters: dict | None = None,
+ prefetch_related_fields: list[str] | None = None,
+ additional_search_fields: list[str] | None = None,
) -> tuple[list[Any], int]:
"""This method is used to get list of orm/db model objects.
@@ -262,21 +264,30 @@ async def orm_get_list(
:params search: a search query.
:params sort_by: a sort by field name.
:params filters: a dict of filters.
+ :params prefetch_related_fields: a list of related fields to prefetch.
+ :params additional_search_fields: a list of additional search fields.
:return: A tuple of list of objects and total count.
"""
qs = self.model_cls.all()
+ if prefetch_related_fields:
+ qs = qs.prefetch_related(*prefetch_related_fields).distinct()
+
if filters:
for field_with_condition, value in filters.items():
field = field_with_condition[0]
condition = field_with_condition[1]
qs = qs.filter(**{f"{field}__{condition}" if condition != "exact" else field: value})
- if search and self.search_fields:
+ search_fields = list(self.search_fields)
+ if additional_search_fields:
+ search_fields.extend(additional_search_fields)
+
+ if search and search_fields:
qs = qs.filter(
functools.reduce(
operator.or_,
- (Q(**{f + "__icontains": search}) for f in self.search_fields),
+ (Q(**{f + "__icontains": search}) for f in search_fields),
Q(),
)
)
diff --git a/tests/models/test_orm.py b/tests/models/test_orm.py
index c1530bf..5dd019e 100644
--- a/tests/models/test_orm.py
+++ b/tests/models/test_orm.py
@@ -456,6 +456,64 @@ async def test_sqlalchemy_orm_get_list_search_nested_relation(event, session_wit
assert any(getattr(obj, "id", None) == event.id for obj in objs)
+async def test_sqlalchemy_orm_get_list_supports_additional_search_and_prefetch(event, session_with_type):
+ _, session_type = session_with_type
+ if session_type != "sqlalchemy":
+ return
+
+ admin_model = get_admin_model(event.__class__)
+ admin_model.search_fields = ["name"]
+ objs, total = await admin_model.orm_get_list(
+ search="Test Tournament",
+ prefetch_related_fields=["tournament"],
+ additional_search_fields=["tournament__name"],
+ )
+
+ assert isinstance(total, int)
+ assert total > 0
+ assert any(getattr(obj, "id", None) == event.id for obj in objs)
+
+
+async def test_sqlalchemy_orm_get_list_prefetch_edge_cases(event, session_with_type):
+ _, session_type = session_with_type
+ if session_type != "sqlalchemy":
+ return
+
+ admin_model = get_admin_model(event.__class__)
+ objs, total = await admin_model.orm_get_list(
+ prefetch_related_fields=[
+ "does_not_exist",
+ "tournament__does_not_exist",
+ "tournament__events",
+ "tournament__name",
+ ]
+ )
+ assert isinstance(total, int)
+ assert isinstance(objs, list)
+
+
+async def test_sqlalchemy_orm_get_list_prefetch_no_related_model(event, session_with_type, monkeypatch):
+ _, session_type = session_with_type
+ if session_type != "sqlalchemy":
+ return
+
+ from fastadmin.models.orms import sqlalchemy as sqlalchemy_orm
+
+ admin_model = get_admin_model(event.__class__)
+ original_getattrs = sqlalchemy_orm.getattrs
+
+ def fake_getattrs(obj, attr_path, default=None):
+ if attr_path == "property.mapper.class_":
+ return None
+ return original_getattrs(obj, attr_path, default=default)
+
+ monkeypatch.setattr(sqlalchemy_orm, "getattrs", fake_getattrs)
+
+ objs, total = await admin_model.orm_get_list(prefetch_related_fields=["tournament__events"])
+ assert isinstance(total, int)
+ assert isinstance(objs, list)
+
+
def test_sqlalchemy_resolve_ordering_field_for_relation(monkeypatch):
from types import SimpleNamespace
@@ -748,6 +806,24 @@ async def test_ponyorm_orm_get_list_search_nested_relation(event, session_with_t
assert any(getattr(obj, "id", None) == event.id for obj in objs)
+async def test_ponyorm_orm_get_list_supports_additional_search_and_prefetch(event, session_with_type):
+ _, session_type = session_with_type
+ if session_type != "ponyorm":
+ return
+
+ admin_model = get_admin_model(event.__class__)
+ admin_model.search_fields = ["name"]
+ objs, total = await admin_model.orm_get_list(
+ search="Test Tournament",
+ prefetch_related_fields=["tournament"],
+ additional_search_fields=["tournament__name"],
+ )
+
+ assert isinstance(total, int)
+ assert total > 0
+ assert any(getattr(obj, "id", None) == event.id for obj in objs)
+
+
async def test_ponyorm_edge_cases(event, session_with_type):
from types import SimpleNamespace
@@ -964,6 +1040,44 @@ def test_tortoise_resolve_ordering_field_edge_cases():
assert admin._resolve_ordering_field("unknown") == "unknown"
+async def test_tortoise_orm_get_list_supports_additional_search_and_prefetch(event, session_with_type):
+ _, session_type = session_with_type
+ if session_type != "tortoiseorm":
+ return
+
+ from fastadmin.models.orms.tortoise import TortoiseModelAdmin
+
+ admin_model = TortoiseModelAdmin(event.__class__)
+ admin_model.search_fields = ["name"]
+ objs, total = await admin_model.orm_get_list(
+ search="Test Tournament",
+ prefetch_related_fields=["tournament"],
+ additional_search_fields=["tournament__name"],
+ )
+
+ assert isinstance(total, int)
+ assert total > 0
+ assert any(getattr(obj, "id", None) == event.id for obj in objs)
+
+
+async def test_django_orm_get_list_supports_additional_search_and_prefetch(event, session_with_type):
+ _, session_type = session_with_type
+ if session_type != "djangoorm":
+ return
+
+ admin_model = get_admin_model(event.__class__)
+ admin_model.search_fields = ["name"]
+ objs, total = await admin_model.orm_get_list(
+ search="Test Tournament",
+ prefetch_related_fields=["tournament"],
+ additional_search_fields=["tournament__name"],
+ )
+
+ assert isinstance(total, int)
+ assert total > 0
+ assert any(getattr(obj, "id", None) == event.id for obj in objs)
+
+
def test_django_field_mapping_special_cases():
from types import SimpleNamespace