-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathmixins.py
More file actions
338 lines (254 loc) · 9.42 KB
/
mixins.py
File metadata and controls
338 lines (254 loc) · 9.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
from __future__ import annotations
import typing as t
from datetime import datetime
import sqlalchemy
import sqlalchemy.event
import sqlalchemy.exc
import sqlalchemy.ext
import sqlalchemy.ext.asyncio
import sqlalchemy.orm
import sqlalchemy.util
import typing_extensions as tx
from sqlalchemy.orm import Mapped
from ..util import camel_to_snake_case
sa = sqlalchemy
class ORMModel(tx.Protocol):
__table__: sa.Table
class SerializingModel(ORMModel):
__table__: sa.Table
def to_dict(
self: ORMModel,
obj: t.Optional[t.Any] = None,
max_depth: int = 3,
_children_seen: t.Optional[set] = None,
_relations_seen: t.Optional[set] = None,
) -> t.Dict[str, t.Any]:
...
class TableNameMixin:
__abstract__ = True
__table__: sa.Table
@sa.orm.declared_attr.directive
def __tablename__(cls: t.Type[ORMModel]) -> str:
return camel_to_snake_case(cls.__name__)
class ReprMixin:
__abstract__ = True
__table__: sa.Table
def __repr__(self: ORMModel) -> str:
state = sa.inspect(self)
if state is None:
return super().__repr__()
if state.transient:
pk = f"(transient {id(self)})"
elif state.pending:
pk = f"(pending {id(self)})"
else:
pk = ", ".join(map(str, state.identity))
return f"<{type(self).__name__} {pk}>"
class ComparableMixin:
__abstract__ = True
__table__: sa.Table
def __eq__(self: ORMModel, other: ORMModel) -> bool:
if type(self).__name__ != type(other).__name__:
return False
for key, column in sa.inspect(type(self)).columns.items():
if column.primary_key:
continue
if not (getattr(self, key) == getattr(other, key)):
return False
return True
class TotalOrderMixin:
__abstract__ = True
__table__: sa.Table
def __lt__(self: ORMModel, other: ORMModel) -> bool:
if type(self).__name__ != type(other).__name__:
raise NotImplemented
primary_keys = sa.inspect(type(self)).primary_key
self_keys = [getattr(self, col.name) for col in primary_keys]
other_keys = [getattr(other, col.name) for col in primary_keys]
return self_keys < other_keys
class SimpleDictMixin:
__abstract__ = True
__table__: sa.Table
def to_dict(self) -> t.Dict[str, t.Any]:
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
class RecursiveDictMixin:
__abstract__ = True
__table__: sa.Table
def to_dict(
self: tx.Self,
obj: t.Optional[t.Any] = None,
max_depth: int = 1,
_children_seen: t.Optional[set] = None,
_relations_seen: t.Optional[set] = None,
) -> t.Dict[str, t.Any]:
"""Convert model to python dict, with recursion.
Args:
obj (self):
SQLAlchemy model inheriting from DeclarativeBase.
max_depth (int):
Maximum depth for recursion on relationships, defaults to 3.
Returns:
(dict) representation of the SQLAlchemy model.
"""
if obj is None:
obj = self
if _children_seen is None:
_children_seen = set()
if _relations_seen is None:
_relations_seen = set()
mapper = sa.inspect(obj).mapper
columns = [column.key for column in mapper.columns]
get_key_value = lambda c: (c, getattr(obj, c))
data = dict(map(get_key_value, columns))
if max_depth > 0:
for name, relation in mapper.relationships.items():
if name in _relations_seen:
continue
if relation.backref:
_relations_seen.add(name)
relationship_children = getattr(obj, name)
if relationship_children is not None:
if relation.uselist:
children = []
for child in (
c for c in relationship_children if repr(c) not in _children_seen
):
_children_seen.add(repr(child))
children.append(
self.to_dict(
child,
max_depth=max_depth - 1,
_children_seen=_children_seen,
_relations_seen=_relations_seen,
)
)
data[name] = children
else:
data[name] = self.to_dict(
relationship_children,
max_depth=max_depth - 1,
_children_seen=_children_seen,
_relations_seen=_relations_seen,
)
return data
class IdentityMixin:
__abstract__ = True
__table__: sa.Table
id: Mapped[int] = sa.orm.mapped_column(sa.Identity(), primary_key=True, autoincrement=True)
class SoftDeleteMixin:
"""Use as a mixin in a class to opt-in to the soft-delete feature.
At initialization time, the `soft_delete_filter` function below is registered on the
`do_orm_execute` event.
The expected effects of using this mixin are the addition of an is_active column by default, and
Example:
class User(db.Model, SoftDeleteMixin):
id: Mapped[int] = sa.orm.mapped_column(primary_key=True)
email: Mapped[str] = sa.orm.mapped_column()
db.create_all()
u = User(email="joe@magic.link")
db.session.add(u)
db.session.commit()
statement = select(User).where(name="joe@magic.link")
# returns user
result = db.session.execute(statement).scalars().one()
# Mark inactive
u.is_active = False
db.session.add(u)
db.session.commit()
# User not found!
result = db.session.execute(statement).scalars().one()
# User found (when manually adding include_inactive execution option).
# Now you can reactivate them if you like.
result = db.session.execute(statement.execution_options(include_inactive=True)).scalars().one()
see: https://docs.sqlalchemy.org/en/20/orm/versioning.html
"""
__abstract__ = True
__table__: sa.Table
is_active: Mapped[bool] = sa.orm.mapped_column(default=True)
class TimestampMixin:
__abstract__ = True
__table__: sa.Table
created_at: Mapped[datetime] = sa.orm.mapped_column(
default=sa.func.now(), server_default=sa.FetchedValue()
)
updated_at: Mapped[datetime] = sa.orm.mapped_column(
default=sa.func.now(),
onupdate=sa.func.now(),
server_default=sa.FetchedValue(),
server_onupdate=sa.FetchedValue(),
)
class VersionMixin:
__abstract__ = True
__table__: sa.Table
version_id: Mapped[int] = sa.orm.mapped_column(nullable=False)
@sa.orm.declared_attr.directive
def __mapper_args__(cls) -> dict[str, t.Any]:
return dict(
version_id_col=cls.version_id,
)
class EagerDefaultsMixin:
"""
https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.Mapper.params.eager_defaults
"""
__abstract__ = True
__table__: sa.Table
@sa.orm.declared_attr.directive
def __mapper_args__(cls) -> dict[str, t.Any]:
return dict(
eager_defaults=True,
)
def soft_delete_filter(execute_state: sa.orm.ORMExecuteState) -> None:
if execute_state.is_select and not execute_state.execution_options.get(
"include_inactive", False
):
execute_state.statement = execute_state.statement.options(
sa.orm.with_loader_criteria(
SoftDeleteMixin,
lambda cls: cls.is_active == sa.true(),
include_aliases=True,
)
)
def setup_soft_delete_for_session(session: t.Type[sa.orm.Session]) -> None:
if not sa.event.contains(
session,
"do_orm_execute",
soft_delete_filter,
):
sa.event.listen(
session,
"do_orm_execute",
soft_delete_filter,
propagate=True,
)
def accumulate_mappings(class_, attribute) -> t.Dict[str, t.Any]:
accumulated = {}
for base_class in class_.__mro__[::-1]:
if base_class is class_:
continue
args = getattr(base_class, attribute, {})
accumulated |= args
return accumulated
def accumulate_tuples_with_mapping(class_, attribute) -> t.Sequence[t.Any]:
accumulated_map = {}
accumulated_args = []
for base_class in class_.__mro__[::-1]:
if base_class is class_:
continue
args = getattr(base_class, attribute, ())
for arg in args:
if isinstance(arg, t.Mapping):
accumulated_map |= arg
else:
accumulated_args.append(arg)
if accumulated_map:
accumulated_args.append(accumulated_map)
return tuple(accumulated_args)
class DynamicArgsMixin:
__abstract__ = True
__table__: sa.Table
@sa.orm.declared_attr.directive
def __mapper_args__(cls) -> t.Dict[str, t.Any]:
return accumulate_mappings(cls, "__mapper_args__")
@sa.orm.declared_attr.directive
def __table_args__(cls) -> t.Sequence[t.Any]:
return accumulate_tuples_with_mapping(cls, "__table_args__")