Skip to content

Commit ca3bb3f

Browse files
committed
Implement session_factory context manager
1 parent bf0c6f1 commit ca3bb3f

4 files changed

Lines changed: 59 additions & 17 deletions

File tree

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ async def create_user(self):
6161

6262
Do not use explicit `commit()`. `Transactional` class automatically do.
6363

64+
### Query with asyncio.gather()
65+
When executing queries concurrently through `asyncio.gather()`, you must use the `session_factory` context manager rather than the globally used session.
66+
67+
```python
68+
from core.db import session_factory
69+
70+
71+
async def get_by_id(self, *, user_id) -> User:
72+
stmt = select(User)
73+
async with session_factory() as read_session:
74+
return await read_session.execute(query).scalars().first()
75+
76+
77+
async def main() -> None:
78+
user_1, user_2 = await asyncio.gather(
79+
get_by_id(user_id=1),
80+
get_by_id(user_id=2),
81+
)
82+
```
83+
If you do not use a database connection like `session.add()`, it is recommended to use a globally provided session.
84+
6485
### Multiple databases
6586

6687
Go to `core/config.py` and edit `WRITER_DB_URL` and `READER_DB_URL` in the config class.
Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from sqlalchemy import select, or_, and_
1+
from sqlalchemy import and_, or_, select
22

33
from app.user.domain.entity.user import User
44
from app.user.domain.repository.user import UserRepo
5-
from core.db.session import session
5+
from core.db.session import session, session_factory
66

77

88
class UserSQLAlchemyRepo(UserRepo):
@@ -21,7 +21,9 @@ async def get_users(
2121
limit = 12
2222

2323
query = query.limit(limit)
24-
result = await session.execute(query)
24+
async with session_factory() as read_session:
25+
result = await read_session.execute(query)
26+
2527
return result.scalars().all()
2628

2729
async def get_user_by_email_or_nickname(
@@ -30,25 +32,28 @@ async def get_user_by_email_or_nickname(
3032
email: str,
3133
nickname: str,
3234
) -> User | None:
33-
stmt = await session.execute(
34-
select(User).where(or_(User.email == email, User.nickname == nickname)),
35-
)
36-
return stmt.scalars().first()
35+
async with session_factory() as read_session:
36+
stmt = await read_session.execute(
37+
select(User).where(or_(User.email == email, User.nickname == nickname)),
38+
)
39+
return stmt.scalars().first()
3740

3841
async def get_user_by_id(self, *, user_id: int) -> User | None:
39-
query = await session.execute(select(User).where(User.id == user_id))
40-
return query.scalars().first()
42+
async with session_factory() as read_session:
43+
stmt = await read_session.execute(select(User).where(User.id == user_id))
44+
return stmt.scalars().first()
4145

4246
async def get_user_by_email_and_password(
4347
self,
4448
*,
4549
email: str,
4650
password: str,
4751
) -> User | None:
48-
stmt = await session.execute(
49-
select(User).where(and_(User.email == email, password == password))
50-
)
51-
return stmt.scalars().first()
52+
async with session_factory() as read_session:
53+
stmt = await read_session.execute(
54+
select(User).where(and_(User.email == email, password == password))
55+
)
56+
return stmt.scalars().first()
5257

5358
async def save(self, *, user: User) -> None:
5459
session.add(user)

core/db/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from .session import Base, session
1+
from .session import Base, session, session_factory
22
from .transactional import Transactional
33

44
__all__ = [
55
"Base",
66
"session",
77
"Transactional",
8+
"session_factory",
89
]

core/db/session.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
from contextlib import asynccontextmanager
12
from contextvars import ContextVar, Token
23
from enum import Enum
4+
from typing import AsyncGenerator
35

46
from sqlalchemy.ext.asyncio import (
57
AsyncSession,
68
async_scoped_session,
79
async_sessionmaker,
810
create_async_engine,
911
)
10-
from sqlalchemy.orm import Session, DeclarativeBase
12+
from sqlalchemy.orm import DeclarativeBase, Session
1113
from sqlalchemy.sql.expression import Delete, Insert, Update
1214

1315
from core.config import config
@@ -46,16 +48,29 @@ def get_bind(self, mapper=None, clause=None, **kw):
4648
return engines[EngineType.READER].sync_engine
4749

4850

49-
async_session_factory = async_sessionmaker(
51+
_async_session_factory = async_sessionmaker(
5052
class_=AsyncSession,
5153
sync_session_class=RoutingSession,
5254
expire_on_commit=False,
5355
)
5456
session = async_scoped_session(
55-
session_factory=async_session_factory,
57+
session_factory=_async_session_factory,
5658
scopefunc=get_session_context,
5759
)
5860

5961

6062
class Base(DeclarativeBase):
6163
...
64+
65+
66+
@asynccontextmanager
67+
async def session_factory() -> AsyncGenerator[AsyncSession, None]:
68+
_session = async_sessionmaker(
69+
class_=AsyncSession,
70+
sync_session_class=RoutingSession,
71+
expire_on_commit=False,
72+
)()
73+
try:
74+
yield _session
75+
finally:
76+
await _session.close()

0 commit comments

Comments
 (0)