Skip to content

Commit e43c318

Browse files
authored
chore: prepare fastapi-sqla to be compliant with sqlalchemy 2.0 - DIA-40134 (#48)
1 parent f57e7ac commit e43c318

11 files changed

Lines changed: 158 additions & 84 deletions

fastapi_sqla/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from fastapi.responses import PlainTextResponse
1212
from pydantic import BaseModel, Field
1313
from pydantic.generics import GenericModel
14-
from sqlalchemy import engine_from_config
14+
from sqlalchemy import engine_from_config, text
1515
from sqlalchemy.ext.declarative import DeferredReflection
1616
from sqlalchemy.orm import Query as LegacyQuery
1717
from sqlalchemy.orm.session import Session as SqlaSession
@@ -63,7 +63,9 @@ def setup(app: FastAPI):
6363

6464

6565
def startup():
66-
lowercase_environ = {k.lower(): v for k, v in os.environ.items()}
66+
lowercase_environ = {
67+
k.lower(): v for k, v in os.environ.items() if k.lower() != "sqlalchemy_warn_20"
68+
}
6769
engine = engine_from_config(lowercase_environ, prefix="sqlalchemy_")
6870
aws_rds_iam_support.setup(engine.engine)
6971

@@ -74,7 +76,7 @@ def startup():
7476
# Fail early:
7577
try:
7678
with open_session() as session:
77-
session.execute("select 'OK'")
79+
session.execute(text("select 'OK'"))
7880
except Exception:
7981
logger.critical(
8082
"Fail querying db: is sqlalchemy_url envvar correctly configured?"
@@ -234,7 +236,9 @@ def default_query_count(session: Session, query: DbQuery) -> int:
234236
result = query.count()
235237

236238
elif isinstance(query, Select):
237-
result = session.execute(select(func.count()).select_from(query)).scalar()
239+
result = session.execute(
240+
select(func.count()).select_from(query.subquery())
241+
).scalar()
238242

239243
else: # pragma no cover
240244
raise NotImplementedError(f"Query type {type(query)!r} is not supported")

fastapi_sqla/_pytest_plugin.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,19 @@
1616
asyncio_support = False
1717

1818

19+
@fixture(scope="session")
20+
def sqla_version_tuple():
21+
from sqlalchemy import __version__
22+
23+
return tuple(int(i) for i in __version__.split("."))
24+
25+
1926
@fixture(scope="session")
2027
def db_host():
2128
"""Default db host used by depending fixtures.
2229
23-
When CI key is set in environment variables, it uses `postgres` as host name else, host used is `localhost`
30+
When CI key is set in environment variables, it uses `postgres` as host name else,
31+
host used is `localhost`
2432
"""
2533

2634
return "postgres" if "CI" in os.environ else "localhost"
@@ -49,9 +57,8 @@ def db_url(db_host, db_user):
4957
@fixture(scope="session")
5058
def sqla_connection(db_url):
5159
engine = create_engine(db_url)
52-
connection = engine.connect()
53-
yield connection
54-
connection.close()
60+
with engine.connect() as connection:
61+
yield connection
5562

5663

5764
@fixture(scope="session")
@@ -69,7 +76,10 @@ def db_migration(db_url, sqla_connection, alembic_ini_path):
6976
alembic_config = Config(file_=alembic_ini_path)
7077
alembic_config.set_main_option("sqlalchemy.url", db_url)
7178

72-
sqla_connection.execute(text("DROP SCHEMA public CASCADE; CREATE SCHEMA public;"))
79+
with sqla_connection.begin():
80+
sqla_connection.execute(
81+
text("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
82+
)
7383

7484
command.upgrade(alembic_config, "head")
7585
yield
@@ -112,7 +122,7 @@ def sqla_transaction(sqla_connection):
112122

113123

114124
@fixture
115-
def session(sqla_transaction, sqla_connection):
125+
def session(sqla_transaction, sqla_connection, sqla_version_tuple):
116126
"""Sqla session to use when creating db fixtures.
117127
118128
While it does not write any record in DB, the application will still be able to
@@ -121,8 +131,15 @@ def session(sqla_transaction, sqla_connection):
121131
import fastapi_sqla
122132

123133
session = fastapi_sqla._Session(bind=sqla_connection)
124-
yield session
125-
session.close()
134+
135+
if sqla_version_tuple >= (1, 4, 0):
136+
with session.begin():
137+
yield session
138+
session.rollback()
139+
else:
140+
with session.begin(subtransactions=True):
141+
session.begin_nested()
142+
yield session
126143

127144

128145
def format_async_async_sqlalchemy_url(url):

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,7 @@ extras =
110110
tests
111111
commands = pytest -vv --cov={envsitepackagesdir}/fastapi_sqla --cov-report xml --cov-report html --junitxml=test-reports/pytest/junit.xml
112112
"""
113+
114+
[tool.black]
115+
exclude = ".mypy_cache|.pytest_cache|.vscode|.eggs|venv"
116+
--skip-magic-trailing-comma = true

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ addopts =
88
--cov-report term
99
--cov-report term-missing
1010

11+
filterwarnings =
12+
error:.*removed in version 2\.0.*:
13+
1114
[pytest-watch]
1215
ext = .py,.yaml,.cfg,.yml
1316

tests/conftest.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import importlib
2+
import os
23
from unittest.mock import patch
34

45
from faker import Faker
56
from pytest import fixture, skip
6-
from sqlalchemy import engine_from_config
7-
from sqlalchemy.orm.session import close_all_sessions
7+
8+
# Must be done before importing anything from sqlalchemy:
9+
os.environ["SQLALCHEMY_WARN_20"] = "true"
810

911
pytest_plugins = ["fastapi_sqla._pytest_plugin", "pytester"]
1012

@@ -23,6 +25,7 @@ def pytest_configure(config):
2325
config.addinivalue_line(
2426
"markers", "require_boto3: skip test if boto3 is not installed"
2527
)
28+
config.addinivalue_line("markers", "dont_patch_engines: do not patch engines")
2629

2730

2831
@fixture(scope="session")
@@ -52,7 +55,11 @@ def is_boto3_installed():
5255

5356
@fixture(scope="session", autouse=True)
5457
def environ(db_url, sqla_version_tuple, async_sqlalchemy_url):
55-
values = {"sqlalchemy_url": db_url, "PYTHONASYNCIODEBUG": "1"}
58+
values = {
59+
"PYTHONASYNCIODEBUG": "1",
60+
"sqlalchemy_url": db_url,
61+
"SQLALCHEMY_WARN_20": "true",
62+
}
5663

5764
if sqla_version_tuple >= (1, 4, 0) and is_asyncpg_installed():
5865
values["async_sqlalchemy_url"] = async_sqlalchemy_url
@@ -63,12 +70,16 @@ def environ(db_url, sqla_version_tuple, async_sqlalchemy_url):
6370

6471
@fixture(scope="session")
6572
def engine(environ):
73+
from sqlalchemy import engine_from_config
74+
6675
engine = engine_from_config(environ, prefix="sqlalchemy_")
6776
return engine
6877

6978

7079
@fixture(autouse=True)
71-
def tear_down():
80+
def tear_down(environ):
81+
from sqlalchemy.orm.session import close_all_sessions
82+
7283
import fastapi_sqla
7384

7485
yield

tests/test_base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
@fixture(autouse=True, scope="module")
77
def setup_tear_down(engine):
88
with engine.connect() as connection:
9-
connection.execute(
10-
text("CREATE TABLE IF NOT EXISTS test_table (id integer primary key)")
11-
)
9+
with connection.begin():
10+
connection.execute(
11+
text("CREATE TABLE IF NOT EXISTS test_table (id integer primary key)")
12+
)
1213
yield
13-
connection.execute(text("DROP TABLE test_table"))
14+
with connection.begin():
15+
connection.execute(text("DROP TABLE test_table"))
1416

1517

1618
def test_startup_reflect_test_table():

tests/test_middleware.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414
@fixture(scope="module", autouse=True)
1515
def setup_tear_down(engine):
1616
with engine.connect() as connection:
17-
connection.execute(
18-
text(
19-
"""
20-
CREATE TABLE IF NOT EXISTS public.user (
21-
id integer primary key,
22-
first_name varchar,
23-
last_name varchar
17+
with connection.begin():
18+
connection.execute(
19+
text(
20+
"""
21+
CREATE TABLE IF NOT EXISTS public.user (
22+
id integer primary key,
23+
first_name varchar,
24+
last_name varchar
25+
)
26+
"""
2427
)
25-
"""
2628
)
27-
)
2829
yield
29-
connection.execute(text("DROP TABLE public.user"))
30+
with connection.begin():
31+
connection.execute(text("DROP TABLE public.user"))
3032

3133

3234
@fixture
@@ -104,11 +106,12 @@ async def test_session_dependency(client, faker, session):
104106
first_name = faker.first_name()
105107
last_name = faker.last_name()
106108
res = await client.post(
107-
"/users",
108-
json={"id": userid, "first_name": first_name, "last_name": last_name},
109+
"/users", json={"id": userid, "first_name": first_name, "last_name": last_name}
109110
)
110111
assert res.status_code == 200, res.json()
111-
row = session.execute(f"select * from public.user where id = {userid}").fetchone()
112+
row = session.execute(
113+
text(f"select * from public.user where id = {userid}")
114+
).fetchone()
112115
assert row == (userid, first_name, last_name)
113116

114117

@@ -124,7 +127,9 @@ async def test_async_session_dependency(client, faker, async_session):
124127
)
125128
assert res.status_code == 200, res.json()
126129
row = (
127-
await async_session.execute(f"select * from public.user where id = {userid}")
130+
await async_session.execute(
131+
text(f"select * from public.user where id = {userid}")
132+
)
128133
).fetchone()
129134
assert row == (userid, first_name, last_name)
130135

tests/test_open_session.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@
55

66
@fixture(autouse=True, scope="module")
77
def module_setup_tear_down(engine, sqla_connection):
8-
engine.execute(
9-
"CREATE TABLE IF NOT EXISTS test_table (id integer primary key, value varchar)"
10-
)
8+
with sqla_connection.begin():
9+
sqla_connection.execute(
10+
text(
11+
"CREATE TABLE IF NOT EXISTS test_table "
12+
"(id integer primary key, value varchar) "
13+
)
14+
)
1115
yield
12-
engine.execute("DROP TABLE test_table")
16+
with sqla_connection.begin():
17+
sqla_connection.execute(text("DROP TABLE test_table"))
1318

1419

1520
@fixture(autouse=True)

tests/test_pagination.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,39 @@
1313
@fixture(scope="module", autouse=True)
1414
def setup_tear_down(sqla_connection):
1515
faker = Faker(seed=0)
16-
sqla_connection.execute(
17-
text(
18-
"create table if not exists public.user "
19-
"(id serial primary key, name varchar)"
16+
with sqla_connection.begin():
17+
sqla_connection.execute(
18+
text(
19+
"create table if not exists public.user "
20+
"(id serial primary key, name varchar)"
21+
)
2022
)
21-
)
22-
sqla_connection.execute(
23-
text(
23+
sqla_connection.execute(
24+
text(
25+
"""
26+
create table if not exists note (
27+
user_id integer,
28+
id serial,
29+
content text,
30+
primary key (user_id, id),
31+
foreign key (user_id) references public.user (id)
32+
)
2433
"""
25-
create table if not exists note (
26-
user_id integer,
27-
id serial,
28-
content text,
29-
primary key (user_id, id),
30-
foreign key (user_id) references public.user (id)
31-
)
32-
"""
34+
)
3335
)
34-
)
35-
metadata = MetaData()
36-
user = Table("user", metadata, autoload_with=sqla_connection)
37-
note = Table("note", metadata, autoload_with=sqla_connection)
38-
user_params = [{"name": faker.name()} for _ in range(1, 43)]
39-
note_params = [
40-
{"user_id": i % 42 + 1, "content": faker.text()} for i in range(0, 22 * 42)
41-
]
42-
sqla_connection.execute(user.insert(), *user_params)
43-
sqla_connection.execute(note.insert(), *note_params)
36+
metadata = MetaData()
37+
user = Table("user", metadata, autoload_with=sqla_connection)
38+
note = Table("note", metadata, autoload_with=sqla_connection)
39+
user_params = [{"name": faker.name()} for _ in range(1, 43)]
40+
note_params = [
41+
{"user_id": i % 42 + 1, "content": faker.text()} for i in range(0, 22 * 42)
42+
]
43+
sqla_connection.execute(user.insert(), user_params)
44+
sqla_connection.execute(note.insert(), note_params)
4445
yield
45-
sqla_connection.execute(text("drop table note cascade"))
46-
sqla_connection.execute(text("drop table public.user cascade"))
46+
with sqla_connection.begin():
47+
sqla_connection.execute(text("drop table note cascade"))
48+
sqla_connection.execute(text("drop table public.user cascade"))
4749

4850

4951
@fixture
@@ -80,7 +82,7 @@ class Note(Base):
8082
def test_pagination(session, user_cls, offset, limit, total_pages, page_number):
8183
from fastapi_sqla import Paginate
8284

83-
query = session.query(user_cls).options(joinedload("notes"))
85+
query = session.query(user_cls).options(joinedload(user_cls.notes))
8486
result = Paginate(session, offset, limit)(query)
8587

8688
assert result.meta.total_items == 42
@@ -99,7 +101,7 @@ def test_pagination_with_legacy_query_count(
99101
):
100102
from fastapi_sqla import Paginate
101103

102-
query = session.query(user_cls).options(joinedload("notes"))
104+
query = session.query(user_cls).options(joinedload(user_cls.notes))
103105
result = Paginate(session, offset, limit)(query)
104106

105107
assert result.meta.total_items == 42
@@ -154,13 +156,17 @@ class UserWithMeta(User):
154156
@app.get("/v1/users", response_model=Page[UserWithNotes])
155157
def sqla_13_all_users(session: Session = Depends(), paginate: Paginate = Depends()):
156158
query = (
157-
session.query(user_cls).options(joinedload("notes")).order_by(user_cls.id)
159+
session.query(user_cls)
160+
.options(joinedload(user_cls.notes))
161+
.order_by(user_cls.id)
158162
)
159163
return paginate(query)
160164

161165
@app.get("/v2/users", response_model=Page[UserWithNotes])
162166
def sqla_14_all_users(paginate: Paginate = Depends()):
163-
query = select(user_cls).options(joinedload("notes")).order_by(user_cls.id)
167+
query = (
168+
select(user_cls).options(joinedload(user_cls.notes)).order_by(user_cls.id)
169+
)
164170
return paginate(query)
165171

166172
@app.get("/v2/users-with-notes-count", response_model=Page[UserWithNotesCount])
@@ -182,8 +188,7 @@ def query_with_JSON_result(paginate: Paginate = Depends()):
182188
user_cls.id,
183189
user_cls.name,
184190
cast(
185-
func.format('{"notes_count": %s}', func.count(note_cls.id)),
186-
JSON,
191+
func.format('{"notes_count": %s}', func.count(note_cls.id)), JSON
187192
).label("meta"),
188193
)
189194
.join(note_cls)

0 commit comments

Comments
 (0)