|
1 | | -from unittest.mock import patch |
| 1 | +from unittest.mock import Mock, patch |
2 | 2 |
|
3 | 3 | import httpx |
4 | 4 | from asgi_lifespan import LifespanManager |
5 | | -from fastapi import Depends, FastAPI |
| 5 | +from fastapi import Depends, FastAPI, HTTPException |
6 | 6 | from pydantic import BaseModel |
7 | 7 | from pytest import fixture, mark |
8 | 8 | from sqlalchemy import text |
@@ -68,9 +68,26 @@ def create_user_with_async_session( |
68 | 68 | ): |
69 | 69 | session.add(User(**user.dict())) |
70 | 70 |
|
| 71 | + @app.get("/404") |
| 72 | + def get_users(session: Session = Depends(Session)): |
| 73 | + raise HTTPException(status_code=404, detail="YOLO") |
| 74 | + |
71 | 75 | return app |
72 | 76 |
|
73 | 77 |
|
| 78 | +@fixture |
| 79 | +def mock_middleware(app: FastAPI): |
| 80 | + mock_middleware = Mock() |
| 81 | + |
| 82 | + @app.middleware("http") |
| 83 | + async def a_middleware(request, call_next): |
| 84 | + res = await call_next(request) |
| 85 | + mock_middleware() |
| 86 | + return res |
| 87 | + |
| 88 | + return mock_middleware |
| 89 | + |
| 90 | + |
74 | 91 | @fixture |
75 | 92 | async def client(app): |
76 | 93 |
|
@@ -120,24 +137,36 @@ def user_1(sqla_connection): |
120 | 137 | yield |
121 | 138 |
|
122 | 139 |
|
123 | | -async def test_commit_error_returns_500(client, user_1): |
| 140 | +async def test_commit_error_returns_500(client, user_1, mock_middleware): |
124 | 141 | with capture_logs() as caplog: |
125 | 142 | res = await client.post( |
126 | | - "/users", json={"id": 1, "first_name": "Bob", "last_name": "Morane"} |
| 143 | + "/users", |
| 144 | + json={"id": 1, "first_name": "Bob", "last_name": "Morane"}, |
| 145 | + headers={"origin": "localhost"}, |
127 | 146 | ) |
128 | 147 |
|
129 | 148 | assert res.status_code == 500 |
| 149 | + |
130 | 150 | assert { |
131 | | - "event": "commit failed, rolling back", |
132 | | - "log_level": "error", |
| 151 | + "event": "commit failed, returning http error", |
133 | 152 | "exc_info": True, |
| 153 | + "log_level": "error", |
| 154 | + } in caplog |
| 155 | + |
| 156 | + assert { |
| 157 | + "event": "http error, rolling back", |
| 158 | + "log_level": "warning", |
| 159 | + "status_code": 500, |
134 | 160 | } in caplog |
135 | 161 |
|
| 162 | + mock_middleware.assert_called_once() |
| 163 | + |
136 | 164 |
|
137 | | -async def test_rollback_on_http_exception(client): |
| 165 | +async def test_rollback_on_http_exception(client, mock_middleware): |
138 | 166 | with patch("fastapi_sqla.open_session") as open_session: |
139 | 167 | session = open_session.return_value.__enter__.return_value |
140 | 168 |
|
141 | 169 | await client.get("/404") |
142 | 170 |
|
143 | 171 | session.rollback.assert_called_once_with() |
| 172 | + mock_middleware.assert_called_once() |
0 commit comments