Skip to content

Commit 77b1676

Browse files
authored
update: database loading improved (#34)
1 parent 903626c commit 77b1676

12 files changed

Lines changed: 340 additions & 314 deletions

File tree

fastapi_forge/__main__.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,23 @@ def main() -> None:
2828
help="Generate a project using a custom configuration from a YAML file.",
2929
)
3030
@click.option(
31-
"--db-url",
32-
help="PostgreSQL connection URL (e.g., postgresql://user:password@host:port/dbname)",
31+
"--conn-string",
32+
help="PostgreSQL connection string (e.g., postgresql://user:password@host:port/dbname)",
3333
)
3434
def start(
3535
use_example: bool = False,
3636
no_ui: bool = False,
3737
from_yaml: str | None = None,
38-
db_url: str | None = None,
38+
conn_string: str | None = None,
3939
) -> None:
4040
"""Start the FastAPI Forge server and generate a new project."""
41-
option_count = sum([use_example, bool(from_yaml), bool(db_url)])
41+
option_count = sum([use_example, bool(from_yaml), bool(conn_string)])
4242
if option_count > 1:
43-
msg = "Only one of '--use-example', '--from-yaml', or '--db-url' can be used."
43+
msg = "Only one of '--use-example', '--from-yaml', or '--conn-string' can be used."
4444
raise click.UsageError(msg)
4545

4646
if no_ui and option_count < 1:
47-
msg = "Option '--no-ui' requires one of '--use-example', '--from-yaml', or '--db-url' to be set."
47+
msg = "Option '--no-ui' requires one of '--use-example', '--from-yaml', or '--conn-string' to be set."
4848
raise click.UsageError(msg)
4949

5050
project_spec = None
@@ -53,16 +53,16 @@ def start(
5353
yaml_path = Path(from_yaml).expanduser().resolve()
5454
if not yaml_path.is_file():
5555
raise click.FileError(f"YAML file not found: {yaml_path}")
56-
project_spec = ProjectLoader(project_path=yaml_path).load_project()
57-
elif db_url:
58-
project_spec = ProjectLoader.load_project_from_db(
59-
connection_string=db_url,
56+
project_spec = ProjectLoader(project_path=yaml_path).load()
57+
elif conn_string:
58+
project_spec = ProjectLoader.load_from_conn_string(
59+
conn_string=conn_string,
6060
)
6161
elif use_example:
6262
base_path = Path(__file__).parent / "example-projects"
6363
path = base_path / "game_zone.yaml"
6464

65-
project_spec = ProjectLoader(project_path=path).load_project()
65+
project_spec = ProjectLoader(project_path=path).load()
6666

6767
init(project_spec=project_spec, no_ui=no_ui)
6868

fastapi_forge/dtos.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
from fastapi_forge.constants import TAB
1313
from fastapi_forge.enums import FieldDataTypeEnum, OnDeleteEnum
14-
from fastapi_forge.string_utils import camel_to_snake_hyphen, snake_to_camel
14+
from fastapi_forge.string_utils import (
15+
camel_to_snake,
16+
camel_to_snake_hyphen,
17+
snake_to_camel,
18+
)
1519
from fastapi_forge.type_info_registry import TypeInfo, enum_registry, registry
1620

1721
BoundedStr = Annotated[str, Field(..., min_length=1, max_length=100)]
@@ -65,11 +69,13 @@ def __init__(self, **kwargs: Any):
6569
super().__init__(**kwargs)
6670
# dynamically register in the enum registry on instantiation
6771
enum_repr = f"enums.{self.name}"
68-
enum_value_repr = f"{enum_repr}.{self.values[0].name}"
72+
enum_value_repr = (
73+
None if not self.values else f"{enum_repr}.{self.values[0].name}"
74+
)
6975
enum_registry.register(
7076
self.name,
7177
TypeInfo(
72-
sqlalchemy_type=f"Enum({enum_repr})",
78+
sqlalchemy_type=f'Enum({enum_repr}, name="{camel_to_snake(self.name)}")',
7379
sqlalchemy_prefix=True,
7480
python_type=enum_repr,
7581
faker_field_value=enum_value_repr,
@@ -147,9 +153,6 @@ def _validate_type(self) -> Self:
147153
)
148154
raise ValueError(msg)
149155

150-
# if self.type_enum and self.default_value:
151-
# self.default_value = f"enums.{self.type_enum.name}.{self.default_value}"
152-
153156
return self
154157

155158
@model_validator(mode="after")
@@ -422,44 +425,6 @@ def _validate_models(self) -> Self:
422425

423426
return self
424427

425-
@model_validator(mode="after")
426-
def _validate_circular_references(self) -> Self:
427-
relationship_graph = {}
428-
429-
model_names = {model.name for model in self.models}
430-
431-
for model in self.models:
432-
relationship_graph[model.name] = [
433-
rel.target_model
434-
for rel in model.relationships
435-
if rel.target_model in model_names
436-
]
437-
438-
visited = set()
439-
path = set()
440-
441-
def has_cycle(node):
442-
if node in visited:
443-
return False
444-
visited.add(node)
445-
path.add(node)
446-
447-
for neighbor in relationship_graph.get(node, []):
448-
if neighbor in path or has_cycle(neighbor):
449-
return True
450-
451-
path.remove(node)
452-
return False
453-
454-
for model_name in relationship_graph:
455-
if has_cycle(model_name):
456-
raise ValueError(
457-
f"Circular reference detected involving model '{model_name}'. "
458-
"Remove bidirectional relationships between models.",
459-
)
460-
461-
return self
462-
463428
def get_auth_model(self) -> Model | None:
464429
if not self.use_builtin_auth:
465430
return None

fastapi_forge/enums.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ class FieldDataTypeEnum(StrEnum):
1616
@lru_cache
1717
def get_type_mappings(cls) -> dict[str, list[str]]:
1818
return {
19-
cls.STRING: ["character varying", "text", "varchar", "char"],
19+
cls.STRING: [
20+
"character varying",
21+
"text",
22+
"varchar",
23+
"char",
24+
"user-defined",
25+
],
2026
cls.INTEGER: [
2127
"integer",
2228
"int",

fastapi_forge/jinja.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from fastapi_forge.dtos import (
66
CustomEnum,
7-
CustomEnumValue,
87
Model,
98
ModelField,
109
)
@@ -26,7 +25,7 @@
2625
2726
2827
{% set unique_relationships = model.relationships | unique(attribute='target') %}
29-
{% for relation in unique_relationships -%}
28+
{% for relation in unique_relationships if relation.target != model.name_cc -%}
3029
from src.models.{{ relation.target_model }}_models import {{ relation.target }}
3130
{% endfor %}
3231
@@ -43,7 +42,7 @@ class {{ model.name_cc }}(Base):
4342
{% endfor %}
4443
4544
{% for relation in model.relationships -%}
46-
{{ relation | generate_relationship }}
45+
{{ relation | generate_relationship(model.name_cc == relation.target) }}
4746
{% endfor %}
4847
"""
4948

@@ -182,8 +181,10 @@ async def get_{{ model.name }}(
182181
import pytest
183182
from tests import factories
184183
from src.daos import AllDAOs
184+
from src import enums
185185
from httpx import AsyncClient
186186
from datetime import datetime, timezone, timedelta
187+
from uuid import uuid4
187188
188189
189190
from typing import Any
@@ -200,11 +201,11 @@ async def test_post_{{ model.name }}(client: AsyncClient, daos: AllDAOs,) -> Non
200201
{%- endfor %}
201202
202203
input_json = {
203-
{%- for field in model.fields if not (field.metadata.is_created_at_timestamp or field.metadata.is_updated_at_timestamp or field.primary_key) -%}
204-
{%- if not field.primary_key and field.name.endswith('_id') -%}
204+
{%- for field in model.fields if not (field.metadata.is_created_at_timestamp or field.metadata.is_updated_at_timestamp or field.primary_key or not field.type_info.test_value) -%}
205+
{%- if not field.primary_key and field.name.endswith('_id') and field.metadata.is_foreign_key -%}
205206
"{{ field.name }}": str({{ field.name | replace('_id', '.id') }}),
206207
{%- elif not field.primary_key %}
207-
"{{ field.name }}": {{ field.type_info.test_value }}{{ field.type_info.test_func }},
208+
"{{ field.name }}": {{ field.type_info.test_value }}{{ field.type_info.test_func if field.type_info.test_func else '' }},
208209
{%- endif %}
209210
{%- endfor %}
210211
}
@@ -216,11 +217,19 @@ async def test_post_{{ model.name }}(client: AsyncClient, daos: AllDAOs,) -> Non
216217
db_{{ model.name }} = await daos.{{ model.name }}.filter_first(id=response_data["id"])
217218
218219
assert db_{{ model.name }} is not None
219-
{%- for field in model.fields if not (field.metadata.is_created_at_timestamp or field.metadata.is_updated_at_timestamp or field.primary_key) %}
220-
{%- if not field.primary_key and field.name.endswith('_id') %}
221-
assert db_{{ model.name }}.{{ field.name }} == UUID(input_json["{{ field.name }}"])
220+
{%- for field in model.fields if not (field.metadata.is_created_at_timestamp or field.metadata.is_updated_at_timestamp or field.primary_key or not field.type_info.test_value) %}
221+
{%- if not field.primary_key and field.metadata.is_foreign_key %}
222+
{%- if field.type_info.encapsulate_assert %}
223+
assert db_{{ model.name }}.{{ field.name }} == {{ field.type_info.encapsulate_assert }}(input_json["{{ field.name }}"])
224+
{%- else %}
225+
assert db_{{ model.name }}.{{ field.name }} == input_json["{{ field.name }}"]
226+
{%- endif %}
222227
{%- elif not field.primary_key %}
223-
assert db_{{ model.name }}.{{ field.name }}{{ field.type_info.test_func }} == input_json["{{ field.name }}"]
228+
{%- if field.type_info.encapsulate_assert %}
229+
assert db_{{ model.name }}.{{ field.name }}{{ field.type_info.test_func if field.type_info.test_func else '' }} == {{ field.type_info.encapsulate_assert }}(input_json["{{ field.name }}"])
230+
{%- else %}
231+
assert db_{{ model.name }}.{{ field.name }}{{ field.type_info.test_func if field.type_info.test_func else '' }} == input_json["{{ field.name }}"]
232+
{%- endif %}
224233
{%- endif %}
225234
{%- endfor %}
226235
"""
@@ -278,7 +287,7 @@ async def test_get_{{ model.name }}_by_id(client: AsyncClient,) -> None:
278287
{%- if not field.primary_key and field.name.endswith('_id') %}
279288
assert response_data["{{ field.name }}"] == str({{ model.name }}.{{ field.name }})
280289
{%- elif not field.primary_key %}
281-
assert response_data["{{ field.name }}"] == {{ model.name }}.{{ field.name }}{{ field.type_info.test_func }}
290+
assert response_data["{{ field.name }}"] == {{ model.name }}.{{ field.name }}{{ field.type_info.test_func if field.type_info.test_func else '' }}
282291
{%- endif %}
283292
{%- endfor %}
284293
"""
@@ -287,8 +296,10 @@ async def test_get_{{ model.name }}_by_id(client: AsyncClient,) -> None:
287296
import pytest
288297
from tests import factories
289298
from src.daos import AllDAOs
299+
from src import enums
290300
from httpx import AsyncClient
291301
from datetime import datetime, timezone, timedelta
302+
from uuid import uuid4
292303
293304
294305
from typing import Any
@@ -306,11 +317,11 @@ async def test_patch_{{ model.name }}(client: AsyncClient, daos: AllDAOs,) -> No
306317
{{ model.name }} = await factories.{{ model.name_cc }}Factory.create()
307318
308319
input_json = {
309-
{%- for field in model.fields if not (field.metadata.is_created_at_timestamp or field.metadata.is_updated_at_timestamp or field.primary_key) -%}
310-
{%- if not field.primary_key and field.name.endswith('_id') -%}
320+
{%- for field in model.fields if not (field.metadata.is_created_at_timestamp or field.metadata.is_updated_at_timestamp or field.primary_key or not field.type_info.test_value) -%}
321+
{%- if not field.primary_key and field.name.endswith('_id') and field.metadata.is_foreign_key -%}
311322
"{{ field.name }}": str({{ field.name | replace('_id', '.id') }}),
312323
{% elif not field.primary_key %}
313-
"{{ field.name }}": {{ field.type_info.test_value }}{{ field.type_info.test_func }},
324+
"{{ field.name }}": {{ field.type_info.test_value }}{{ field.type_info.test_func if field.type_info.test_func else '' }},
314325
{%- endif %}
315326
{%- endfor %}
316327
}
@@ -321,11 +332,19 @@ async def test_patch_{{ model.name }}(client: AsyncClient, daos: AllDAOs,) -> No
321332
db_{{ model.name }} = await daos.{{ model.name }}.filter_first(id={{ model.name }}.id)
322333
323334
assert db_{{ model.name }} is not None
324-
{%- for field in model.fields if not (field.metadata.is_created_at_timestamp or field.metadata.is_updated_at_timestamp or field.primary_key) %}
325-
{%- if not field.primary_key and field.name.endswith('_id') %}
335+
{%- for field in model.fields if not (field.metadata.is_created_at_timestamp or field.metadata.is_updated_at_timestamp or field.primary_key or not field.type_info.test_value) %}
336+
{%- if not field.primary_key and field.metadata.is_foreign_key %}
337+
{%- if field.type_info.encapsulate_assert %}
338+
assert db_{{ model.name }}.{{ field.name }} == {{ field.type_info.encapsulate_assert }}(input_json["{{ field.name }}"])
339+
{%- else %}
326340
assert db_{{ model.name }}.{{ field.name }} == UUID(input_json["{{ field.name }}"])
341+
{%- endif %}
327342
{%- elif not field.primary_key %}
328-
assert db_{{ model.name }}.{{ field.name }}{{ field.type_info.test_func }} == input_json["{{ field.name }}"]
343+
{%- if field.type_info.encapsulate_assert %}
344+
assert db_{{ model.name }}.{{ field.name }}{{ field.type_info.test_func if field.type_info.test_func else '' }} == {{ field.type_info.encapsulate_assert }}(input_json["{{ field.name }}"])
345+
{%- else %}
346+
assert db_{{ model.name }}.{{ field.name }}{{ field.type_info.test_func if field.type_info.test_func else '' }} == input_json["{{ field.name }}"]
347+
{%- endif %}
329348
{%- endif %}
330349
{%- endfor %}
331350
@@ -454,8 +473,8 @@ def render_custom_enums_to_enums(custom_enums: list[CustomEnum]) -> str:
454473
enum0 = CustomEnum(
455474
name="MyEnum0",
456475
values=[
457-
CustomEnumValue(name="FoO", value="foo"),
458-
CustomEnumValue(name="BAR", value="bar"),
476+
# CustomEnumValue(name="FoO", value="foo"),
477+
# CustomEnumValue(name="BAR", value="bar"),
459478
],
460479
)
461480

@@ -480,4 +499,4 @@ def render_custom_enums_to_enums(custom_enums: list[CustomEnum]) -> str:
480499
),
481500
],
482501
)
483-
print(render_model_to_model(model))
502+
print(render_model_to_post_test(model))

fastapi_forge/jinja_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,18 @@ def generate_field(
5858
return _gen_field(field=field, target=target)
5959

6060

61-
def generate_relationship(relation: ModelRelationship) -> str:
61+
def generate_relationship(
62+
relation: ModelRelationship, is_self_reference: bool = False
63+
) -> str:
6264
args = []
63-
args.append(f'"{relation.target}"')
6465
args.append(f"foreign_keys=[{relation.field_name}]")
6566
if relation.back_populates:
6667
args.append(f'back_populates="{relation.back_populates}"')
6768
args.append("uselist=False")
6869

70+
target_repr = relation.target if not is_self_reference else f'"{relation.target}"'
6971
return f"""
70-
{relation.field_name_no_id}: Mapped[{relation.target}] = relationship(
72+
{relation.field_name_no_id}: Mapped[{target_repr}] = relationship(
7173
{",\n ".join(args)}
7274
)
7375
""".strip()

0 commit comments

Comments
 (0)