Skip to content

Commit a070743

Browse files
authored
feat(pytest): support startup in tests (#299)
1 parent 950c52f commit a070743

6 files changed

Lines changed: 164 additions & 92 deletions

File tree

fastapi_sqla/_pytest_plugin.py

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
import os
2+
from collections.abc import AsyncGenerator, Generator
23
from unittest.mock import patch
34
from urllib.parse import urlsplit, urlunsplit
45

56
from alembic import command
67
from alembic.config import Config
7-
from pytest import fixture
8+
from pytest import FixtureRequest, fixture
89
from sqlalchemy import create_engine, text
9-
from sqlalchemy.orm.session import sessionmaker
10+
from sqlalchemy.engine import Connection, Engine
11+
from sqlalchemy.orm.session import Session, sessionmaker
1012

1113
try:
1214
import asyncpg # noqa
13-
from sqlalchemy.ext.asyncio import create_async_engine
15+
from sqlalchemy.ext.asyncio import (
16+
create_async_engine,
17+
AsyncEngine,
18+
AsyncConnection,
19+
AsyncSession,
20+
)
1421

1522
asyncio_support = True
1623
except ImportError:
@@ -22,7 +29,7 @@ def pytest_configure(config):
2229

2330

2431
@fixture(scope="session")
25-
def db_host():
32+
def db_host() -> str:
2633
"""Default db host used by depending fixtures.
2734
2835
When CI key is set in environment variables, it uses `postgres` as host name else,
@@ -32,7 +39,7 @@ def db_host():
3239

3340

3441
@fixture(scope="session")
35-
def db_user():
42+
def db_user() -> str:
3643
"""Default db user used by depending fixtures.
3744
3845
postgres
@@ -41,7 +48,7 @@ def db_user():
4148

4249

4350
@fixture(scope="session")
44-
def db_url(db_host, db_user):
51+
def db_url(db_host: str, db_user: str) -> str:
4552
"""Default db url used by depending fixtures.
4653
4754
db url example postgresql://{db_user}@{db_host}/postgres
@@ -50,24 +57,24 @@ def db_url(db_host, db_user):
5057

5158

5259
@fixture(scope="session")
53-
def engine(db_url):
60+
def engine(db_url: str) -> Engine:
5461
return create_engine(db_url)
5562

5663

5764
@fixture(scope="session")
58-
def sqla_connection(engine):
65+
def sqla_connection(engine: Engine) -> Generator[Connection]:
5966
with engine.connect() as connection:
6067
yield connection
6168

6269

6370
@fixture(scope="session")
64-
def alembic_ini_path(): # pragma: no cover
71+
def alembic_ini_path() -> str: # pragma: no cover
6572
"""Path for alembic.ini file, defaults to `./alembic.ini`."""
6673
return "./alembic.ini"
6774

6875

6976
@fixture(scope="session")
70-
def db_migration(db_url, sqla_connection, alembic_ini_path):
77+
def db_migration(db_url: str, sqla_connection: Connection, alembic_ini_path: str):
7178
"""Run alembic upgrade at test session setup and downgrade at tear down.
7279
7380
Override fixture `alembic_ini_path` to change path of `alembic.ini` file.
@@ -94,54 +101,52 @@ def sqla_modules():
94101

95102

96103
@fixture
97-
def sqla_reflection(sqla_modules, sqla_connection):
104+
def sqla_reflection(sqla_modules, sqla_connection: Connection):
98105
import fastapi_sqla
99106

100-
fastapi_sqla.Base.metadata.bind = sqla_connection
107+
fastapi_sqla.Base.metadata.bind = sqla_connection # type: ignore
101108
fastapi_sqla.Base.prepare(sqla_connection.engine)
102109

103110

104111
@fixture
105-
def patch_engine_from_config(request, sqla_connection):
112+
def patch_new_engine(request: FixtureRequest, sqla_connection: Connection):
106113
"""So that all DB operations are never written to db for real."""
107114
if "dont_patch_engines" in request.keywords:
108115
yield
109116
else:
110-
transaction = sqla_connection.begin()
111-
112-
with patch("fastapi_sqla.sqla.engine_from_config") as engine_from_config:
113-
engine_from_config.return_value = sqla_connection
114-
yield
117+
with sqla_connection.begin() as transaction:
118+
with patch("fastapi_sqla.sqla.new_engine", return_value=sqla_connection):
119+
yield
115120

116-
transaction.rollback()
121+
transaction.rollback()
117122

118123

119124
@fixture
120-
def session_factory():
121-
return sessionmaker()
125+
def session_factory(
126+
sqla_connection: Connection, sqla_reflection, patch_new_engine
127+
) -> sessionmaker:
128+
return sessionmaker(bind=sqla_connection)
122129

123130

124131
@fixture
125-
def session(
126-
session_factory, sqla_connection, sqla_reflection, patch_engine_from_config
127-
):
132+
def session(session_factory: sessionmaker) -> Generator[Session]:
128133
"""Sqla session to use when creating db fixtures.
129134
130135
While it does not write any record in DB, the application will still be able to
131136
access any record committed with that session.
132137
"""
133-
session = session_factory(bind=sqla_connection)
138+
session: Session = session_factory()
134139
yield session
135140
session.close()
136141

137142

138-
def format_async_async_sqlalchemy_url(url):
143+
def format_async_async_sqlalchemy_url(url: str) -> str:
139144
scheme, location, path, query, fragment = urlsplit(url)
140145
return urlunsplit([f"{scheme}+asyncpg", location, path, query, fragment])
141146

142147

143148
@fixture(scope="session")
144-
def async_sqlalchemy_url(db_url):
149+
def async_sqlalchemy_url(db_url: str) -> str:
145150
"""Default async db url.
146151
147152
It is the same as `db_url` with `postgresql+asyncpg://` as scheme.
@@ -152,46 +157,56 @@ def async_sqlalchemy_url(db_url):
152157
if asyncio_support:
153158

154159
@fixture
155-
def async_engine(async_sqlalchemy_url):
160+
def async_engine(async_sqlalchemy_url: str) -> AsyncEngine:
156161
return create_async_engine(async_sqlalchemy_url)
157162

158163
@fixture
159-
async def async_sqla_connection(async_engine):
164+
async def async_sqla_connection(
165+
async_engine: AsyncEngine,
166+
) -> AsyncGenerator[AsyncConnection]:
160167
async with async_engine.connect() as connection:
161168
yield connection
162169

163170
@fixture
164-
async def patch_new_engine(request, async_sqla_connection):
171+
async def patch_new_async_engine(
172+
request: FixtureRequest, async_sqla_connection: AsyncConnection
173+
):
165174
"""So that all async DB operations are never written to db for real."""
166175
if "dont_patch_engines" in request.keywords:
167176
yield
168177
else:
169178
async with async_sqla_connection.begin() as transaction:
170-
with patch("fastapi_sqla.async_sqla.new_engine") as new_engine:
171-
new_engine.return_value = async_sqla_connection
179+
with patch(
180+
"fastapi_sqla.async_sqla.new_async_engine",
181+
return_value=async_sqla_connection,
182+
):
172183
yield
173184

174185
await transaction.rollback()
175186

176187
@fixture
177-
async def async_sqla_reflection(sqla_modules, async_sqla_connection):
188+
async def async_sqla_reflection(
189+
sqla_modules, async_sqla_connection: AsyncConnection
190+
):
178191
from fastapi_sqla import Base
179192

180193
await async_sqla_connection.run_sync(lambda conn: Base.prepare(conn.engine))
181194

182195
@fixture
183-
def async_session_factory():
184-
from fastapi_sqla.async_sqla import SqlaAsyncSession
185-
186-
return sessionmaker(class_=SqlaAsyncSession)
196+
def async_session_factory(
197+
async_sqla_connection: AsyncConnection,
198+
async_sqla_reflection,
199+
patch_new_async_engine,
200+
) -> sessionmaker:
201+
# TODO: Use async_sessionmaker once only supporting 2.x+
202+
return sessionmaker(
203+
bind=async_sqla_connection, expire_on_commit=False, class_=AsyncSession
204+
) # type: ignore
187205

188206
@fixture
189207
async def async_session(
190-
async_session_factory,
191-
async_sqla_connection,
192-
async_sqla_reflection,
193-
patch_new_engine,
194-
):
195-
session = async_session_factory(bind=async_sqla_connection)
208+
async_session_factory: sessionmaker,
209+
) -> AsyncGenerator[AsyncSession]:
210+
session: AsyncSession = async_session_factory()
196211
yield session
197212
await session.close()

fastapi_sqla/async_sqla.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,53 @@
1+
import os
12
from collections.abc import AsyncGenerator
23
from contextlib import asynccontextmanager
3-
from typing import Annotated
4+
from typing import Annotated, Union
45

56
import structlog
67
from fastapi import Depends, Request, Response
78
from fastapi.responses import PlainTextResponse
89
from sqlalchemy import text
9-
from sqlalchemy.ext.asyncio import AsyncEngine
10+
from sqlalchemy.ext.asyncio import (
11+
AsyncConnection,
12+
AsyncEngine,
13+
async_engine_from_config,
14+
)
1015
from sqlalchemy.ext.asyncio import AsyncSession as SqlaAsyncSession
1116
from sqlalchemy.orm.session import sessionmaker
1217
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1318

1419
from fastapi_sqla import aws_aurora_support, aws_rds_iam_support
15-
from fastapi_sqla.sqla import _DEFAULT_SESSION_KEY, Base, new_engine
20+
from fastapi_sqla.sqla import _DEFAULT_SESSION_KEY, Base, get_envvar_prefix
1621

1722
logger = structlog.get_logger(__name__)
1823

1924
_ASYNC_REQUEST_SESSION_KEY = "fastapi_sqla_async_session"
2025
_async_session_factories: dict[str, sessionmaker] = {}
2126

2227

23-
def new_async_engine(key: str = _DEFAULT_SESSION_KEY):
24-
engine = new_engine(key)
25-
return AsyncEngine(engine)
28+
def new_async_engine(
29+
key: str = _DEFAULT_SESSION_KEY,
30+
) -> Union[AsyncEngine, AsyncConnection]:
31+
envvar_prefix = get_envvar_prefix(key)
32+
lowercase_environ = {k.lower(): v for k, v in os.environ.items()}
33+
lowercase_environ.pop(f"{envvar_prefix}warn_20", None)
34+
return async_engine_from_config(lowercase_environ, prefix=envvar_prefix)
2635

2736

2837
async def startup(key: str = _DEFAULT_SESSION_KEY):
29-
engine = new_async_engine(key)
30-
aws_rds_iam_support.setup(engine.sync_engine)
31-
aws_aurora_support.setup(engine.sync_engine)
38+
engine_or_connection = new_async_engine(key)
39+
aws_rds_iam_support.setup(engine_or_connection.sync_engine)
40+
aws_aurora_support.setup(engine_or_connection.sync_engine)
41+
42+
async_engine = (
43+
engine_or_connection
44+
if isinstance(engine_or_connection, AsyncEngine)
45+
else engine_or_connection.engine
46+
)
3247

3348
# Fail early
3449
try:
35-
async with engine.connect() as connection:
50+
async with async_engine.connect() as connection:
3651
await connection.execute(text("select 'ok'"))
3752
except Exception:
3853
logger.critical(
@@ -41,14 +56,15 @@ async def startup(key: str = _DEFAULT_SESSION_KEY):
4156
)
4257
raise
4358

44-
async with engine.connect() as connection:
59+
async with async_engine.connect() as connection:
4560
await connection.run_sync(lambda conn: Base.prepare(conn.engine))
4661

62+
# TODO: Use async_sessionmaker once only supporting 2.x+
4763
_async_session_factories[key] = sessionmaker(
48-
class_=SqlaAsyncSession, bind=engine, expire_on_commit=False
49-
)
64+
class_=SqlaAsyncSession, bind=engine_or_connection, expire_on_commit=False
65+
) # type: ignore
5066

51-
logger.info("engine startup", engine_key=key, async_engine=engine)
67+
logger.info("engine startup", engine_key=key, async_engine=engine_or_connection)
5268

5369

5470
@asynccontextmanager

fastapi_sqla/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import functools
22
import os
33
import re
4+
from typing import Union
45

56
from deprecated import deprecated
67
from fastapi import FastAPI
7-
from sqlalchemy.engine import Engine
8+
from sqlalchemy.engine import Connection, Engine
89

910
from fastapi_sqla import sqla
1011

@@ -72,5 +73,5 @@ def _get_engine_keys() -> set[str]:
7273
return keys
7374

7475

75-
def _is_async_dialect(engine: Engine):
76+
def _is_async_dialect(engine: Union[Engine, Connection]):
7677
return engine.dialect.is_async if hasattr(engine.dialect, "is_async") else False

fastapi_sqla/sqla.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
import os
33
from collections.abc import Generator
44
from contextlib import contextmanager
5-
from typing import Annotated
5+
from typing import Annotated, Union
66

77
import structlog
88
from fastapi import Depends, Request, Response
99
from fastapi.concurrency import contextmanager_in_threadpool
1010
from fastapi.responses import PlainTextResponse
1111
from sqlalchemy import engine_from_config, text
12-
from sqlalchemy.engine import Engine
12+
from sqlalchemy.engine import Connection, Engine
1313
from sqlalchemy.ext.declarative import DeferredReflection
1414
from sqlalchemy.orm.session import Session as SqlaSession
1515
from sqlalchemy.orm.session import sessionmaker
@@ -42,24 +42,29 @@ class Base(DeclarativeBase, DeferredReflection):
4242
__abstract__ = True
4343

4444

45-
def new_engine(key: str = _DEFAULT_SESSION_KEY) -> Engine:
45+
def get_envvar_prefix(key: str) -> str:
4646
envvar_prefix = "sqlalchemy_"
4747
if key != _DEFAULT_SESSION_KEY:
4848
envvar_prefix = f"fastapi_sqla__{key}__{envvar_prefix}"
4949

50+
return envvar_prefix
51+
52+
53+
def new_engine(key: str = _DEFAULT_SESSION_KEY) -> Union[Engine, Connection]:
54+
envvar_prefix = get_envvar_prefix(key)
5055
lowercase_environ = {k.lower(): v for k, v in os.environ.items()}
5156
lowercase_environ.pop(f"{envvar_prefix}warn_20", None)
5257
return engine_from_config(lowercase_environ, prefix=envvar_prefix)
5358

5459

5560
def startup(key: str = _DEFAULT_SESSION_KEY):
56-
engine = new_engine(key)
57-
aws_rds_iam_support.setup(engine.engine)
58-
aws_aurora_support.setup(engine.engine)
61+
engine_or_connection = new_engine(key)
62+
aws_rds_iam_support.setup(engine_or_connection.engine)
63+
aws_aurora_support.setup(engine_or_connection.engine)
5964

6065
# Fail early
6166
try:
62-
with engine.connect() as connection:
67+
with engine_or_connection.engine.connect() as connection:
6368
connection.execute(text("select 'OK'"))
6469
except Exception:
6570
logger.critical(
@@ -68,11 +73,13 @@ def startup(key: str = _DEFAULT_SESSION_KEY):
6873
)
6974
raise
7075

71-
Base.prepare(engine)
76+
Base.prepare(engine_or_connection.engine)
7277

73-
_session_factories[key] = sessionmaker(bind=engine, class_=SqlaSession)
78+
_session_factories[key] = sessionmaker(
79+
bind=engine_or_connection, class_=SqlaSession
80+
)
7481

75-
logger.info("engine startup", engine_key=key, engine=engine)
82+
logger.info("engine startup", engine_key=key, engine=engine_or_connection)
7683

7784

7885
@contextmanager

0 commit comments

Comments
 (0)