Skip to content

Commit 1808ee2

Browse files
authored
feat: add middleware (#8)
1 parent e2f3849 commit 1808ee2

5 files changed

Lines changed: 222 additions & 28 deletions

File tree

README.md

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,46 @@
22

33
SqlAlchemy integration for FastAPI®
44

5-
65
## Configuration
76

8-
* Configure environ variables:
7+
### Environment variables:
98
The keys of interest in `os.environ` are prefixed with `sqlalchemy_`.
109
Each matching key (after the prefix is stripped) is treated as though it were the
1110
corresponding keyword argument to [`sqlalchemy.create_engine`](https://docs.sqlalchemy.org/en/13/core/engines.html?highlight=create_engine#sqlalchemy.create_engine) # noqa
1211
call.
1312

1413
The only required key is `sqlalchemy_url`, which provides the database URL.
1514

16-
* Setup the app:
17-
```python
18-
import fastapi_sqla
19-
from fastapi import FastAPI
20-
21-
app = FastAPI()
22-
fastapi_sqla.setup(app)
23-
```
24-
* Adding a new entity class:
25-
```python
26-
from fastapi_sqla import Base
27-
28-
class Entity(Base):
29-
__tablename__ = "table-name-in-db"
30-
```
15+
### Setup the app:
16+
17+
```python
18+
import fastapi_sqla
19+
from fastapi import FastAPI
20+
21+
app = FastAPI()
22+
fastapi_sqla.setup(app)
23+
```
24+
25+
### Adding a new entity class:
26+
27+
```python
28+
from fastapi_sqla import Base
29+
30+
31+
class Entity(Base):
32+
__tablename__ = "table-name-in-db"
33+
```
34+
35+
### Getting an sqla orm session
36+
37+
```python
38+
from fastapi import APIRouter, Depends
39+
from fastapi_sqla import Session, with_session
40+
41+
router = APIRouter()
42+
43+
44+
@router.get("/example")
45+
def example(session: Session = Depends(with_session)):
46+
return session.execute("SELECT now()").scalar()
47+
```

fastapi_sqla.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
2+
from contextlib import contextmanager
23

34
import structlog
4-
from fastapi import FastAPI
5+
from fastapi import FastAPI, Request
6+
from fastapi.concurrency import contextmanager_in_threadpool
57
from sqlalchemy import engine_from_config
68
from sqlalchemy.ext.declarative import DeferredReflection, declarative_base
7-
from sqlalchemy.orm.session import sessionmaker
9+
from sqlalchemy.orm.session import Session, sessionmaker
810

9-
__all__ = ["Base", "setup"]
11+
__all__ = ["Base", "setup", "with_session"]
1012

1113
logger = structlog.get_logger(__name__)
1214

@@ -15,6 +17,7 @@
1517

1618
def setup(app: FastAPI):
1719
app.add_event_handler("startup", startup)
20+
app.middleware("http")(add_session_to_request)
1821

1922

2023
def startup():
@@ -27,3 +30,76 @@ def startup():
2730

2831
class Base(declarative_base(cls=DeferredReflection)): # type: ignore
2932
__abstract__ = True
33+
34+
35+
@contextmanager
36+
def open_session() -> Session:
37+
"""Context manager that opens a session and properly closes session when exiting.
38+
39+
If no exception is raised before exiting context, session is committed when exiting
40+
context. If an exception is raised, session is rollbacked.
41+
"""
42+
session = _Session()
43+
logger.bind(db_session=session)
44+
45+
try:
46+
yield session
47+
logger.debug("committing")
48+
session.commit()
49+
except Exception:
50+
logger.exception("rolling back")
51+
session.rollback()
52+
raise
53+
finally:
54+
session.close()
55+
56+
57+
def with_session(request: Request) -> Session:
58+
"""Yield the sqlalchmey session for that request.
59+
60+
It is meant to be used as a FastAPI® dependency::
61+
62+
from er import sqla
63+
from fastapi import APIRouter, Depends
64+
65+
router = APIRouter()
66+
67+
@router.get("/users")
68+
def get_users(db: sqla.Session = Depends(sqla.with_session)):
69+
pass
70+
"""
71+
try:
72+
yield request.scope["sqla_session"]
73+
except KeyError: # pragma: no cover
74+
raise Exception(
75+
"No session found in request, please ensure you've setup fastapi_sqla."
76+
)
77+
78+
79+
async def add_session_to_request(request: Request, call_next):
80+
"""Middleware which injects a new sqla session into every request.
81+
82+
Handles creation of session, as well as commit, rollback, and closing of session.
83+
84+
Usage::
85+
86+
import fastapi_sqla
87+
from fastapi import FastApi
88+
89+
app = FastApi()
90+
91+
fastapi_sqla.setup(app) # includes middleware
92+
93+
@app.get("/users")
94+
def get_users(session: sqla.Session = Depends(sqla.new_session)):
95+
return session.query(...) # use your session here
96+
"""
97+
async with contextmanager_in_threadpool(open_session()) as session:
98+
request.scope["sqla_session"] = session
99+
response = await call_next(request)
100+
if response.status_code >= 400:
101+
# If ever a route handler returns an http exception, we do not want the
102+
# session opened by current context manager to commit anything in db.
103+
session.rollback()
104+
105+
return response

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import patch
44

55
from pytest import fixture
6+
from sqlalchemy.orm.session import close_all_sessions
67

78

89
@fixture(scope="session")
@@ -22,5 +23,8 @@ def environ(db_uri):
2223
def tear_down():
2324
import fastapi_sqla
2425

26+
yield
27+
28+
close_all_sessions()
2529
# reload fastapi_sqla to clear sqla deferred reflection mapping stored in Base
2630
importlib.reload(fastapi_sqla)

tests/test_middleware.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from unittest.mock import patch
2+
3+
import httpx
4+
from asgi_lifespan import LifespanManager
5+
from fastapi import Depends, FastAPI
6+
from pydantic import BaseModel
7+
from pytest import fixture, mark
8+
from sqlalchemy import engine_from_config
9+
from sqlalchemy.orm.session import close_all_sessions
10+
11+
pytestmark = mark.asyncio
12+
13+
14+
@fixture
15+
def engine(environ):
16+
engine = engine_from_config(environ, prefix="sqlalchemy_")
17+
return engine
18+
19+
20+
@fixture(autouse=True)
21+
def setup_tear_down(engine):
22+
engine.execute(
23+
"""
24+
CREATE TABLE IF NOT EXISTS public.user (
25+
id integer primary key,
26+
first_name varchar,
27+
last_name varchar
28+
)
29+
"""
30+
)
31+
yield
32+
close_all_sessions()
33+
engine.execute("DROP TABLE public.user")
34+
35+
36+
@fixture
37+
def User():
38+
from fastapi_sqla import Base
39+
40+
class User(Base):
41+
__tablename__ = "user"
42+
43+
return User
44+
45+
46+
@fixture
47+
def app(User):
48+
from fastapi_sqla import setup, with_session
49+
50+
app = FastAPI()
51+
setup(app)
52+
53+
class UserIn(BaseModel):
54+
id: int
55+
first_name: str
56+
last_name: str
57+
58+
@app.post("/users")
59+
def create_user(user: UserIn, session=Depends(with_session)):
60+
session.add(User(**user.dict()))
61+
session.flush()
62+
return {}
63+
64+
return app
65+
66+
67+
@fixture
68+
async def client(app):
69+
70+
async with LifespanManager(app):
71+
transport = httpx.ASGITransport(app=app, raise_app_exceptions=False)
72+
async with httpx.AsyncClient(
73+
transport=transport, base_url="http://example.local"
74+
) as client:
75+
yield client
76+
77+
78+
async def test_session_dependency(client):
79+
res = await client.post(
80+
"/users", json={"id": 1, "first_name": "Bob", "last_name": "Morane"}
81+
)
82+
assert res.status_code == 200
83+
84+
85+
@fixture
86+
def user_1(engine):
87+
engine.execute("INSERT INTO public.user VALUES (1, 'bob', 'morane') ")
88+
yield
89+
90+
91+
async def test_commit_error_returns_500(client, user_1):
92+
res = await client.post(
93+
"/users", json={"id": 1, "first_name": "Bob", "last_name": "Morane"}
94+
)
95+
assert res.status_code == 500
96+
97+
98+
async def test_rollback_on_http_exception(client):
99+
with patch("fastapi_sqla.open_session") as open_session:
100+
session = open_session.return_value.__enter__.return_value
101+
102+
await client.get("/404")
103+
104+
session.rollback.assert_called_once_with()

tests/test_startup.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
import httpx
22
from asgi_lifespan import LifespanManager
33
from fastapi import FastAPI
4-
from pytest import fixture, mark
5-
from sqlalchemy.orm.session import close_all_sessions
6-
7-
8-
@fixture(autouse=True)
9-
def setup_tear_down():
10-
yield
11-
close_all_sessions()
4+
from pytest import mark
125

136

147
def test_startup():

0 commit comments

Comments
 (0)