Skip to content

Commit 9ea877c

Browse files
committed
Fin
1 parent d32eb69 commit 9ea877c

8 files changed

Lines changed: 247 additions & 62 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,4 +170,5 @@ cython_debug/
170170
.pypirc
171171

172172
# Miscellaneous
173-
delete_me*
173+
delete_me*
174+
game_zone*

fastapi_forge/data_type_registry.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,6 @@ def all(self) -> list[DataTypeInfo]:
4949
),
5050
)
5151

52-
registry.register(
53-
FieldDataType.INTEGER,
54-
DataTypeInfo(
55-
sqlalchemy_type="Integer",
56-
sqlalchemy_prefix=True,
57-
python_type="int",
58-
faker_field_value=faker_placeholder.format(placeholder='"random_int"'),
59-
value="1",
60-
test_value="2",
61-
),
62-
)
6352

6453
registry.register(
6554
FieldDataType.FLOAT,
@@ -123,3 +112,15 @@ def all(self) -> list[DataTypeInfo]:
123112
test_value='{"another_key": 123}',
124113
),
125114
)
115+
116+
registry.register(
117+
FieldDataType.INTEGER,
118+
DataTypeInfo(
119+
sqlalchemy_type="Integer",
120+
sqlalchemy_prefix=True,
121+
python_type="int",
122+
faker_field_value=faker_placeholder.format(placeholder='"random_int"'),
123+
value="1",
124+
test_value="2",
125+
),
126+
)

fastapi_forge/enums.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import StrEnum
2+
from functools import lru_cache
23

34

45
class FieldDataType(StrEnum):
@@ -10,26 +11,58 @@ class FieldDataType(StrEnum):
1011
UUID = "UUID"
1112
JSONB = "JSONB"
1213

14+
@classmethod
15+
@lru_cache
16+
def get_type_mappings(cls) -> dict[str, list[str]]:
17+
return {
18+
cls.STRING: ["character varying", "text", "varchar", "char"],
19+
cls.INTEGER: [
20+
"integer",
21+
"int",
22+
"serial",
23+
"smallint",
24+
"bigint",
25+
"bigserial",
26+
],
27+
cls.FLOAT: [
28+
"real",
29+
"float4",
30+
"double precision",
31+
"float8",
32+
],
33+
cls.BOOLEAN: ["boolean", "bool"],
34+
cls.DATETIME: [
35+
"timestamp",
36+
"timestamp with time zone",
37+
"timestamp without time zone",
38+
"date",
39+
"datetime",
40+
"time",
41+
],
42+
cls.UUID: ["uuid"],
43+
cls.JSONB: ["json", "jsonb"],
44+
}
45+
46+
@classmethod
47+
def get_custom_types(cls) -> dict[str, "FieldDataType"]:
48+
return {}
49+
1350
@classmethod
1451
def from_db_type(cls, db_type: str) -> "FieldDataType":
1552
db_type = db_type.lower()
16-
match db_type:
17-
case _ if db_type.startswith("character varying") or db_type == "text":
18-
return cls.STRING
19-
case "integer" | "bigint" | "smallint":
20-
return cls.INTEGER
21-
case "numeric":
22-
return cls.FLOAT
23-
case "boolean":
24-
return cls.BOOLEAN
25-
case "uuid":
26-
return cls.UUID
27-
case _ if db_type.startswith("timestamp") or "date":
28-
return cls.DATETIME
29-
case "jsonb":
30-
return cls.JSONB
31-
case _:
32-
raise ValueError(f"Unsupported database type: {db_type}")
53+
54+
custom_types = cls.get_custom_types()
55+
if db_type in custom_types:
56+
return custom_types[db_type]
57+
58+
for field_type, patterns in cls.get_type_mappings().items():
59+
if any(pattern in db_type for pattern in patterns):
60+
return field_type if isinstance(field_type, cls) else cls(field_type)
61+
62+
raise ValueError(
63+
f"Unsupported database type: {db_type}. "
64+
f"Supported types are: {list(cls.get_type_mappings().keys())}"
65+
)
3366

3467

3568
class HTTPMethod(StrEnum):

fastapi_forge/jinja.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from sqlalchemy.dialects.postgresql import JSONB
2222
from uuid import UUID
2323
from typing import Any, Annotated
24-
from datetime import datetime, timezone
24+
from datetime import datetime, timezone, timedelta
25+
26+
2527
{% set unique_relationships = model.relationships | unique(attribute='target') %}
2628
{% for relation in unique_relationships -%}
2729
from src.models.{{ relation.target_model }}_models import {{ relation.target }}
@@ -45,7 +47,9 @@ class {{ model.name_cc }}(Base):
4547
"""
4648

4749
dto_template = """
48-
from datetime import datetime
50+
from datetime import datetime, timezone, timedelta
51+
52+
4953
from pydantic import BaseModel, ConfigDict, Field
5054
from fastapi import Depends
5155
from uuid import UUID
@@ -177,7 +181,9 @@ async def get_{{ model.name }}(
177181
from tests import factories
178182
from src.daos import AllDAOs
179183
from httpx import AsyncClient
180-
from datetime import datetime, timezone
184+
from datetime import datetime, timezone, timedelta
185+
186+
181187
from typing import Any
182188
from uuid import UUID
183189
@@ -221,7 +227,9 @@ async def test_post_{{ model.name }}(client: AsyncClient, daos: AllDAOs,) -> Non
221227
import pytest
222228
from tests import factories
223229
from httpx import AsyncClient
224-
from datetime import datetime
230+
from datetime import datetime, timezone, timedelta
231+
232+
225233
from uuid import UUID
226234
227235
URI = "/api/v1/{{ model.name_hyphen }}s/"
@@ -246,7 +254,9 @@ async def test_get_{{ model.name }}s(client: AsyncClient,) -> None:
246254
import pytest
247255
from tests import factories
248256
from httpx import AsyncClient
249-
from datetime import datetime
257+
from datetime import datetime, timezone, timedelta
258+
259+
250260
from uuid import UUID
251261
252262
URI = "/api/v1/{{ model.name_hyphen }}s/{ {{- model.name -}}_id}"
@@ -276,7 +286,9 @@ async def test_get_{{ model.name }}_by_id(client: AsyncClient,) -> None:
276286
from tests import factories
277287
from src.daos import AllDAOs
278288
from httpx import AsyncClient
279-
from datetime import datetime, timezone
289+
from datetime import datetime, timezone, timedelta
290+
291+
280292
from typing import Any
281293
from uuid import UUID
282294
@@ -322,7 +334,9 @@ async def test_patch_{{ model.name }}(client: AsyncClient, daos: AllDAOs,) -> No
322334
from tests import factories
323335
from src.daos import AllDAOs
324336
from httpx import AsyncClient
325-
from datetime import datetime
337+
from datetime import datetime, timezone, timedelta
338+
339+
326340
from uuid import UUID
327341
328342
URI = "/api/v1/{{ model.name_hyphen }}s/{ {{- model.name -}}_id}"

fastapi_forge/project_io.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,22 +273,32 @@ def load_project_spec_from_db(
273273

274274
data_type = FieldDataType.from_db_type(column.pop("type"))
275275
column["type"] = data_type
276+
default = None
277+
extra_kwargs = None
276278

277279
metadata = ModelFieldMetadata()
278280
if data_type == FieldDataType.DATETIME:
279281
column_name = column["name"]
280282
default_timestamp = column.get("default") == "CURRENT_TIMESTAMP"
281283
if default_timestamp:
282-
if "created" in column_name:
284+
if "create" in column_name:
283285
metadata.is_created_at_timestamp = True
284-
elif "updated" in column_name:
286+
default = "datetime.now(timezone.utc)"
287+
elif "update" in column_name:
285288
metadata.is_updated_at_timestamp = True
289+
default = "datetime.now(timezone.utc)"
290+
extra_kwargs = {"onupdate": "datetime.now(timezone.utc)"}
286291

287292
# temporary until any primary key name is supported
288293
if column["primary_key"] is True:
289294
column["name"] = "id"
290295

291-
field = ModelField(**column, metadata=metadata)
296+
field = ModelField(
297+
**column,
298+
metadata=metadata,
299+
default_value=default,
300+
extra_kwargs=extra_kwargs,
301+
)
292302
fields.append(field)
293303

294304
model = Model(

fastapi_forge/template/{{cookiecutter.project_name}}/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ test:
1111
$(DOCKER_CMD) pytest ./tests -v -s
1212

1313
test-filter:
14-
$(DOCKER_CMD) /tests -v -s -k $(filter)
14+
$(DOCKER_CMD) pytest ./tests -v -s -k $(filter)
1515

1616
mig-gen:
1717
$(DOCKER_CMD) alembic revision --autogenerate -m "$(name)"

fastapi_forge/template/{{cookiecutter.project_name}}/tests/factories.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from src.db import Base
77
import factory
88
from typing import Any
9+
from datetime import datetime, timezone, timedelta
10+
11+
912

1013
{% for model in cookiecutter.models.models -%}
1114
from src.models.{{ model.name }}_models import {{ model.name_cc }}

0 commit comments

Comments
 (0)