Skip to content

Commit 8f76222

Browse files
authored
fix(pytest_plugin): dont_patch_engines marker writes to db - DIA-61984 (#108)
1 parent 5bbc073 commit 8f76222

2 files changed

Lines changed: 46 additions & 28 deletions

File tree

fastapi_sqla/_pytest_plugin.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -105,23 +105,19 @@ def sqla_reflection(sqla_modules, sqla_connection):
105105

106106

107107
@fixture
108-
def patch_engine_from_config(request, sqla_connection, sqla_transaction):
108+
def patch_engine_from_config(request, sqla_connection):
109109
"""So that all DB operations are never written to db for real."""
110110

111-
if "dont_patch_engines" in request.keywords: # pragma: no cover
111+
if "dont_patch_engines" in request.keywords:
112112
yield
113-
114113
else:
114+
transaction = sqla_connection.begin()
115+
115116
with patch("fastapi_sqla.sqla.engine_from_config") as engine_from_config:
116117
engine_from_config.return_value = sqla_connection
117-
yield engine_from_config
118-
118+
yield
119119

120-
@fixture
121-
def sqla_transaction(sqla_connection):
122-
transaction = sqla_connection.begin()
123-
yield transaction
124-
transaction.rollback()
120+
transaction.rollback()
125121

126122

127123
@fixture
@@ -131,11 +127,7 @@ def session_factory():
131127

132128
@fixture
133129
def session(
134-
session_factory,
135-
sqla_connection,
136-
sqla_transaction,
137-
sqla_reflection,
138-
patch_engine_from_config,
130+
session_factory, sqla_connection, sqla_reflection, patch_engine_from_config
139131
):
140132
"""Sqla session to use when creating db fixtures.
141133
@@ -173,22 +165,18 @@ async def async_sqla_connection(async_engine, event_loop):
173165
yield connection
174166

175167
@fixture
176-
async def async_sqla_transaction(async_sqla_connection):
177-
async with async_sqla_connection.begin() as transaction:
178-
yield transaction
179-
await transaction.rollback()
180-
181-
@fixture
182-
async def patch_new_engine(request, async_sqla_connection, async_sqla_transaction):
168+
async def patch_new_engine(request, async_sqla_connection):
183169
"""So that all async DB operations are never written to db for real."""
184170

185-
if "dont_patch_engines" in request.keywords: # pragma: no cover
171+
if "dont_patch_engines" in request.keywords:
186172
yield
187-
188173
else:
189-
with patch("fastapi_sqla.async_sqla.new_engine") as new_engine:
190-
new_engine.return_value = async_sqla_connection
191-
yield new_engine
174+
async with async_sqla_connection.begin() as transaction:
175+
with patch("fastapi_sqla.async_sqla.new_engine") as new_engine:
176+
new_engine.return_value = async_sqla_connection
177+
yield
178+
179+
await transaction.rollback()
192180

193181
@fixture
194182
async def async_sqla_reflection(sqla_modules, async_sqla_connection):
@@ -206,7 +194,6 @@ def async_session_factory():
206194
async def async_session(
207195
async_session_factory,
208196
async_sqla_connection,
209-
async_sqla_transaction,
210197
async_sqla_reflection,
211198
patch_new_engine,
212199
):

tests/test_pytest_plugin.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,37 @@ async def test_async_session_fixture_does_not_write_in_db(
5656
).scalar() == 0
5757

5858

59+
@fixture
60+
def truncate_table_tear_down(sqla_connection):
61+
yield
62+
with sqla_connection.begin():
63+
sqla_connection.execute(text("TRUNCATE TABLE singer"))
64+
65+
66+
@mark.dont_patch_engines
67+
def test_session_fixture_dont_patch_engine_writes_in_db(
68+
session, singer_cls, engine, truncate_table_tear_down
69+
):
70+
session.add(singer_cls(id=1, name="Bob Marley", country="Jamaica"))
71+
session.commit()
72+
with engine.connect() as connection:
73+
assert connection.execute(text("select count(*) from singer")).scalar() == 1
74+
75+
76+
@mark.dont_patch_engines
77+
@mark.require_asyncpg
78+
@mark.sqlalchemy("1.4")
79+
async def test_async_session_fixture_dont_patch_engine_writes_in_db(
80+
async_session, singer_cls, async_engine, truncate_table_tear_down
81+
):
82+
async_session.add(singer_cls(id=1, name="Bob Marley", country="Jamaica"))
83+
await async_session.commit()
84+
async with async_engine.connect() as connection:
85+
assert (
86+
await connection.execute(text("select count(*) from singer"))
87+
).scalar() == 1
88+
89+
5990
@mark.sqlalchemy("1.4")
6091
def test_all_opened_sessions_are_within_the_same_transaction(
6192
sqla_connection, session, session_factory, singer_cls

0 commit comments

Comments
 (0)