Skip to content

Commit f19f841

Browse files
authored
feat: sqlalchemy_url accepts async dialect - DIA-46202 (#62)
1 parent 34f6ccc commit f19f841

7 files changed

Lines changed: 114 additions & 41 deletions

File tree

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ EOF
9292

9393
Running the app:
9494
```bash
95-
sqlalchemy_url=sqlite:///db.sqlite?check_same_thread=false uvicorn main:app --reload
95+
sqlalchemy_url=sqlite:///db.sqlite?check_same_thread=false uvicorn main:app
9696
```
9797

9898
# Configuration
@@ -119,10 +119,10 @@ To enable `asyncio` support against a Postgres DB, install `asyncpg`:
119119
pip install asyncpg
120120
```
121121

122-
And define environment variable `async_sqlalchemy_url` with `postgres+asyncpg` scheme:
122+
And define environment variable `sqlalchemy_url` with `postgres+asyncpg` scheme:
123123

124124
```bash
125-
export async_sqlalchemy_url=postgresql+asyncpg://postgres@localhost
125+
export sqlalchemy_url=postgresql+asyncpg://postgres@localhost
126126
```
127127

128128
## Setup the app:

fastapi_sqla/__init__.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pydantic import BaseModel, Field
1313
from pydantic.generics import GenericModel
1414
from sqlalchemy import engine_from_config, text
15+
from sqlalchemy.engine import Engine
1516
from sqlalchemy.ext.declarative import DeferredReflection
1617
from sqlalchemy.orm import Query as LegacyQuery
1718
from sqlalchemy.orm.session import Session as SqlaSession
@@ -51,38 +52,48 @@
5152
_Session = sessionmaker()
5253

5354

55+
def new_engine(*, envvar_prefix: str = None) -> Engine:
56+
envvar_prefix = envvar_prefix if envvar_prefix else "sqlalchemy_"
57+
lowercase_environ = {
58+
k.lower(): v for k, v in os.environ.items() if k.lower() != "sqlalchemy_warn_20"
59+
}
60+
return engine_from_config(lowercase_environ, prefix=envvar_prefix)
61+
62+
63+
def is_async_dialect(engine):
64+
return engine.dialect.is_async if hasattr(engine.dialect, "is_async") else False
65+
66+
5467
def setup(app: FastAPI):
55-
app.add_event_handler("startup", startup)
56-
app.middleware("http")(add_session_to_request)
68+
engine = new_engine()
69+
70+
if not is_async_dialect(engine):
71+
app.add_event_handler("startup", startup)
72+
app.middleware("http")(add_session_to_request)
5773

58-
async_sqlalchemy_url = os.getenv("async_sqlalchemy_url")
59-
if async_sqlalchemy_url:
74+
has_async_config = "async_sqlalchemy_url" in os.environ or is_async_dialect(engine)
75+
if has_async_config:
6076
assert asyncio_support, asyncio_support_err
6177
app.add_event_handler("startup", asyncio_support.startup)
6278
app.middleware("http")(asyncio_support.add_session_to_request)
6379

6480

6581
def startup():
66-
lowercase_environ = {
67-
k.lower(): v for k, v in os.environ.items() if k.lower() != "sqlalchemy_warn_20"
68-
}
69-
engine = engine_from_config(lowercase_environ, prefix="sqlalchemy_")
82+
engine = new_engine()
7083
aws_rds_iam_support.setup(engine.engine)
7184

72-
Base.metadata.bind = engine
73-
Base.prepare(engine)
74-
_Session.configure(bind=engine)
75-
7685
# Fail early:
7786
try:
78-
with open_session() as session:
79-
session.execute(text("select 'OK'"))
87+
with engine.connect() as connection:
88+
connection.execute(text("select 'OK'"))
8089
except Exception:
8190
logger.critical(
8291
"Fail querying db: is sqlalchemy_url envvar correctly configured?"
8392
)
8493
raise
8594

95+
Base.prepare(engine)
96+
_Session.configure(bind=engine)
8697
logger.info("startup", engine=engine)
8798

8899

fastapi_sqla/_pytest_plugin.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,29 +162,23 @@ async def async_sqla_connection(async_engine, event_loop):
162162
await connection.rollback()
163163

164164
@fixture
165-
async def patch_create_async_engine(
166-
async_sqlalchemy_url, async_sqla_connection, request
167-
):
165+
async def patch_new_engine(async_sqlalchemy_url, async_sqla_connection, request):
168166
"""So that all async DB operations are never written to db for real."""
169167
from fastapi_sqla.asyncio_support import _AsyncSession
170168

171169
if "dont_patch_engines" in request.keywords:
172170
yield
173171

174172
else:
175-
with patch(
176-
"fastapi_sqla.asyncio_support.create_async_engine"
177-
) as create_async_engine:
178-
create_async_engine.return_value = async_sqla_connection
173+
with patch("fastapi_sqla.asyncio_support.new_engine") as new_engine:
174+
new_engine.return_value = async_sqla_connection
179175
_AsyncSession.configure(
180176
bind=async_sqla_connection, expire_on_commit=False
181177
)
182-
yield create_async_engine
178+
yield new_engine
183179

184180
@fixture
185-
async def async_session(
186-
async_sqla_connection, sqla_reflection, patch_create_async_engine
187-
):
181+
async def async_session(async_sqla_connection, sqla_reflection, patch_new_engine):
188182
from fastapi_sqla.asyncio_support import _AsyncSession
189183

190184
session = _AsyncSession(bind=async_sqla_connection)
@@ -194,5 +188,5 @@ async def async_session(
194188
else:
195189

196190
@fixture
197-
async def patch_create_async_engine():
191+
async def patch_new_engine():
198192
pass

fastapi_sqla/asyncio_support.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,45 @@
44
import structlog
55
from fastapi import Request
66
from sqlalchemy import text
7+
from sqlalchemy.ext.asyncio import AsyncEngine
78
from sqlalchemy.ext.asyncio import AsyncSession as SqlaAsyncSession
8-
from sqlalchemy.ext.asyncio import create_async_engine
99
from sqlalchemy.orm.session import sessionmaker
1010

11-
from . import aws_rds_iam_support
11+
from . import Base, aws_rds_iam_support, new_engine
1212

1313
logger = structlog.get_logger(__name__)
1414
_ASYNC_SESSION_KEY = "fastapi_sqla_async_session"
1515
_AsyncSession = sessionmaker(class_=SqlaAsyncSession)
1616

1717

18+
def new_async_engine():
19+
envvar_prefix = None
20+
if "async_sqlalchemy_url" in os.environ:
21+
envvar_prefix = "async_sqlalchemy_"
22+
23+
engine = new_engine(envvar_prefix=envvar_prefix)
24+
return AsyncEngine(engine)
25+
26+
1827
async def startup():
19-
async_sqlalchemy_url = os.environ["async_sqlalchemy_url"]
20-
engine = create_async_engine(async_sqlalchemy_url)
28+
engine = new_async_engine()
2129
aws_rds_iam_support.setup(engine.sync_engine)
22-
_AsyncSession.configure(bind=engine, expire_on_commit=False)
2330

2431
# Fail early:
2532
try:
26-
async with open_session() as session:
27-
await session.execute(text("select 'ok'"))
33+
async with engine.connect() as connection:
34+
await connection.execute(text("select 'ok'"))
2835
except Exception:
2936
logger.critical(
30-
"Fail querying db: is async_sqlalchemy_url envvar correctly configured?"
37+
"Failed querying db: is sqlalchemy_url or async_sqlalchemy_url envvar "
38+
"correctly configured?"
3139
)
3240
raise
3341

42+
async with engine.connect() as connection:
43+
await connection.run_sync(Base.prepare)
44+
45+
_AsyncSession.configure(bind=engine, expire_on_commit=False)
3446
logger.info("startup", async_engine=engine)
3547

3648

tests/test_asyncio_support.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from pytest import fixture, mark
1+
from unittest.mock import AsyncMock, patch
2+
3+
from pytest import fixture, mark, raises
24
from sqlalchemy import text
35

46
pytestmark = [mark.sqlalchemy("1.4"), mark.require_asyncpg]
57

68

7-
@fixture(autouse=True)
9+
@fixture
810
async def setup(environ):
911
from fastapi_sqla.asyncio_support import startup
1012

@@ -23,10 +25,40 @@ async def test_startup_configure_async_session():
2325
assert res.scalar() == 123
2426

2527

26-
async def test_open_async_session():
28+
async def test_open_async_session(setup):
2729
from fastapi_sqla.asyncio_support import open_session
2830

2931
async with open_session() as session:
3032
res = await session.execute(text("select 123"))
3133

3234
assert res.scalar() == 123
35+
36+
37+
async def test_new_async_engine_without_async_alchemy_url(
38+
monkeypatch, async_sqlalchemy_url
39+
):
40+
from fastapi_sqla.asyncio_support import new_async_engine
41+
42+
monkeypatch.delenv("async_sqlalchemy_url")
43+
monkeypatch.setenv("sqlalchemy_url", async_sqlalchemy_url)
44+
45+
assert new_async_engine()
46+
47+
48+
@fixture
49+
def AsyncSessionMock():
50+
with patch("fastapi_sqla.asyncio_support._AsyncSession") as AsyncSessionMock:
51+
AsyncSessionMock.return_value = AsyncMock()
52+
yield AsyncSessionMock
53+
54+
55+
async def test_context_manager_rollbacks_on_error(AsyncSessionMock):
56+
from fastapi_sqla.asyncio_support import open_session
57+
58+
session = AsyncSessionMock.return_value
59+
with raises(Exception) as raise_info:
60+
async with open_session():
61+
raise Exception("boom!")
62+
63+
session.rollback.assert_awaited_once_with()
64+
assert raise_info.value.args == ("boom!",)

tests/test_setup.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from unittest.mock import Mock
2+
3+
from pytest import mark
4+
5+
6+
@mark.sqlalchemy("1.4")
7+
@mark.require_asyncpg
8+
def test_setup_with_async_sqlalchemy_url_adds_asyncio_support_startup(
9+
monkeypatch, async_sqlalchemy_url
10+
):
11+
from fastapi_sqla import asyncio_support, setup
12+
13+
monkeypatch.delenv("async_sqlalchemy_url")
14+
monkeypatch.setenv("sqlalchemy_url", async_sqlalchemy_url)
15+
16+
app = Mock()
17+
setup(app)
18+
19+
app.add_event_handler.assert_called_once_with("startup", asyncio_support.startup)
20+
app.middleware.assert_called_once_with("http")
21+
app.middleware.return_value.assert_called_once_with(
22+
asyncio_support.add_session_to_request
23+
)

tests/test_startup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytest import fixture, mark, raises
77
from sqlalchemy import text
88

9-
pytestmark = mark.usefixtures("patch_engine_from_config", "patch_create_async_engine")
9+
pytestmark = mark.usefixtures("patch_engine_from_config", "patch_new_engine")
1010

1111

1212
@fixture(params=[True, False])
@@ -50,6 +50,7 @@ def test_startup(case_sensitive_environ):
5050
assert session.execute(text("SELECT 1")).scalar() == 1
5151

5252

53+
@mark.dont_patch_engines
5354
async def test_fastapi_integration():
5455
from fastapi_sqla import _Session, setup
5556

0 commit comments

Comments
 (0)