Skip to content

Commit c2d73f6

Browse files
authored
fix: async middleware is equivalent to sync middleware - DIA-61984 (#107)
1 parent 0deb5b9 commit c2d73f6

11 files changed

Lines changed: 485 additions & 343 deletions

fastapi_sqla/_pytest_plugin.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ def db_url(db_host, db_user):
5252

5353

5454
@fixture(scope="session")
55-
def sqla_connection(db_url):
56-
engine = create_engine(db_url)
55+
def engine(db_url):
56+
return create_engine(db_url)
57+
58+
59+
@fixture(scope="session")
60+
def sqla_connection(engine):
5761
with engine.connect() as connection:
5862
yield connection
5963

@@ -92,19 +96,19 @@ def sqla_modules():
9296

9397

9498
@fixture
95-
def sqla_reflection(sqla_modules, sqla_connection, db_url):
99+
def sqla_reflection(sqla_modules, sqla_connection):
96100
import fastapi_sqla
97101

98102
fastapi_sqla.Base.metadata.bind = sqla_connection
99103
fastapi_sqla.Base.prepare(sqla_connection.engine)
100104

101105

102106
@fixture
103-
def patch_engine_from_config(request, db_url, sqla_connection, sqla_transaction):
107+
def patch_engine_from_config(request, sqla_connection, sqla_transaction):
104108
"""So that all DB operations are never written to db for real."""
105109
from fastapi_sqla.sqla import _Session
106110

107-
if "dont_patch_engines" in request.keywords:
111+
if "dont_patch_engines" in request.keywords: # pragma: no cover
108112
yield
109113

110114
else:
@@ -149,24 +153,29 @@ def async_sqlalchemy_url(db_url):
149153
return format_async_async_sqlalchemy_url(db_url)
150154

151155

152-
if asyncio_support:
156+
if asyncio_support: # noqa: C901
153157

154158
@fixture
155-
async def async_engine(async_sqlalchemy_url):
159+
def async_engine(async_sqlalchemy_url):
156160
return create_async_engine(async_sqlalchemy_url)
157161

158162
@fixture
159163
async def async_sqla_connection(async_engine, event_loop):
160-
async with async_engine.begin() as connection:
164+
async with async_engine.connect() as connection:
161165
yield connection
162-
await connection.rollback()
166+
167+
@fixture
168+
async def async_sqla_transaction(async_sqla_connection):
169+
async with async_sqla_connection.begin() as transaction:
170+
yield transaction
171+
await transaction.rollback()
163172

164173
@fixture
165174
async def patch_new_engine(async_sqlalchemy_url, async_sqla_connection, request):
166175
"""So that all async DB operations are never written to db for real."""
167176
from fastapi_sqla.async_sqla import _AsyncSession
168177

169-
if "dont_patch_engines" in request.keywords:
178+
if "dont_patch_engines" in request.keywords: # pragma: no cover
170179
yield
171180

172181
else:
@@ -185,16 +194,13 @@ async def async_sqla_reflection(sqla_modules, async_sqla_connection):
185194

186195
@fixture
187196
async def async_session(
188-
async_sqla_connection, async_sqla_reflection, patch_new_engine
197+
async_sqla_connection,
198+
async_sqla_transaction,
199+
async_sqla_reflection,
200+
patch_new_engine,
189201
):
190202
from fastapi_sqla.async_sqla import _AsyncSession
191203

192204
session = _AsyncSession(bind=async_sqla_connection)
193205
yield session
194206
await session.close()
195-
196-
else:
197-
198-
@fixture
199-
async def patch_new_engine():
200-
pass

fastapi_sqla/async_sqla.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import structlog
77
from fastapi import Request
8+
from fastapi.responses import PlainTextResponse
89
from sqlalchemy import text
910
from sqlalchemy.ext.asyncio import AsyncEngine
1011
from sqlalchemy.ext.asyncio import AsyncSession as SqlaAsyncSession
@@ -87,13 +88,20 @@ async def open_session() -> AsyncGenerator[SqlaAsyncSession, None]:
8788

8889
try:
8990
yield session
90-
await session.commit()
9191

9292
except Exception:
93-
logger.exception("commit failed, rolling back")
93+
logger.warning("context failed, rolling back", exc_info=True)
9494
await session.rollback()
9595
raise
9696

97+
else:
98+
try:
99+
await session.commit()
100+
except Exception:
101+
logger.exception("commit failed, rolling back")
102+
await session.rollback()
103+
raise
104+
97105
finally:
98106
await session.close()
99107

@@ -119,9 +127,30 @@ async def get_users(session: fastapi_sqla.AsyncSession = Depends()):
119127
async with open_session() as session:
120128
request.scope[_ASYNC_SESSION_KEY] = session
121129
response = await call_next(request)
130+
131+
is_dirty = bool(session.dirty or session.deleted or session.new)
132+
133+
# try to commit after response, so that we can return a proper 500 response
134+
# and not raise a true internal server error
135+
if response.status_code < 400:
136+
try:
137+
await session.commit()
138+
except Exception:
139+
logger.exception("commit failed, returning http error")
140+
response = PlainTextResponse(
141+
content="Internal Server Error", status_code=500
142+
)
143+
122144
if response.status_code >= 400:
123145
# If ever a route handler returns an http exception, we do not want the
124146
# session opened by current context manager to commit anything in db.
147+
if is_dirty:
148+
# optimistically only log if there were uncommitted changes
149+
logger.warning(
150+
"http error, rolling back possibly uncommitted changes",
151+
status_code=response.status_code,
152+
)
153+
# since this is no-op if session is not dirty, we can always call it
125154
await session.rollback()
126155

127156
return response

tests/conftest.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,6 @@ def environ(db_url, sqla_version_tuple, async_sqlalchemy_url):
6868
yield values
6969

7070

71-
@fixture(scope="session")
72-
def engine(environ):
73-
from sqlalchemy import engine_from_config
74-
75-
engine = engine_from_config(environ, prefix="sqlalchemy_")
76-
return engine
77-
78-
7971
@fixture(autouse=True)
8072
def tear_down(environ):
8173
from sqlalchemy.orm.session import close_all_sessions
@@ -88,6 +80,7 @@ def tear_down(environ):
8880
# reload fastapi_sqla to clear sqla deferred reflection mapping stored in Base
8981
importlib.reload(fastapi_sqla.models)
9082
importlib.reload(fastapi_sqla.sqla)
83+
importlib.reload(fastapi_sqla.async_sqla)
9184
importlib.reload(fastapi_sqla)
9285

9386

tests/middleware/conftest.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from unittest.mock import Mock
2+
3+
import httpx
4+
from asgi_lifespan import LifespanManager
5+
from fastapi import Depends, FastAPI, HTTPException
6+
from pydantic import BaseModel
7+
from pytest import fixture
8+
from sqlalchemy import text
9+
10+
11+
@fixture(scope="module", autouse=True)
12+
def setup_tear_down(sqla_connection):
13+
with sqla_connection.begin():
14+
sqla_connection.execute(
15+
text(
16+
"""
17+
CREATE TABLE IF NOT EXISTS public.user (
18+
id integer primary key,
19+
first_name varchar,
20+
last_name varchar
21+
)
22+
"""
23+
)
24+
)
25+
yield
26+
with sqla_connection.begin():
27+
sqla_connection.execute(text("DROP TABLE public.user"))
28+
29+
30+
@fixture
31+
def User():
32+
from fastapi_sqla import Base
33+
34+
class User(Base):
35+
__tablename__ = "user"
36+
37+
return User
38+
39+
40+
@fixture
41+
def app(User):
42+
from fastapi_sqla import Session, setup
43+
44+
app = FastAPI()
45+
setup(app)
46+
47+
class UserIn(BaseModel):
48+
id: int
49+
first_name: str
50+
last_name: str
51+
52+
@app.post("/users")
53+
def create_user(user: UserIn, session: Session = Depends()):
54+
session.add(User(**dict(user)))
55+
56+
@app.get("/404")
57+
def get_users(session: Session = Depends(Session)):
58+
raise HTTPException(status_code=404, detail="YOLO")
59+
60+
return app
61+
62+
63+
@fixture
64+
def mock_middleware(app: FastAPI):
65+
mock_middleware = Mock()
66+
67+
@app.middleware("http")
68+
async def a_middleware(request, call_next):
69+
res = await call_next(request)
70+
mock_middleware()
71+
return res
72+
73+
return mock_middleware
74+
75+
76+
@fixture
77+
async def client(app, mock_middleware):
78+
async with LifespanManager(app):
79+
transport = httpx.ASGITransport(app=app, raise_app_exceptions=False)
80+
async with httpx.AsyncClient(
81+
transport=transport, base_url="http://example.local"
82+
) as client:
83+
yield client
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from unittest.mock import patch
2+
3+
from fastapi import Depends, FastAPI, HTTPException
4+
from pydantic import BaseModel
5+
from pytest import fixture, mark
6+
from sqlalchemy import text
7+
from structlog.testing import capture_logs
8+
9+
pytestmark = [mark.sqlalchemy("1.4"), mark.require_asyncpg]
10+
11+
12+
@fixture
13+
def app(User):
14+
from fastapi_sqla import AsyncSession, setup
15+
16+
app = FastAPI()
17+
setup(app)
18+
19+
class UserIn(BaseModel):
20+
id: int
21+
first_name: str
22+
last_name: str
23+
24+
@app.post("/users")
25+
def create_user(user: UserIn, session: AsyncSession = Depends()):
26+
session.add(User(**dict(user)))
27+
28+
@app.get("/404")
29+
def get_users(session: AsyncSession = Depends(AsyncSession)):
30+
raise HTTPException(status_code=404, detail="YOLO")
31+
32+
return app
33+
34+
35+
async def test_async_session_dependency(client, faker, async_session):
36+
userid = faker.unique.random_int()
37+
first_name = faker.first_name()
38+
last_name = faker.last_name()
39+
res = await client.post(
40+
"/users", json={"id": userid, "first_name": first_name, "last_name": last_name}
41+
)
42+
assert res.status_code == 200, res.json()
43+
row = (
44+
await async_session.execute(
45+
text(f"select * from public.user where id = {userid}")
46+
)
47+
).fetchone()
48+
assert row == (userid, first_name, last_name)
49+
50+
51+
@fixture
52+
async def user_1(async_sqla_connection):
53+
async with async_sqla_connection.begin():
54+
await async_sqla_connection.execute(
55+
text("INSERT INTO public.user VALUES (1, 'bob', 'morane') ")
56+
)
57+
yield
58+
async with async_sqla_connection.begin():
59+
await async_sqla_connection.execute(
60+
text("DELETE FROM public.user WHERE id = 1")
61+
)
62+
63+
64+
async def test_commit_error_returns_500(client, user_1, mock_middleware):
65+
with capture_logs() as caplog:
66+
res = await client.post(
67+
"/users",
68+
json={"id": 1, "first_name": "Bob", "last_name": "Morane"},
69+
headers={"origin": "localhost"},
70+
)
71+
72+
assert res.status_code == 500
73+
74+
assert {
75+
"event": "commit failed, returning http error",
76+
"exc_info": True,
77+
"log_level": "error",
78+
} in caplog
79+
80+
assert {
81+
"event": "http error, rolling back possibly uncommitted changes",
82+
"log_level": "warning",
83+
"status_code": 500,
84+
} in caplog
85+
86+
mock_middleware.assert_called_once()
87+
88+
89+
async def test_rollback_on_http_exception(client, mock_middleware):
90+
with patch("fastapi_sqla.async_sqla.open_session") as open_session:
91+
session = open_session.return_value.__aenter__.return_value
92+
93+
await client.get("/404")
94+
95+
session.rollback.assert_awaited_once_with()
96+
mock_middleware.assert_called_once()
97+
98+
99+
async def test_rollback_on_http_exception_silent(client, mock_middleware):
100+
with capture_logs() as caplog:
101+
await client.get("/404")
102+
103+
mock_middleware.assert_called_once()
104+
105+
assert {
106+
"event": "http error, rolling back possibly uncommitted changes",
107+
"log_level": "warning",
108+
"status_code": 404,
109+
} not in caplog

0 commit comments

Comments
 (0)