11import os
2+ from collections .abc import AsyncGenerator , Generator
23from unittest .mock import patch
34from urllib .parse import urlsplit , urlunsplit
45
56from alembic import command
67from alembic .config import Config
7- from pytest import fixture
8+ from pytest import FixtureRequest , fixture
89from sqlalchemy import create_engine , text
9- from sqlalchemy .orm .session import sessionmaker
10+ from sqlalchemy .engine import Connection , Engine
11+ from sqlalchemy .orm .session import Session , sessionmaker
1012
1113try :
1214 import asyncpg # noqa
13- from sqlalchemy .ext .asyncio import create_async_engine
15+ from sqlalchemy .ext .asyncio import (
16+ create_async_engine ,
17+ AsyncEngine ,
18+ AsyncConnection ,
19+ AsyncSession ,
20+ )
1421
1522 asyncio_support = True
1623except ImportError :
@@ -22,7 +29,7 @@ def pytest_configure(config):
2229
2330
2431@fixture (scope = "session" )
25- def db_host ():
32+ def db_host () -> str :
2633 """Default db host used by depending fixtures.
2734
2835 When CI key is set in environment variables, it uses `postgres` as host name else,
@@ -32,7 +39,7 @@ def db_host():
3239
3340
3441@fixture (scope = "session" )
35- def db_user ():
42+ def db_user () -> str :
3643 """Default db user used by depending fixtures.
3744
3845 postgres
@@ -41,7 +48,7 @@ def db_user():
4148
4249
4350@fixture (scope = "session" )
44- def db_url (db_host , db_user ) :
51+ def db_url (db_host : str , db_user : str ) -> str :
4552 """Default db url used by depending fixtures.
4653
4754 db url example postgresql://{db_user}@{db_host}/postgres
@@ -50,24 +57,24 @@ def db_url(db_host, db_user):
5057
5158
5259@fixture (scope = "session" )
53- def engine (db_url ) :
60+ def engine (db_url : str ) -> Engine :
5461 return create_engine (db_url )
5562
5663
5764@fixture (scope = "session" )
58- def sqla_connection (engine ) :
65+ def sqla_connection (engine : Engine ) -> Generator [ Connection ] :
5966 with engine .connect () as connection :
6067 yield connection
6168
6269
6370@fixture (scope = "session" )
64- def alembic_ini_path (): # pragma: no cover
71+ def alembic_ini_path () -> str : # pragma: no cover
6572 """Path for alembic.ini file, defaults to `./alembic.ini`."""
6673 return "./alembic.ini"
6774
6875
6976@fixture (scope = "session" )
70- def db_migration (db_url , sqla_connection , alembic_ini_path ):
77+ def db_migration (db_url : str , sqla_connection : Connection , alembic_ini_path : str ):
7178 """Run alembic upgrade at test session setup and downgrade at tear down.
7279
7380 Override fixture `alembic_ini_path` to change path of `alembic.ini` file.
@@ -94,54 +101,52 @@ def sqla_modules():
94101
95102
96103@fixture
97- def sqla_reflection (sqla_modules , sqla_connection ):
104+ def sqla_reflection (sqla_modules , sqla_connection : Connection ):
98105 import fastapi_sqla
99106
100- fastapi_sqla .Base .metadata .bind = sqla_connection
107+ fastapi_sqla .Base .metadata .bind = sqla_connection # type: ignore
101108 fastapi_sqla .Base .prepare (sqla_connection .engine )
102109
103110
104111@fixture
105- def patch_engine_from_config (request , sqla_connection ):
112+ def patch_new_engine (request : FixtureRequest , sqla_connection : Connection ):
106113 """So that all DB operations are never written to db for real."""
107114 if "dont_patch_engines" in request .keywords :
108115 yield
109116 else :
110- transaction = sqla_connection .begin ()
111-
112- with patch ("fastapi_sqla.sqla.engine_from_config" ) as engine_from_config :
113- engine_from_config .return_value = sqla_connection
114- yield
117+ with sqla_connection .begin () as transaction :
118+ with patch ("fastapi_sqla.sqla.new_engine" , return_value = sqla_connection ):
119+ yield
115120
116- transaction .rollback ()
121+ transaction .rollback ()
117122
118123
119124@fixture
120- def session_factory ():
121- return sessionmaker ()
125+ def session_factory (
126+ sqla_connection : Connection , sqla_reflection , patch_new_engine
127+ ) -> sessionmaker :
128+ return sessionmaker (bind = sqla_connection )
122129
123130
124131@fixture
125- def session (
126- session_factory , sqla_connection , sqla_reflection , patch_engine_from_config
127- ):
132+ def session (session_factory : sessionmaker ) -> Generator [Session ]:
128133 """Sqla session to use when creating db fixtures.
129134
130135 While it does not write any record in DB, the application will still be able to
131136 access any record committed with that session.
132137 """
133- session = session_factory (bind = sqla_connection )
138+ session : Session = session_factory ()
134139 yield session
135140 session .close ()
136141
137142
138- def format_async_async_sqlalchemy_url (url ) :
143+ def format_async_async_sqlalchemy_url (url : str ) -> str :
139144 scheme , location , path , query , fragment = urlsplit (url )
140145 return urlunsplit ([f"{ scheme } +asyncpg" , location , path , query , fragment ])
141146
142147
143148@fixture (scope = "session" )
144- def async_sqlalchemy_url (db_url ) :
149+ def async_sqlalchemy_url (db_url : str ) -> str :
145150 """Default async db url.
146151
147152 It is the same as `db_url` with `postgresql+asyncpg://` as scheme.
@@ -152,46 +157,56 @@ def async_sqlalchemy_url(db_url):
152157if asyncio_support :
153158
154159 @fixture
155- def async_engine (async_sqlalchemy_url ) :
160+ def async_engine (async_sqlalchemy_url : str ) -> AsyncEngine :
156161 return create_async_engine (async_sqlalchemy_url )
157162
158163 @fixture
159- async def async_sqla_connection (async_engine ):
164+ async def async_sqla_connection (
165+ async_engine : AsyncEngine ,
166+ ) -> AsyncGenerator [AsyncConnection ]:
160167 async with async_engine .connect () as connection :
161168 yield connection
162169
163170 @fixture
164- async def patch_new_engine (request , async_sqla_connection ):
171+ async def patch_new_async_engine (
172+ request : FixtureRequest , async_sqla_connection : AsyncConnection
173+ ):
165174 """So that all async DB operations are never written to db for real."""
166175 if "dont_patch_engines" in request .keywords :
167176 yield
168177 else :
169178 async with async_sqla_connection .begin () as transaction :
170- with patch ("fastapi_sqla.async_sqla.new_engine" ) as new_engine :
171- new_engine .return_value = async_sqla_connection
179+ with patch (
180+ "fastapi_sqla.async_sqla.new_async_engine" ,
181+ return_value = async_sqla_connection ,
182+ ):
172183 yield
173184
174185 await transaction .rollback ()
175186
176187 @fixture
177- async def async_sqla_reflection (sqla_modules , async_sqla_connection ):
188+ async def async_sqla_reflection (
189+ sqla_modules , async_sqla_connection : AsyncConnection
190+ ):
178191 from fastapi_sqla import Base
179192
180193 await async_sqla_connection .run_sync (lambda conn : Base .prepare (conn .engine ))
181194
182195 @fixture
183- def async_session_factory ():
184- from fastapi_sqla .async_sqla import SqlaAsyncSession
185-
186- return sessionmaker (class_ = SqlaAsyncSession )
196+ def async_session_factory (
197+ async_sqla_connection : AsyncConnection ,
198+ async_sqla_reflection ,
199+ patch_new_async_engine ,
200+ ) -> sessionmaker :
201+ # TODO: Use async_sessionmaker once only supporting 2.x+
202+ return sessionmaker (
203+ bind = async_sqla_connection , expire_on_commit = False , class_ = AsyncSession
204+ ) # type: ignore
187205
188206 @fixture
189207 async def async_session (
190- async_session_factory ,
191- async_sqla_connection ,
192- async_sqla_reflection ,
193- patch_new_engine ,
194- ):
195- session = async_session_factory (bind = async_sqla_connection )
208+ async_session_factory : sessionmaker ,
209+ ) -> AsyncGenerator [AsyncSession ]:
210+ session : AsyncSession = async_session_factory ()
196211 yield session
197212 await session .close ()
0 commit comments