|
| 1 | +import contextlib |
1 | 2 | from typing import Any |
2 | 3 | from uuid import UUID |
3 | 4 |
|
4 | | -from sqlalchemy import and_, func, inspect, or_, select, text |
| 5 | +from sqlalchemy import BIGINT, Integer, and_, func, inspect, or_, select, text |
5 | 6 | from sqlalchemy.orm import selectinload |
6 | 7 |
|
7 | 8 | from fastadmin.models.base import InlineModelAdmin, ModelAdmin |
@@ -261,21 +262,27 @@ def convert_sort_by(sort_by: str) -> str: |
261 | 262 | for field_with_condition, value in filters.items(): |
262 | 263 | field = field_with_condition[0] |
263 | 264 | condition = field_with_condition[1] |
| 265 | + model_field = getattr(self.model_cls, field) |
| 266 | + |
| 267 | + if isinstance(model_field.expression.type, BIGINT | Integer): |
| 268 | + with contextlib.suppress(ValueError): |
| 269 | + value = int(value) |
| 270 | + |
264 | 271 | match condition: |
265 | 272 | case "lte": |
266 | | - q.append(getattr(self.model_cls, field) >= value) |
| 273 | + q.append(model_field >= value) |
267 | 274 | case "gte": |
268 | | - q.append(getattr(self.model_cls, field) <= value) |
| 275 | + q.append(model_field <= value) |
269 | 276 | case "lt": |
270 | | - q.append(getattr(self.model_cls, field) > value) |
| 277 | + q.append(model_field > value) |
271 | 278 | case "gt": |
272 | | - q.append(getattr(self.model_cls, field) < value) |
| 279 | + q.append(model_field < value) |
273 | 280 | case "exact": |
274 | | - q.append(getattr(self.model_cls, field) == value) |
| 281 | + q.append(model_field == value) |
275 | 282 | case "contains": |
276 | | - q.append(getattr(self.model_cls, field).like(f"%{value}%")) |
| 283 | + q.append(model_field.like(f"%{value}%")) |
277 | 284 | case "icontains": |
278 | | - q.append(getattr(self.model_cls, field).ilike(f"%{value}%")) |
| 285 | + q.append(model_field.ilike(f"%{value}%")) |
279 | 286 | qs = qs.filter(and_(*q)) |
280 | 287 |
|
281 | 288 | if search and self.search_fields: |
@@ -313,13 +320,26 @@ async def orm_get_obj(self, id: UUID | int) -> Any | None: |
313 | 320 | async with sessionmaker() as session: |
314 | 321 | return await session.get(self.model_cls, id) |
315 | 322 |
|
| 323 | + def _get_foreign_key_fields(self) -> list[str]: |
| 324 | + """Returns a list of foreign key fields for the model. |
| 325 | +
|
| 326 | + :return: List of foreign key field names. |
| 327 | + """ |
| 328 | + return [column.name for column in self.model_cls.__table__.columns if column.foreign_keys] |
| 329 | + |
316 | 330 | async def orm_save_obj(self, id: UUID | Any | None, payload: dict) -> Any: |
317 | 331 | """This method is used to save orm/db model object. |
318 | 332 |
|
319 | 333 | :params id: an id of object. |
320 | 334 | :params payload: a dict of payload. |
321 | 335 | :return: An object. |
322 | 336 | """ |
| 337 | + for fk_field_name in self._get_foreign_key_fields(): |
| 338 | + if fk_field_name in payload and isinstance(payload[fk_field_name], str): |
| 339 | + with contextlib.suppress(ValueError): |
| 340 | + # convert string to int for foreign key fields for postgresql alchemy |
| 341 | + payload[fk_field_name] = int(payload[fk_field_name]) |
| 342 | + |
323 | 343 | sessionmaker = self.get_sessionmaker() |
324 | 344 | async with sessionmaker() as session: |
325 | 345 | if id: |
|
0 commit comments