Skip to content

Commit 178fc76

Browse files
committed
feat: New feature, support updating relational attributes.
1 parent b47bc4a commit 178fc76

4 files changed

Lines changed: 165 additions & 38 deletions

File tree

fastapi_amis_admin/crud/_sqlmodel.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import datetime
2+
import logging
23
import re
34
from enum import Enum
45
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Type, Union
56

67
from fastapi import APIRouter, Body, Depends, Query
8+
from fastapi.encoders import DictIntStrAny, SetIntStr
79
from pydantic import BaseModel, Extra, Json
8-
from sqlalchemy import Column, Table, delete, func, insert, update
10+
from sqlalchemy import Column, Table, delete, func, insert
911
from sqlalchemy.future import select
1012
from sqlalchemy.orm import InstrumentedAttribute, Session
1113
from sqlalchemy.sql import Select
@@ -206,7 +208,10 @@ class SQLModelCrud(BaseCrud, SQLModelSelector):
206208
engine: SqlalchemyDatabase = None
207209
create_fields: List[SQLModelField] = [] # 新增数据字段
208210
readonly_fields: List[SQLModelListField] = [] # 只读字段
211+
"""readonly fields, deprecated, not recommended, will be removed in version 0.4.0"""
209212
update_fields: List[SQLModelListField] = [] # 可编辑字段
213+
update_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None
214+
"""update exclude fields, such as: {'id', 'key', 'name'} or {'id': True, 'category': {'id', 'name'}}"""
210215

211216
def __init__(
212217
self,
@@ -220,6 +225,11 @@ def __init__(
220225
self.db = get_engine_db(self.engine)
221226
SQLModelSelector.__init__(self, model, fields)
222227
BaseCrud.__init__(self, self.model, router)
228+
if self.readonly_fields:
229+
logging.warning(
230+
"readonly fields, deprecated, not recommended, will be removed in version 0.4.0."
231+
"Please replace them with update_fields and update_exclude."
232+
)
223233

224234
def _create_schema_list(self):
225235
if self.schema_list:
@@ -291,11 +301,33 @@ def _create_schema_create(self):
291301
)
292302
return schema_create_by_modelfield(f"{self.schema_name_prefix}Create", modelfields)
293303

294-
def _read_items(self, session: Session, item_id: List[str]):
295-
stmt = select(self.model).where(self.pk.in_(list(map(get_python_type_parse(self.pk), item_id))))
296-
items = session.scalars(stmt).all()
304+
def read_item(self, obj: SQLModel) -> BaseModel:
305+
"""read database data and parse to schema_read"""
297306
parse = self.schema_read.from_orm if self.schema_read.Config.orm_mode else self.schema_read.parse_obj
298-
return [parse(obj) for obj in items]
307+
return parse(obj)
308+
309+
def update_item(self, obj: SQLModel, values: Dict[str, Any]) -> None:
310+
"""update schema_update data to database,support relational attributes"""
311+
for k, v in values.items():
312+
if isinstance(v, dict) and hasattr(obj, k):
313+
# Relational attributes, nested;such as: setattr(article.content, "body", "new body")
314+
self.update_item(getattr(obj, k), v)
315+
else:
316+
setattr(obj, k, v)
317+
318+
def _fetch_item_scalars(self, session: Session, item_id: List[str]) -> List[SQLModel]:
319+
stmt = select(self.model).where(self.pk.in_(list(map(get_python_type_parse(self.pk), item_id))))
320+
return session.scalars(stmt).all()
321+
322+
def _read_items(self, session: Session, item_id: List[str]):
323+
items = self._fetch_item_scalars(session, item_id)
324+
return [self.read_item(obj) for obj in items]
325+
326+
def _update_items(self, session: Session, item_id: List[str], values: Dict[str, Any]):
327+
items = self._fetch_item_scalars(session, item_id)
328+
for item in items:
329+
self.update_item(item, values)
330+
return len(items)
299331

300332
@property
301333
def schema_name_prefix(self):
@@ -316,7 +348,7 @@ async def on_update_pre(
316348
item_id: Union[List[str], List[int]],
317349
**kwargs,
318350
) -> Dict[str, Any]:
319-
data = obj.dict(exclude_unset=True, by_alias=True)
351+
data = obj.dict(exclude=self.update_exclude, exclude_unset=True, by_alias=True)
320352
data = {key: val for key, val in data.items() if val is not None or self.model.__fields__[key].allow_none}
321353
return data
322354

@@ -415,13 +447,11 @@ async def route(
415447
if not await self.has_update_permission(request, item_id, data):
416448
return self.error_no_router_permission(request)
417449
item_id = list(map(get_python_type_parse(self.pk), item_id))
418-
stmt = update(self.model).where(self.pk.in_(item_id))
419450
values = await self.on_update_pre(request, data, item_id=item_id)
420451
if not values:
421452
return self.error_data_handle(request)
422-
stmt = stmt.values(values)
423-
result = await self.db.async_execute(stmt)
424-
return BaseApiOut(data=getattr(result, "rowcount", None))
453+
result = await self.db.async_run_sync(self._update_items, item_id, values)
454+
return BaseApiOut(data=result)
425455

426456
return route
427457

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import pytest
12
from sqlalchemy_database import AsyncDatabase, Database
23

34
# sqlite
45
sync_db = Database.create("sqlite:///amisadmin.db?check_same_thread=False")
56
async_db = AsyncDatabase.create("sqlite+aiosqlite:///amisadmin.db?check_same_thread=False")
67

8+
79
# mysql
810
# sync_db = Database.create('mysql+pymysql://root:123456@127.0.0.1:3306/amisadmin?charset=utf8mb4')
911
# async_db = AsyncDatabase.create('mysql+aiomysql://root:123456@127.0.0.1:3306/amisadmin?charset=utf8mb4')
@@ -17,3 +19,15 @@
1719

1820
# SQL Server
1921
# sync_db = Database.create('mssql+pyodbc://scott:tiger@mydsn')
22+
23+
24+
@pytest.fixture
25+
def session():
26+
with sync_db.session_maker() as session:
27+
yield session
28+
29+
30+
@pytest.fixture
31+
async def async_session():
32+
async with async_db.session_maker() as session:
33+
yield session

tests/test_crud/conftest.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
import pytest
55
from fastapi import FastAPI
66
from httpx import AsyncClient
7-
from sqlalchemy import insert
87
from sqlmodel import SQLModel
98

109
from tests.conftest import async_db as db
11-
from tests.models import Article, User
10+
from tests.models import Article, ArticleContent, Category, User
1211

1312
pytestmark = pytest.mark.asyncio
1413

@@ -33,30 +32,43 @@ async def async_client(app: FastAPI, prepare_database: Any) -> AsyncGenerator[As
3332

3433

3534
@pytest.fixture
36-
async def fake_users() -> List[User]:
35+
async def fake_users(async_session) -> List[User]:
3736
data = [
38-
{
39-
"id": i,
40-
"username": f"User_{i}",
41-
"password": f"password_{i}",
42-
"create_time": datetime.datetime.strptime(f"2022-01-0{i} 00:00:00", "%Y-%m-%d %H:%M:%S"),
43-
}
37+
User(
38+
id=i,
39+
username=f"User_{i}",
40+
password=f"password_{i}",
41+
create_time=datetime.datetime.strptime(f"2022-01-0{i} 00:00:00", "%Y-%m-%d %H:%M:%S"),
42+
)
4443
for i in range(1, 6)
4544
]
46-
await db.execute(insert(User).values(data))
47-
return [User.parse_obj(item) for item in data]
45+
async_session.add_all(data)
46+
await async_session.commit()
47+
return data
4848

4949

5050
@pytest.fixture
51-
async def fake_articles(fake_users) -> List[Article]:
51+
async def fake_categorys(async_session) -> List[Category]:
52+
data = [Category(id=i, name=f"Category_{i}") for i in range(1, 6)]
53+
async_session.add_all(data)
54+
await async_session.commit()
55+
return data
56+
57+
58+
@pytest.fixture
59+
async def fake_article_contents(async_session) -> List[ArticleContent]:
60+
data = [ArticleContent(id=i, content=f"Content_{i}") for i in range(1, 6)]
61+
async_session.add_all(data)
62+
await async_session.commit()
63+
return data
64+
65+
66+
@pytest.fixture
67+
async def fake_articles(async_session, fake_users, fake_categorys, fake_article_contents) -> List[Article]:
5268
data = [
53-
{
54-
"id": user.id,
55-
"title": f"Article_{user.id}",
56-
"description": f"Description_{user.id}",
57-
"user_id": user.id,
58-
}
59-
for user in fake_users
69+
Article(id=i, title=f"Article_{i}", description=f"Description_{i}", user_id=i, category_id=i, content_id=i)
70+
for i in range(1, 6)
6071
]
61-
await db.execute(insert(Article).values(data))
62-
return [Article.parse_obj(item) for item in data]
72+
async_session.add_all(data)
73+
await async_session.commit()
74+
return data

tests/test_crud/test_SQLModelCrud_schemas.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1+
from typing import Optional
2+
13
from fastapi import FastAPI
24
from httpx import AsyncClient
35
from pydantic import BaseModel
6+
from sqlmodel import SQLModel
47

58
from fastapi_amis_admin.crud import SQLModelCrud
69
from tests.conftest import async_db as db
7-
from tests.models import User
8-
9-
10-
class UserFilter(BaseModel):
11-
id: int = None
12-
name: str = None
10+
from tests.models import Article, ArticleContent, Category, User
1311

1412

1513
async def test_schema_update(app: FastAPI, async_client: AsyncClient, fake_users):
@@ -120,7 +118,7 @@ class UserCrud(SQLModelCrud):
120118
assert "password" not in items
121119

122120

123-
# todo perfect
121+
# todo perfect;test more comparison operators
124122
async def test_schema_filter(app: FastAPI, async_client: AsyncClient, fake_users):
125123
class UserFilter(BaseModel):
126124
id: int = None
@@ -152,3 +150,76 @@ class UserCrud(SQLModelCrud):
152150
res = await async_client.post("/user/list", json={"password": "new_password"})
153151
items = res.json()["data"]["items"]
154152
assert items
153+
154+
155+
async def test_schema_read_relationship(app: FastAPI, async_client: AsyncClient, fake_articles):
156+
class ArticleRead(SQLModel): # must be SQLModel, not BaseModel
157+
id: int
158+
title: str
159+
description: str
160+
category: Optional[Category] = None # Relationship
161+
content: Optional[ArticleContent] = None # Relationship
162+
user: Optional[User] = None # Relationship
163+
164+
class ArticleCrud(SQLModelCrud):
165+
router_prefix = "/article"
166+
schema_read = ArticleRead
167+
168+
ins = ArticleCrud(Article, db.engine).register_crud()
169+
170+
app.include_router(ins.router)
171+
172+
# test schemas
173+
openapi = app.openapi()
174+
schemas = openapi["components"]["schemas"]
175+
assert "category" in schemas["ArticleRead"]["properties"]
176+
assert schemas["ArticleRead"]["properties"]["category"]["$ref"] == "#/components/schemas/Category"
177+
assert "content" in schemas["ArticleRead"]["properties"]
178+
assert "user" in schemas["ArticleRead"]["properties"]
179+
180+
# test api
181+
res = await async_client.get("/article/item/1")
182+
items = res.json()["data"]
183+
assert items["id"] == 1
184+
assert "category" in items
185+
assert "content" in items
186+
assert "user" in items
187+
assert items["category"]["id"] == 1
188+
assert items["content"]["id"] == 1
189+
assert items["user"]["id"] == 1
190+
191+
192+
async def test_schema_update_relationship(app: FastAPI, async_client: AsyncClient, fake_articles, async_session):
193+
class ArticleUpdate(SQLModel): # must be SQLModel, not BaseModel
194+
title: str = None
195+
description: str = None
196+
content: Optional[ArticleContent] = None # Relationship
197+
198+
class ArticleCrud(SQLModelCrud):
199+
router_prefix = "/article"
200+
update_exclude = {"content": {"id"}}
201+
schema_update = ArticleUpdate
202+
203+
ins = ArticleCrud(Article, db.engine).register_crud()
204+
205+
app.include_router(ins.router)
206+
207+
# test schemas
208+
openapi = app.openapi()
209+
schemas = openapi["components"]["schemas"]
210+
211+
assert "content" in schemas["ArticleUpdate"]["properties"]
212+
assert schemas["ArticleUpdate"]["properties"]["content"]["$ref"] == "#/components/schemas/ArticleContent"
213+
214+
# test api
215+
res = await async_client.put("/article/item/1", json={"title": "new_title"})
216+
assert res.json()["data"] == 1
217+
article = await async_session.get(Article, 1)
218+
assert article.title == "new_title"
219+
220+
res = await async_client.put(
221+
"/article/item/1", json={"content": {"id": 2, "content": "new_content"}} # will be ignored by `update_exclude`
222+
)
223+
assert res.json()["data"] == 1
224+
content = await async_session.get(ArticleContent, 1)
225+
assert content.content == "new_content"

0 commit comments

Comments
 (0)