Skip to content

Commit f6e3f03

Browse files
authored
fix: importing AsyncSession from fastapi_sqla - DIA-46202 (#68)
1 parent 4f02815 commit f6e3f03

13 files changed

Lines changed: 406 additions & 370 deletions

fastapi_sqla/__init__.py

Lines changed: 37 additions & 347 deletions
Large diffs are not rendered by default.

fastapi_sqla/_pytest_plugin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ def sqla_reflection(sqla_modules, sqla_connection, db_url):
102102
@fixture
103103
def patch_engine_from_config(request, db_url, sqla_connection, sqla_transaction):
104104
"""So that all DB operations are never written to db for real."""
105-
from fastapi_sqla import _Session
105+
from fastapi_sqla.sqla import _Session
106106

107107
if "dont_patch_engines" in request.keywords:
108108
yield
109109

110110
else:
111-
with patch("fastapi_sqla.engine_from_config") as engine_from_config:
111+
with patch("fastapi_sqla.sqla.engine_from_config") as engine_from_config:
112112
engine_from_config.return_value = sqla_connection
113113
_Session.configure(bind=sqla_connection)
114114
yield engine_from_config
@@ -130,9 +130,9 @@ def session(
130130
While it does not write any record in DB, the application will still be able to
131131
access any record committed with that session.
132132
"""
133-
import fastapi_sqla
133+
import fastapi_sqla.sqla
134134

135-
yield fastapi_sqla._Session(bind=sqla_connection)
135+
yield fastapi_sqla.sqla._Session(bind=sqla_connection)
136136

137137

138138
def format_async_async_sqlalchemy_url(url):

fastapi_sqla/asyncio_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlalchemy.orm.session import sessionmaker
1313
from sqlalchemy.sql import Select, func, select
1414

15-
from . import Base, Page, T, aws_rds_iam_support, new_engine
15+
from fastapi_sqla.sqla import Base, Page, T, aws_rds_iam_support, new_engine
1616

1717
logger = structlog.get_logger(__name__)
1818
_ASYNC_SESSION_KEY = "fastapi_sqla_async_session"

fastapi_sqla/sqla.py

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
import asyncio
2+
import math
3+
import os
4+
from collections.abc import Callable, Generator
5+
from contextlib import contextmanager
6+
from functools import singledispatch
7+
from typing import Generic, Optional, TypeVar, Union
8+
9+
import structlog
10+
from fastapi import Depends, Query, Request
11+
from fastapi.concurrency import contextmanager_in_threadpool
12+
from fastapi.responses import PlainTextResponse
13+
from pydantic import BaseModel, Field
14+
from pydantic.generics import GenericModel
15+
from sqlalchemy import engine_from_config, text
16+
from sqlalchemy.engine import Engine
17+
from sqlalchemy.ext.declarative import DeferredReflection
18+
from sqlalchemy.orm import Query as LegacyQuery
19+
from sqlalchemy.orm.session import Session as SqlaSession
20+
from sqlalchemy.orm.session import sessionmaker
21+
from sqlalchemy.sql import Select, func, select
22+
23+
from fastapi_sqla import aws_rds_iam_support
24+
25+
try:
26+
from sqlalchemy.orm import declarative_base
27+
except ImportError:
28+
from sqlalchemy.ext.declarative import declarative_base
29+
30+
31+
logger = structlog.get_logger(__name__)
32+
33+
_SESSION_KEY = "fastapi_sqla_session"
34+
35+
_Session = sessionmaker()
36+
37+
38+
def new_engine(*, envvar_prefix: str = None) -> Engine:
39+
envvar_prefix = envvar_prefix if envvar_prefix else "sqlalchemy_"
40+
lowercase_environ = {
41+
k.lower(): v for k, v in os.environ.items() if k.lower() != "sqlalchemy_warn_20"
42+
}
43+
return engine_from_config(lowercase_environ, prefix=envvar_prefix)
44+
45+
46+
def is_async_dialect(engine):
47+
return engine.dialect.is_async if hasattr(engine.dialect, "is_async") else False
48+
49+
50+
def startup():
51+
engine = new_engine()
52+
aws_rds_iam_support.setup(engine.engine)
53+
54+
# Fail early:
55+
try:
56+
with engine.connect() as connection:
57+
connection.execute(text("select 'OK'"))
58+
except Exception:
59+
logger.critical(
60+
"Fail querying db: is sqlalchemy_url envvar correctly configured?"
61+
)
62+
raise
63+
64+
Base.prepare(engine)
65+
_Session.configure(bind=engine)
66+
logger.info("startup", engine=engine)
67+
68+
69+
class Base(declarative_base(cls=DeferredReflection)): # type: ignore
70+
__abstract__ = True
71+
72+
73+
class Session(SqlaSession):
74+
def __new__(cls, request: Request) -> SqlaSession:
75+
"""Yield the sqlalchmey session for that request.
76+
77+
It is meant to be used as a FastAPI dependency::
78+
79+
from fastapi import APIRouter, Depends
80+
from fastapi_sqla import Session
81+
82+
router = APIRouter()
83+
84+
@router.get("/users")
85+
def get_users(session: Session = Depends()):
86+
pass
87+
"""
88+
try:
89+
return request.scope[_SESSION_KEY]
90+
except KeyError: # pragma: no cover
91+
raise Exception(
92+
"No session found in request, please ensure you've setup fastapi_sqla."
93+
)
94+
95+
96+
@contextmanager
97+
def open_session() -> Generator[Session, None, None]:
98+
"""Context manager that opens a session and properly closes session when exiting.
99+
100+
If no exception is raised before exiting context, session is committed when exiting
101+
context. If an exception is raised, session is rollbacked.
102+
"""
103+
session = _Session()
104+
logger.bind(db_session=session)
105+
106+
try:
107+
yield session
108+
109+
except Exception:
110+
logger.warning("context failed, rolling back", exc_info=True)
111+
session.rollback()
112+
raise
113+
114+
else:
115+
try:
116+
session.commit()
117+
except Exception:
118+
logger.exception("commit failed, rolling back")
119+
session.rollback()
120+
raise
121+
122+
finally:
123+
session.close()
124+
125+
126+
async def add_session_to_request(request: Request, call_next):
127+
"""Middleware which injects a new sqla session into every request.
128+
129+
Handles creation of session, as well as commit, rollback, and closing of session.
130+
131+
Usage::
132+
133+
import fastapi_sqla
134+
from fastapi import FastApi
135+
136+
app = FastApi()
137+
138+
fastapi_sqla.setup(app) # includes middleware
139+
140+
@app.get("/users")
141+
def get_users(session: fastapi_sqla.Session = Depends()):
142+
return session.execute(...) # use your session here
143+
"""
144+
async with contextmanager_in_threadpool(open_session()) as session:
145+
request.scope[_SESSION_KEY] = session
146+
147+
response = await call_next(request)
148+
149+
is_dirty = bool(session.dirty or session.deleted or session.new)
150+
151+
loop = asyncio.get_running_loop()
152+
153+
# try to commit after response, so that we can return a proper 500 response
154+
# and not raise a true internal server error
155+
if response.status_code < 400:
156+
try:
157+
await loop.run_in_executor(None, session.commit)
158+
except Exception:
159+
logger.exception("commit failed, returning http error")
160+
response = PlainTextResponse(
161+
content="Internal Server Error", status_code=500
162+
)
163+
164+
if response.status_code >= 400:
165+
# If ever a route handler returns an http exception, we do not want the
166+
# session opened by current context manager to commit anything in db.
167+
if is_dirty:
168+
# optimistically only log if there were uncommitted changes
169+
logger.warning(
170+
"http error, rolling back possibly uncommitted changes",
171+
status_code=response.status_code,
172+
)
173+
# since this is no-op if session is not dirty, we can always call it
174+
await loop.run_in_executor(None, session.rollback)
175+
176+
return response
177+
178+
179+
T = TypeVar("T")
180+
181+
182+
class Item(GenericModel, Generic[T]):
183+
"""Item container."""
184+
185+
data: T
186+
187+
188+
class Collection(GenericModel, Generic[T]):
189+
"""Collection container."""
190+
191+
data: list[T]
192+
193+
194+
class Meta(BaseModel):
195+
"""Meta information on current page and collection"""
196+
197+
offset: int = Field(..., description="Current page offset")
198+
total_items: int = Field(..., description="Total number of items in the collection")
199+
total_pages: int = Field(..., description="Total number of pages in the collection")
200+
page_number: int = Field(..., description="Current page number. Starts at 1.")
201+
202+
203+
class Page(Collection, Generic[T]):
204+
"""A page of the collection with info on current page and total items in meta."""
205+
206+
meta: Meta
207+
208+
209+
DbQuery = Union[LegacyQuery, Select]
210+
QueryCountDependency = Callable[..., int]
211+
PaginateSignature = Callable[[DbQuery, Optional[bool]], Page[T]]
212+
DefaultDependency = Callable[[Session, int, int], PaginateSignature]
213+
WithQueryCountDependency = Callable[[Session, int, int, int], PaginateSignature]
214+
PaginateDependency = Union[DefaultDependency, WithQueryCountDependency]
215+
216+
217+
def default_query_count(session: Session, query: DbQuery) -> int:
218+
"""Default function used to count items returned by a query.
219+
220+
It is slower than a manually written query could be: It runs the query in a
221+
subquery, and count the number of elements returned.
222+
223+
See https://gist.github.com/hest/8798884
224+
"""
225+
if isinstance(query, LegacyQuery):
226+
result = query.count()
227+
228+
elif isinstance(query, Select):
229+
result = session.execute(
230+
select(func.count()).select_from(query.subquery())
231+
).scalar()
232+
233+
else: # pragma no cover
234+
raise NotImplementedError(f"Query type {type(query)!r} is not supported")
235+
236+
return result
237+
238+
239+
@singledispatch
240+
def paginate_query(
241+
query: DbQuery,
242+
session: Session,
243+
total_items: int,
244+
offset: int,
245+
limit: int,
246+
scalars: bool = True,
247+
) -> Page[T]: # pragma no cover
248+
"Dispatch on registered functions based on `query` type"
249+
raise NotImplementedError(f"no paginate_query registered for type {type(query)!r}")
250+
251+
252+
@paginate_query.register
253+
def _paginate_legacy(
254+
query: LegacyQuery,
255+
session: Session,
256+
total_items: int,
257+
offset: int,
258+
limit: int,
259+
scalars: bool = True,
260+
) -> Page[T]:
261+
total_pages = math.ceil(total_items / limit)
262+
page_number = offset / limit + 1
263+
return Page[T](
264+
data=query.offset(offset).limit(limit).all(),
265+
meta={
266+
"offset": offset,
267+
"total_items": total_items,
268+
"total_pages": total_pages,
269+
"page_number": page_number,
270+
},
271+
)
272+
273+
274+
@paginate_query.register
275+
def _paginate(
276+
query: Select,
277+
session: Session,
278+
total_items: int,
279+
offset: int,
280+
limit: int,
281+
*,
282+
scalars: bool = True,
283+
) -> Page[T]:
284+
total_pages = math.ceil(total_items / limit)
285+
page_number = offset / limit + 1
286+
query = query.offset(offset).limit(limit)
287+
result = session.execute(query)
288+
data = iter(result.unique().scalars() if scalars else result.mappings())
289+
return Page[T](
290+
data=data,
291+
meta={
292+
"offset": offset,
293+
"total_items": total_items,
294+
"total_pages": total_pages,
295+
"page_number": page_number,
296+
},
297+
)
298+
299+
300+
def Pagination(
301+
min_page_size: int = 10,
302+
max_page_size: int = 100,
303+
query_count: QueryCountDependency = None,
304+
) -> PaginateDependency:
305+
def default_dependency(
306+
session: Session = Depends(),
307+
offset: int = Query(0, ge=0),
308+
limit: int = Query(min_page_size, ge=1, le=max_page_size),
309+
) -> PaginateSignature:
310+
def paginate(query: DbQuery, scalars=True) -> Page[T]:
311+
total_items = default_query_count(session, query)
312+
return paginate_query(
313+
query, session, total_items, offset, limit, scalars=scalars
314+
)
315+
316+
return paginate
317+
318+
def with_query_count_dependency(
319+
session: Session = Depends(),
320+
offset: int = Query(0, ge=0),
321+
limit: int = Query(min_page_size, ge=1, le=max_page_size),
322+
total_items: int = Depends(query_count),
323+
) -> PaginateSignature:
324+
def paginate(query: DbQuery, scalars=True) -> Page[T]:
325+
return paginate_query(
326+
query, session, total_items, offset, limit, scalars=scalars
327+
)
328+
329+
return paginate
330+
331+
if query_count:
332+
return with_query_count_dependency
333+
else:
334+
return default_dependency
335+
336+
337+
Paginate: PaginateDependency = Pagination()

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def tear_down(environ):
8585

8686
close_all_sessions()
8787
# reload fastapi_sqla to clear sqla deferred reflection mapping stored in Base
88+
importlib.reload(fastapi_sqla.sqla)
8889
importlib.reload(fastapi_sqla)
8990

9091

tests/pagination/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def setup_tear_down(sqla_connection, nb_users, nb_notes):
4444
metadata = MetaData()
4545
user = Table("user", metadata, autoload_with=sqla_connection)
4646
note = Table("note", metadata, autoload_with=sqla_connection)
47-
user_params = [{"name": faker.name()}] * nb_users
47+
user_params = [{"name": faker.name()} for i in range(0, nb_users)]
4848
note_params = [
4949
{"user_id": i % 42 + 1, "content": faker.text()} for i in range(0, nb_notes)
5050
]

0 commit comments

Comments
 (0)