Skip to content

Commit 638a2d6

Browse files
committed
Init
1 parent 1436b0d commit 638a2d6

4 files changed

Lines changed: 120 additions & 163 deletions

File tree

fastapi_forge/dtos.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
model_validator,
1010
)
1111

12-
from fastapi_forge.enums import FieldDataType
12+
from fastapi_forge.enums import DataTypeInfo, FieldDataType, registry
1313
from fastapi_forge.string_utils import camel_to_snake_hyphen, snake_to_camel
1414

1515
BoundedStr = Annotated[str, Field(..., min_length=1, max_length=100)]
@@ -44,10 +44,8 @@ class ModelField(_Base):
4444
nullable: bool = False
4545
unique: bool = False
4646
index: bool = False
47-
4847
default_value: str | None = None
4948
extra_kwargs: dict[str, Any] | None = None
50-
5149
metadata: ModelFieldMetadata = ModelFieldMetadata()
5250

5351
@computed_field
@@ -56,6 +54,10 @@ def name_cc(self) -> str:
5654
"""Convert field name to camelCase."""
5755
return snake_to_camel(self.name)
5856

57+
@property
58+
def type_info(self) -> DataTypeInfo:
59+
return registry.get(self.type)
60+
5961
@model_validator(mode="after")
6062
def _validate(self) -> Self:
6163
"""Validate field constraints."""

fastapi_forge/enums.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11
from enum import StrEnum
2+
from typing import Any
3+
4+
5+
class HTTPMethod(StrEnum):
6+
GET = "get"
7+
GET_ID = "get_id"
8+
POST = "post"
9+
PATCH = "patch"
10+
DELETE = "delete"
211

312

413
class FieldDataType(StrEnum):
@@ -10,21 +19,75 @@ class FieldDataType(StrEnum):
1019
UUID = "UUID"
1120
JSONB = "JSONB"
1221

13-
def as_python_type(self) -> str:
14-
return {
15-
FieldDataType.STRING: "str",
16-
FieldDataType.INTEGER: "int",
17-
FieldDataType.FLOAT: "float",
18-
FieldDataType.BOOLEAN: "bool",
19-
FieldDataType.DATETIME: "datetime",
20-
FieldDataType.UUID: "UUID",
21-
FieldDataType.JSONB: "dict[str, Any]",
22-
}[self]
2322

23+
class DataTypeInfo:
2424

25-
class HTTPMethod(StrEnum):
26-
GET = "get"
27-
GET_ID = "get_id"
28-
POST = "post"
29-
PATCH = "patch"
30-
DELETE = "delete"
25+
def __init__(
26+
self,
27+
pydantic_annotation: str,
28+
sqlalchemy_type: str,
29+
sqlalchemy_prefix: bool,
30+
python_type: str,
31+
faker_generator: str,
32+
value: str,
33+
test_value: str,
34+
test_func: str,
35+
):
36+
self.pydantic_annotation = pydantic_annotation
37+
self.sqlalchemy_type = sqlalchemy_type
38+
self.sqlalchemy_prefix = sqlalchemy_prefix
39+
self.python_type = python_type
40+
self.faker_generator = faker_generator
41+
self.value = value
42+
self.test_value = test_value
43+
self.test_func = test_func
44+
45+
46+
class DataTypeInfoRegistry:
47+
def __init__(self):
48+
self._registry: dict[FieldDataType, DataTypeInfo] = {}
49+
50+
def register(self, field_data_type: FieldDataType, data_type: DataTypeInfo):
51+
if field_data_type in self._registry:
52+
raise ValueError(f"Data type '{field_data_type}' is already registered.")
53+
self._registry[field_data_type] = data_type
54+
55+
def get(self, field_data_type: FieldDataType) -> DataTypeInfo:
56+
if field_data_type not in self._registry:
57+
raise ValueError(f"Data type '{field_data_type}' not found.")
58+
return self._registry[field_data_type]
59+
60+
def all(self) -> list[DataTypeInfo]:
61+
return list(self._registry.values())
62+
63+
64+
registry = DataTypeInfoRegistry()
65+
66+
67+
registry.register(
68+
FieldDataType.STRING,
69+
DataTypeInfo(
70+
pydantic_annotation="",
71+
sqlalchemy_type="String",
72+
sqlalchemy_prefix=False,
73+
python_type="str",
74+
faker_generator="text",
75+
value="hello",
76+
test_value='"world"',
77+
test_func="",
78+
),
79+
)
80+
81+
registry.register(
82+
FieldDataType.DATETIME,
83+
DataTypeInfo(
84+
pydantic_annotation="",
85+
sqlalchemy_type="DateTime",
86+
sqlalchemy_prefix=False,
87+
python_type="datetime",
88+
faker_generator="text",
89+
value="datetime.now(timezone.utc)",
90+
test_value="datetime.now(timezone.utc)",
91+
test_func=".isoformat()",
92+
),
93+
)

fastapi_forge/jinja.py

Lines changed: 29 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class {{ model.name_cc }}DTO(BaseOrmModel):
5959
6060
id: UUID
6161
{% for field in model.fields_sorted if not field.primary_key -%}
62-
{{ field.name }}: {{ type_mapping[field.type] }}{% if field.nullable %} | None{% endif %}
62+
{{ field.name }}: {{ field.type_info.python_type }}{% if field.nullable %} | None{% endif %}
6363
{% endfor %}
6464
6565
@@ -76,7 +76,7 @@ class {{ model.name_cc }}UpdateDTO(BaseModel):
7676
\"\"\"{{ model.name_cc }} update DTO.\"\"\"
7777
7878
{% for field in model.fields_sorted if not (field.metadata.is_created_at_timestamp or field.metadata.is_updated_at_timestamp or field.primary_key) -%}
79-
{{ field.name }}: {{ type_mapping[field.type] }} | None = None
79+
{{ field.name }}: {{ field.type_info.python_type }} | None = None
8080
{% endfor %}
8181
"""
8282

@@ -197,11 +197,7 @@ async def test_post_{{ model.name }}(client: AsyncClient, daos: AllDAOs,) -> Non
197197
{%- if not field.primary_key and field.name.endswith('_id') -%}
198198
"{{ field.name }}": str({{ field.name | replace('_id', '.id') }}),
199199
{%- elif not field.primary_key %}
200-
{%- if field.type == "DateTime" %}
201-
"{{ field.name }}": {{ type_to_input_value_mapping[field.type] }}.isoformat(),
202-
{%- else %}
203-
"{{ field.name }}": {{ type_to_input_value_mapping[field.type] }},
204-
{%- endif %}
200+
"{{ field.name }}": {{ field.type_info.test_value }}{{ field.type_info.test_func }},
205201
{%- endif %}
206202
{%- endfor %}
207203
}
@@ -217,11 +213,7 @@ async def test_post_{{ model.name }}(client: AsyncClient, daos: AllDAOs,) -> Non
217213
{%- if not field.primary_key and field.name.endswith('_id') %}
218214
assert db_{{ model.name }}.{{ field.name }} == UUID(input_json["{{ field.name }}"])
219215
{%- elif not field.primary_key %}
220-
{%- if field.type == "DateTime" %}
221-
assert db_{{ model.name }}.{{ field.name }}.isoformat() == input_json["{{ field.name }}"]
222-
{%- else %}
223-
assert db_{{ model.name }}.{{ field.name }} == input_json["{{ field.name }}"]
224-
{%- endif %}
216+
assert db_{{ model.name }}.{{ field.name }}{{ field.type_info.test_func }} == input_json["{{ field.name }}"]
225217
{%- endif %}
226218
{%- endfor %}
227219
"""
@@ -391,58 +383,65 @@ def _render(model: Model, template_name: str, **kwargs: Any) -> str:
391383

392384

393385
def render_model_to_model(model: Model) -> str:
394-
return _render(model, model_template, type_mapping=TYPE_MAPPING)
386+
return _render(
387+
model,
388+
model_template,
389+
)
395390

396391

397392
def render_model_to_dto(model: Model) -> str:
398-
return _render(model, dto_template, type_mapping=TYPE_MAPPING)
393+
return _render(
394+
model,
395+
dto_template,
396+
)
399397

400398

401399
def render_model_to_dao(model: Model) -> str:
402-
return _render(model, dao_template)
400+
return _render(
401+
model,
402+
dao_template,
403+
)
403404

404405

405406
def render_model_to_routers(model: Model) -> str:
406-
return _render(model, routers_template)
407+
return _render(
408+
model,
409+
routers_template,
410+
)
407411

408412

409413
def render_model_to_post_test(model: Model) -> str:
410414
return _render(
411415
model,
412416
test_template_post,
413-
type_to_input_value_mapping=TYPE_TO_INPUT_VALUE_MAPPING,
414417
)
415418

416419

417420
def render_model_to_get_test(model: Model) -> str:
418421
return _render(
419422
model,
420423
test_template_get,
421-
type_to_input_value_mapping=TYPE_TO_INPUT_VALUE_MAPPING,
422424
)
423425

424426

425427
def render_model_to_get_id_test(model: Model) -> str:
426428
return _render(
427429
model,
428430
test_template_get_id,
429-
type_to_input_value_mapping=TYPE_TO_INPUT_VALUE_MAPPING,
430431
)
431432

432433

433434
def render_model_to_patch_test(model: Model) -> str:
434435
return _render(
435436
model,
436437
test_template_patch,
437-
type_to_input_value_mapping=TYPE_TO_INPUT_VALUE_MAPPING,
438438
)
439439

440440

441441
def render_model_to_delete_test(model: Model) -> str:
442442
return _render(
443443
model,
444444
test_template_delete,
445-
type_to_input_value_mapping=TYPE_TO_INPUT_VALUE_MAPPING,
446445
)
447446

448447

@@ -473,73 +472,29 @@ def render_model_to_delete_test(model: Model) -> str:
473472
type=FieldDataType.STRING,
474473
primary_key=False,
475474
nullable=False,
476-
unique=False,
477-
index=False,
478-
),
479-
],
480-
),
481-
Model(
482-
name="model_a",
483-
fields=[
484-
ModelField(
485-
name="id",
486-
type=FieldDataType.UUID,
487-
primary_key=True,
488475
unique=True,
489-
),
490-
],
491-
),
492-
Model(
493-
name="model_b",
494-
fields=[
495-
ModelField(
496-
name="id",
497-
type=FieldDataType.UUID,
498-
primary_key=True,
499-
unique=True,
500-
),
501-
ModelField(
502-
name="updated_at",
503-
type=FieldDataType.DATETIME,
504-
metadata=ModelFieldMetadata(
505-
is_updated_at_timestamp=True,
506-
),
476+
index=True,
507477
),
508478
ModelField(
509-
name="created_at",
479+
name="timestamp",
510480
type=FieldDataType.DATETIME,
511-
metadata=ModelFieldMetadata(
512-
is_created_at_timestamp=True,
513-
),
514-
),
515-
ModelField(
516-
name="is_mohammad",
517-
type=FieldDataType.BOOLEAN,
518481
),
519482
],
520483
relationships=[
521484
ModelRelationship(
522-
field_name="model_a0_id",
523-
target_model="model_a",
524-
),
525-
ModelRelationship(
526-
field_name="model_a1_id",
527-
target_model="model_a",
528-
),
529-
ModelRelationship(
530-
field_name="user_id",
531-
target_model="auth_user",
532-
),
485+
field_name="yo_id",
486+
target_model="yo",
487+
)
533488
],
534-
),
489+
)
535490
]
536491

537492
render_funcs = [
538-
render_model_to_model,
493+
# render_model_to_model,
539494
# render_model_to_dto,
540495
# render_model_to_dao,
541496
# render_model_to_routers,
542-
# render_model_to_post_test,
497+
render_model_to_post_test,
543498
# render_model_to_get_test,
544499
# render_model_to_get_id_test,
545500
# render_model_to_patch_test,
@@ -553,4 +508,4 @@ def render_model_to_delete_test(model: Model) -> str:
553508
print("=" * 80)
554509
print()
555510

556-
print(fn(models[2]))
511+
print(fn(models[0]))

0 commit comments

Comments
 (0)