Skip to content

Commit 2e969aa

Browse files
authored
fix: other middlewares don't run when there's an exception in fastapi-sqla context manager (#16)
1 parent e5fb670 commit 2e969aa

3 files changed

Lines changed: 144 additions & 10 deletions

File tree

fastapi_sqla/__init__.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import structlog
99
from fastapi import Depends, FastAPI, Query, Request
1010
from fastapi.concurrency import contextmanager_in_threadpool
11+
from fastapi.responses import PlainTextResponse
1112
from pydantic import BaseModel, Field
1213
from pydantic.generics import GenericModel
1314
from sqlalchemy import engine_from_config
@@ -106,13 +107,20 @@ def open_session() -> Session:
106107

107108
try:
108109
yield session
109-
session.commit()
110110

111111
except Exception:
112-
logger.exception("commit failed, rolling back")
112+
logger.warning("context failed, rolling back", exc_info=True)
113113
session.rollback()
114114
raise
115115

116+
else:
117+
try:
118+
session.commit()
119+
except Exception:
120+
logger.exception("commit failed, rolling back")
121+
session.rollback()
122+
raise
123+
116124
finally:
117125
session.close()
118126

@@ -137,11 +145,27 @@ def get_users(session: fastapi_sqla.Session = Depends()):
137145
"""
138146
async with contextmanager_in_threadpool(open_session()) as session:
139147
request.scope[_SESSION_KEY] = session
148+
140149
response = await call_next(request)
150+
151+
loop = asyncio.get_running_loop()
152+
153+
# try to commit after response, so that we can return a proper 500 response
154+
# and not raise a true internal server error
155+
if response.status_code < 400:
156+
157+
try:
158+
await loop.run_in_executor(None, session.commit)
159+
except Exception:
160+
logger.exception("commit failed, returning http error")
161+
response = PlainTextResponse(
162+
content="Internal Server Error", status_code=500
163+
)
164+
141165
if response.status_code >= 400:
142166
# If ever a route handler returns an http exception, we do not want the
143167
# session opened by current context manager to commit anything in db.
144-
loop = asyncio.get_running_loop()
168+
logger.warning("http error, rolling back", status_code=response.status_code)
145169
await loop.run_in_executor(None, session.rollback)
146170

147171
return response

tests/test_middleware.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from unittest.mock import patch
1+
from unittest.mock import Mock, patch
22

33
import httpx
44
from asgi_lifespan import LifespanManager
5-
from fastapi import Depends, FastAPI
5+
from fastapi import Depends, FastAPI, HTTPException
66
from pydantic import BaseModel
77
from pytest import fixture, mark
88
from sqlalchemy import text
@@ -68,9 +68,26 @@ def create_user_with_async_session(
6868
):
6969
session.add(User(**user.dict()))
7070

71+
@app.get("/404")
72+
def get_users(session: Session = Depends(Session)):
73+
raise HTTPException(status_code=404, detail="YOLO")
74+
7175
return app
7276

7377

78+
@fixture
79+
def mock_middleware(app: FastAPI):
80+
mock_middleware = Mock()
81+
82+
@app.middleware("http")
83+
async def a_middleware(request, call_next):
84+
res = await call_next(request)
85+
mock_middleware()
86+
return res
87+
88+
return mock_middleware
89+
90+
7491
@fixture
7592
async def client(app):
7693

@@ -120,24 +137,36 @@ def user_1(sqla_connection):
120137
yield
121138

122139

123-
async def test_commit_error_returns_500(client, user_1):
140+
async def test_commit_error_returns_500(client, user_1, mock_middleware):
124141
with capture_logs() as caplog:
125142
res = await client.post(
126-
"/users", json={"id": 1, "first_name": "Bob", "last_name": "Morane"}
143+
"/users",
144+
json={"id": 1, "first_name": "Bob", "last_name": "Morane"},
145+
headers={"origin": "localhost"},
127146
)
128147

129148
assert res.status_code == 500
149+
130150
assert {
131-
"event": "commit failed, rolling back",
132-
"log_level": "error",
151+
"event": "commit failed, returning http error",
133152
"exc_info": True,
153+
"log_level": "error",
154+
} in caplog
155+
156+
assert {
157+
"event": "http error, rolling back",
158+
"log_level": "warning",
159+
"status_code": 500,
134160
} in caplog
135161

162+
mock_middleware.assert_called_once()
163+
136164

137-
async def test_rollback_on_http_exception(client):
165+
async def test_rollback_on_http_exception(client, mock_middleware):
138166
with patch("fastapi_sqla.open_session") as open_session:
139167
session = open_session.return_value.__enter__.return_value
140168

141169
await client.get("/404")
142170

143171
session.rollback.assert_called_once_with()
172+
mock_middleware.assert_called_once()

tests/test_open_session.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from pytest import fixture, mark, raises
2+
from sqlalchemy import insert, select, text
3+
from sqlalchemy.exc import IntegrityError
4+
5+
6+
@fixture(autouse=True, scope="module")
7+
def module_setup_tear_down(engine, sqla_connection):
8+
engine.execute(
9+
"CREATE TABLE IF NOT EXISTS test_table (id integer primary key, value varchar)"
10+
)
11+
yield
12+
engine.execute("DROP TABLE test_table")
13+
14+
15+
@fixture(autouse=True)
16+
def setup(sqla_connection):
17+
from fastapi_sqla import _Session
18+
19+
_Session.configure(bind=sqla_connection)
20+
21+
22+
@fixture(scope="module")
23+
def TestTable(module_setup_tear_down):
24+
from fastapi_sqla import Base, startup
25+
26+
class TestTable(Base):
27+
__tablename__ = "test_table"
28+
29+
startup()
30+
31+
return TestTable
32+
33+
34+
@mark.sqlalchemy("1.4")
35+
def test_open_session():
36+
from fastapi_sqla import open_session
37+
38+
with open_session() as session:
39+
res = session.execute(select(text("'OK'"))).scalar()
40+
41+
assert res == "OK"
42+
43+
44+
@mark.sqlalchemy("1.4")
45+
def test_open_session_rollback_when_error_occurs_in_context(TestTable, session):
46+
from fastapi_sqla import open_session
47+
48+
error = Exception("Error in context")
49+
50+
class Custom(Exception):
51+
pass
52+
53+
with raises(Exception) as raise_info:
54+
with open_session() as session:
55+
session.execute(insert(TestTable).values(id=1, value="bobby drop tables"))
56+
raise error
57+
58+
assert raise_info.value == error
59+
60+
res = session.execute(select(TestTable)).fetchall()
61+
assert res == [], "insert has not been rolled back"
62+
63+
64+
@fixture
65+
def existing_record(TestTable, session):
66+
id = 1
67+
session.execute(insert(TestTable).values(id=id, value="bob morane was there."))
68+
session.flush()
69+
yield (1, "bob morane was there.")
70+
71+
72+
def test_open_session_re_raise_exception_when_commit_fails(
73+
TestTable, existing_record, session
74+
):
75+
from fastapi_sqla import open_session
76+
77+
with raises(Exception) as raise_info:
78+
with open_session() as session:
79+
session.add(TestTable(id=1, value="bobby already exists"))
80+
81+
assert isinstance(raise_info.value, IntegrityError)

0 commit comments

Comments
 (0)