Skip to content

Commit cd535b2

Browse files
authored
feat: Add support for SQLModel - DIA-65344 (#114)
## Problem * FastAPI-SQLA is not compatible with [SQLModel](https://sqlmodel.tiangolo.com/). ## Solution * Instantiate an SQLModel session if sqlmodel can be imported. * SQLAlchemy's `sessionmaker` expect a `class_` argument and because SQLModel's Session inherits from SQLAlchemy's Session, it just works. Also: * Move `fastapi_sqla.models.Base` to `fastapi_sqla.sqla`: It is not a model but the SQLAlchemy declarative base.
1 parent 9ffd1a1 commit cd535b2

14 files changed

Lines changed: 374 additions & 93 deletions

.circleci/config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ workflows:
5555
matrix:
5656
parameters:
5757
python_version: ["3.9", "3.10", "3.11"]
58-
sqlalchemy_version: ["1.4", "2.0"]
58+
sqlalchemy_version: ["1.4", "2.0", "2.0-sqlmodel"]
5959
asyncpg: ["asyncpg", "noasyncpg"]
6060
aws_rds_iam: ["aws_rds_iam", "noaws_rds_iam"]
61-
pydantic_version: ["1.10", "2.0", "2.1"]
61+
pydantic_version: ["1", "2"]
6262

6363
- release/release:
6464
name: release
@@ -97,7 +97,7 @@ jobs:
9797
Specify which version of python to run the tests against
9898
sqlalchemy_version:
9999
type: enum
100-
enum: ["2.0", "1.4", "1.3"]
100+
enum: ["2.0-sqlmodel", "2.0", "1.4", "1.3"]
101101
description: |
102102
Specify which version of sqlalchemy to run the tests against
103103
asyncpg:
@@ -110,7 +110,7 @@ jobs:
110110
description: To run tests with and without asyncpg installed.
111111
pydantic_version:
112112
type: enum
113-
enum: ["1.10", "2.0", "2.1"]
113+
enum: ["1", "2"]
114114
executor:
115115
name: python-postgres
116116
python_version: <<parameters.python_version>>

README.md

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
Fastapi-SQLA is an [SQLAlchemy] extension for [FastAPI] easy to setup with support for
11-
pagination, asyncio, and [pytest].
11+
pagination, asyncio, [SQLModel] and [pytest].
1212
It supports SQLAlchemy>=1.3 and is fully compliant with [SQLAlchemy 2.0].
1313
It is developped, maintained and used on production by the team at [@dialoguemd] with
1414
love from Montreal 🇨🇦.
@@ -539,6 +539,41 @@ async def async_all_users_alt(
539539
return await paginate(select(User))
540540
```
541541

542+
# SQLModel support 🎉
543+
544+
If your project uses [SQLModel], then `Session` dependency is an SQLModel session::
545+
546+
```python
547+
from http import HTTPStatus
548+
549+
from fastapi import FastAPI, HTTPException
550+
from fastapi_sqla import Item, Page, Paginate, Session, setup
551+
from sqlmodel import Field, SQLModel, select
552+
553+
class Hero(SQLModel, table=True):
554+
id: int | None = Field(default=None, primary_key=True)
555+
name: str
556+
secret_name: str
557+
age: int | None = None
558+
559+
560+
app = FastAPI()
561+
setup(app)
562+
563+
@app.get("/heros", response_model=Page[Hero])
564+
def list_hero(paginate: Paginate) -> Page[Hero]:
565+
return paginate(select(Hero))
566+
567+
568+
@app.get("/heros/{hero_id}", response_model=Item[Hero])
569+
def get_hero(hero_id: int, session: Session) -> Item[Hero]:
570+
hero = session.get(Hero, hero_id)
571+
if hero is None:
572+
raise HTTPException(HTTPStatus.NOT_FOUND)
573+
return {"data": hero}
574+
575+
```
576+
542577
# Pytest fixtures
543578

544579
This library provides a set of utility fixtures, through its PyTest plugin, which is
@@ -709,6 +744,7 @@ $ poetry run tox
709744
[FastAPI background tasks]: https://fastapi.tiangolo.com/tutorial/background-tasks/
710745
[SQLAlchemy]: http://sqlalchemy.org/
711746
[SQLAlchemy 2.0]: https://docs.sqlalchemy.org/en/20/changelog/migration_20.html
747+
[SQLModel]: https://sqlmodel.tiangolo.com
712748
[`asyncpg`]: https://magicstack.github.io/asyncpg/current/
713749
[scalars]: https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.Result.scalars
714750
[alembic]: https://alembic.sqlalchemy.org/

fastapi_sqla/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from fastapi_sqla.base import setup
2-
from fastapi_sqla.models import Base, Collection, Item, Page
2+
from fastapi_sqla.models import Collection, Item, Page
33
from fastapi_sqla.pagination import Paginate, PaginateSignature, Pagination
4-
from fastapi_sqla.sqla import Session, SessionDependency, SqlaSession, open_session
4+
from fastapi_sqla.sqla import (
5+
Base,
6+
Session,
7+
SessionDependency,
8+
SqlaSession,
9+
open_session,
10+
)
511

612
__all__ = [
713
"Base",

fastapi_sqla/async_pagination.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
async def default_query_count(session: SqlaAsyncSession, query: Select) -> int:
22-
result = await session.execute(select(func.count()).select_from(query.subquery()))
22+
result = await session.execute(select(func.count()).select_from(query.subquery())) # type: ignore # noqa
2323
return cast(int, result.scalar())
2424

2525

@@ -34,10 +34,10 @@ async def paginate_query(
3434
) -> Page:
3535
total_pages = math.ceil(total_items / limit)
3636
page_number = offset / limit + 1
37-
query = query.offset(offset).limit(limit)
37+
query = query.offset(offset).limit(limit) # type: ignore
3838
result = await session.execute(query)
3939
data = iter(
40-
cast(Iterator, result.unique().scalars() if scalars else result.mappings())
40+
cast(Iterator, result.unique().scalars() if scalars else result.mappings()) # type: ignore # noqa
4141
)
4242
return Page(
4343
data=data,

fastapi_sqla/async_sqla.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from sqlalchemy.orm.session import sessionmaker
1212

1313
from fastapi_sqla import aws_aurora_support, aws_rds_iam_support
14-
from fastapi_sqla.models import Base
15-
from fastapi_sqla.sqla import _DEFAULT_SESSION_KEY, new_engine
14+
from fastapi_sqla.sqla import _DEFAULT_SESSION_KEY, Base, new_engine
1615

1716
logger = structlog.get_logger(__name__)
1817

fastapi_sqla/models.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,6 @@
22

33
from pydantic import BaseModel, Field
44
from pydantic import __version__ as pydantic_version
5-
from sqlalchemy.ext.declarative import DeferredReflection
6-
7-
try:
8-
from sqlalchemy.orm import DeclarativeBase
9-
except ImportError:
10-
from sqlalchemy.ext.declarative import declarative_base
11-
12-
DeclarativeBase = declarative_base() # type: ignore
135

146
major, _, _ = [int(v) for v in pydantic_version.split(".")]
157
is_pydantic2 = major == 2
@@ -19,10 +11,6 @@
1911
from pydantic.generics import GenericModel # type:ignore
2012

2113

22-
class Base(DeclarativeBase, DeferredReflection):
23-
__abstract__ = True
24-
25-
2614
ItemT = TypeVar("ItemT")
2715

2816

fastapi_sqla/pagination.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _paginate_legacy(
6868
total_pages = math.ceil(total_items / limit)
6969
page_number = offset / limit + 1
7070
return Page(
71-
data=query.offset(offset).limit(limit).all(),
71+
data=query.offset(offset).limit(limit).all(), # type: ignore
7272
meta={
7373
"offset": offset,
7474
"total_items": total_items,
@@ -90,10 +90,10 @@ def _paginate(
9090
) -> Page:
9191
total_pages = math.ceil(total_items / limit)
9292
page_number = offset / limit + 1
93-
query = query.offset(offset).limit(limit)
93+
query = query.offset(offset).limit(limit) # type: ignore
9494
result = session.execute(query)
9595
data = iter(
96-
cast(Iterator, result.unique().scalars() if scalars else result.mappings())
96+
cast(Iterator, result.unique().scalars() if scalars else result.mappings()) # type: ignore # noqa
9797
)
9898
return Page(
9999
data=data,

fastapi_sqla/sqla.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,33 @@
22
import os
33
from collections.abc import Generator
44
from contextlib import contextmanager
5-
from typing import Annotated
5+
from typing import Annotated, Generic, TypeVar
66

77
import structlog
88
from fastapi import Depends, Request
99
from fastapi.concurrency import contextmanager_in_threadpool
1010
from fastapi.responses import PlainTextResponse
1111
from sqlalchemy import engine_from_config, text
1212
from sqlalchemy.engine import Engine
13+
from sqlalchemy.ext.declarative import DeferredReflection
1314
from sqlalchemy.orm.session import Session as SqlaSession
1415
from sqlalchemy.orm.session import sessionmaker
1516

1617
from fastapi_sqla import aws_aurora_support, aws_rds_iam_support
17-
from fastapi_sqla.models import Base
18+
19+
try:
20+
from sqlalchemy.orm import DeclarativeBase
21+
except ImportError:
22+
from sqlalchemy.ext.declarative import declarative_base
23+
24+
DeclarativeBase = declarative_base() # type: ignore
25+
26+
try:
27+
from sqlmodel import Session as SqlaSession # type: ignore # noqa
28+
29+
except ImportError:
30+
pass
31+
1832

1933
logger = structlog.get_logger(__name__)
2034

@@ -23,6 +37,10 @@
2337
_session_factories: dict[str, sessionmaker] = {}
2438

2539

40+
class Base(DeclarativeBase, DeferredReflection):
41+
__abstract__ = True
42+
43+
2644
def new_engine(key: str = _DEFAULT_SESSION_KEY) -> Engine:
2745
envvar_prefix = "sqlalchemy_"
2846
if key != _DEFAULT_SESSION_KEY:
@@ -50,7 +68,8 @@ def startup(key: str = _DEFAULT_SESSION_KEY):
5068
raise
5169

5270
Base.prepare(engine)
53-
_session_factories[key] = sessionmaker(bind=engine)
71+
72+
_session_factories[key] = sessionmaker(bind=engine, class_=SqlaSession)
5473

5574
logger.info("engine startup", engine_key=key, engine=engine)
5675

@@ -146,11 +165,14 @@ def get_users(session: fastapi_sqla.Session):
146165
return response
147166

148167

149-
class SessionDependency:
168+
S = TypeVar("S", bound=SqlaSession)
169+
170+
171+
class SessionDependency(Generic[S]):
150172
def __init__(self, key: str = _DEFAULT_SESSION_KEY) -> None:
151173
self.key = key
152174

153-
def __call__(self, request: Request) -> SqlaSession:
175+
def __call__(self, request: Request) -> S:
154176
"""Yield the sqlalchemy session for that request.
155177
156178
It is meant to be used as a FastAPI dependency::
@@ -175,5 +197,5 @@ def get_users(session: SqlaSession = Depends(SessionDependency())):
175197
raise
176198

177199

178-
default_session_dep = SessionDependency()
200+
default_session_dep = SessionDependency[SqlaSession]()
179201
Session = Annotated[SqlaSession, Depends(default_session_dep)]

0 commit comments

Comments
 (0)