1212from sqlalchemy import engine_from_config
1313from sqlalchemy .ext .declarative import DeferredReflection , declarative_base
1414from sqlalchemy .orm import Query as DbQuery
15- from sqlalchemy .orm .session import Session , sessionmaker
15+ from sqlalchemy .orm .session import Session as SqlaSession , sessionmaker
1616
17- __all__ = ["Base" , "setup " , "with_session " ]
17+ __all__ = ["Base" , "Session " , "setup " ]
1818
1919logger = structlog .get_logger (__name__ )
2020
@@ -40,6 +40,29 @@ class Base(declarative_base(cls=DeferredReflection)): # type: ignore
4040 __abstract__ = True
4141
4242
43+ class Session (SqlaSession ):
44+ def __new__ (cls , request : Request ) -> SqlaSession :
45+ """Yield the sqlalchmey session for that request.
46+
47+ It is meant to be used as a FastAPI dependency::
48+
49+ from fastapi import APIRouter, Depends
50+ from fastapi_sqla import Session
51+
52+ router = APIRouter()
53+
54+ @router.get("/users")
55+ def get_users(session: Session = Depends()):
56+ pass
57+ """
58+ try :
59+ return request .scope [_SESSION_KEY ]
60+ except KeyError : # pragma: no cover
61+ raise Exception (
62+ "No session found in request, please ensure you've setup fastapi_sqla."
63+ )
64+
65+
4366@contextmanager
4467def open_session () -> Session :
4568 """Context manager that opens a session and properly closes session when exiting.
@@ -63,28 +86,6 @@ def open_session() -> Session:
6386 session .close ()
6487
6588
66- def with_session (request : Request ) -> Session :
67- """Yield the sqlalchmey session for that request.
68-
69- It is meant to be used as a FastAPI® dependency::
70-
71- from er import sqla
72- from fastapi import APIRouter, Depends
73-
74- router = APIRouter()
75-
76- @router.get("/users")
77- def get_users(db: sqla.Session = Depends(sqla.with_session)):
78- pass
79- """
80- try :
81- yield request .scope [_SESSION_KEY ]
82- except KeyError : # pragma: no cover
83- raise Exception (
84- "No session found in request, please ensure you've setup fastapi_sqla."
85- )
86-
87-
8889async def add_session_to_request (request : Request , call_next ):
8990 """Middleware which injects a new sqla session into every request.
9091
@@ -165,7 +166,7 @@ def Pagination(
165166 query_count : Callable [[Session , DbQuery ], int ] = _query_count ,
166167) -> Callable [[Session , int , int ], PaginatedResult ]:
167168 def dependency (
168- session : Session = Depends (with_session ),
169+ session : Session = Depends (),
169170 offset : int = Query (0 , ge = 0 ),
170171 limit : int = Query (min_page_size , ge = 1 , le = max_page_size ),
171172 ) -> PaginatedResult :
0 commit comments