11import os
2+ from contextlib import contextmanager
23
34import structlog
4- from fastapi import FastAPI
5+ from fastapi import FastAPI , Request
6+ from fastapi .concurrency import contextmanager_in_threadpool
57from sqlalchemy import engine_from_config
68from sqlalchemy .ext .declarative import DeferredReflection , declarative_base
7- from sqlalchemy .orm .session import sessionmaker
9+ from sqlalchemy .orm .session import Session , sessionmaker
810
9- __all__ = ["Base" , "setup" ]
11+ __all__ = ["Base" , "setup" , "with_session" ]
1012
1113logger = structlog .get_logger (__name__ )
1214
1517
1618def setup (app : FastAPI ):
1719 app .add_event_handler ("startup" , startup )
20+ app .middleware ("http" )(add_session_to_request )
1821
1922
2023def startup ():
@@ -27,3 +30,76 @@ def startup():
2730
2831class Base (declarative_base (cls = DeferredReflection )): # type: ignore
2932 __abstract__ = True
33+
34+
35+ @contextmanager
36+ def open_session () -> Session :
37+ """Context manager that opens a session and properly closes session when exiting.
38+
39+ If no exception is raised before exiting context, session is committed when exiting
40+ context. If an exception is raised, session is rollbacked.
41+ """
42+ session = _Session ()
43+ logger .bind (db_session = session )
44+
45+ try :
46+ yield session
47+ logger .debug ("committing" )
48+ session .commit ()
49+ except Exception :
50+ logger .exception ("rolling back" )
51+ session .rollback ()
52+ raise
53+ finally :
54+ session .close ()
55+
56+
57+ def with_session (request : Request ) -> Session :
58+ """Yield the sqlalchmey session for that request.
59+
60+ It is meant to be used as a FastAPI® dependency::
61+
62+ from er import sqla
63+ from fastapi import APIRouter, Depends
64+
65+ router = APIRouter()
66+
67+ @router.get("/users")
68+ def get_users(db: sqla.Session = Depends(sqla.with_session)):
69+ pass
70+ """
71+ try :
72+ yield request .scope ["sqla_session" ]
73+ except KeyError : # pragma: no cover
74+ raise Exception (
75+ "No session found in request, please ensure you've setup fastapi_sqla."
76+ )
77+
78+
79+ async def add_session_to_request (request : Request , call_next ):
80+ """Middleware which injects a new sqla session into every request.
81+
82+ Handles creation of session, as well as commit, rollback, and closing of session.
83+
84+ Usage::
85+
86+ import fastapi_sqla
87+ from fastapi import FastApi
88+
89+ app = FastApi()
90+
91+ fastapi_sqla.setup(app) # includes middleware
92+
93+ @app.get("/users")
94+ def get_users(session: sqla.Session = Depends(sqla.new_session)):
95+ return session.query(...) # use your session here
96+ """
97+ async with contextmanager_in_threadpool (open_session ()) as session :
98+ request .scope ["sqla_session" ] = session
99+ response = await call_next (request )
100+ if response .status_code >= 400 :
101+ # If ever a route handler returns an http exception, we do not want the
102+ # session opened by current context manager to commit anything in db.
103+ session .rollback ()
104+
105+ return response
0 commit comments