Skip to content

Commit 20abc90

Browse files
authored
feat: support custom postgres enums (#30)
1 parent 406a3c8 commit 20abc90

31 files changed

Lines changed: 1262 additions & 372 deletions

fastapi_forge/__main__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,16 @@ def start(
5353
yaml_path = Path(from_yaml).expanduser().resolve()
5454
if not yaml_path.is_file():
5555
raise click.FileError(f"YAML file not found: {yaml_path}")
56-
project_spec = ProjectLoader(project_path=yaml_path).load_project_input()
56+
project_spec = ProjectLoader(project_path=yaml_path).load_project()
5757
elif db_url:
58-
project_spec = ProjectLoader.load_project_spec_from_db(
58+
project_spec = ProjectLoader.load_project_from_db(
5959
connection_string=db_url,
6060
)
6161
elif use_example:
6262
base_path = Path(__file__).parent / "example-projects"
6363
path = base_path / "game_zone.yaml"
64-
project_spec = ProjectLoader(project_path=path).load_project_input()
64+
65+
project_spec = ProjectLoader(project_path=path).load_project()
6566

6667
init(project_spec=project_spec, no_ui=no_ui)
6768

fastapi_forge/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
TAB = " "

fastapi_forge/dtos.py

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,31 @@
99
model_validator,
1010
)
1111

12-
from fastapi_forge.data_type_registry import DataTypeInfo, registry
12+
from fastapi_forge.constants import TAB
1313
from fastapi_forge.enums import FieldDataTypeEnum, OnDeleteEnum
1414
from fastapi_forge.string_utils import camel_to_snake_hyphen, snake_to_camel
15+
from fastapi_forge.type_info_registry import TypeInfo, enum_registry, registry
1516

1617
BoundedStr = Annotated[str, Field(..., min_length=1, max_length=100)]
1718
SnakeCaseStr = Annotated[BoundedStr, Field(..., pattern=r"^[a-z][a-z0-9_]*$")]
19+
PascalCaseStr = Annotated[
20+
BoundedStr,
21+
Field(..., pattern=r"^[A-Z][a-zA-Z0-9]*$"),
22+
]
1823
ModelName = SnakeCaseStr
1924
FieldName = SnakeCaseStr
2025
BackPopulates = Annotated[str, Field(..., pattern=r"^[a-z][a-z0-9_]*$")]
2126
ProjectName = Annotated[
2227
BoundedStr,
2328
Field(..., pattern=r"^[a-zA-Z0-9](?:[a-zA-Z0-9._-]*[a-zA-Z0-9])?$"),
2429
]
30+
EnumStr = Annotated[
31+
BoundedStr,
32+
Field(
33+
...,
34+
pattern=r"^[a-zA-Z][a-zA-Z0-9_]*$",
35+
),
36+
]
2537

2638

2739
class _Base(BaseModel):
@@ -36,11 +48,67 @@ class ModelFieldMetadata(_Base):
3648
is_foreign_key: bool = False
3749

3850

51+
class CustomEnumValue(_Base):
52+
"""Represents a single name/value pair in a custom enum."""
53+
54+
name: EnumStr
55+
value: BoundedStr
56+
57+
58+
class CustomEnum(_Base):
59+
"""Represents a custom PostgreSQL ENUM type."""
60+
61+
name: EnumStr
62+
values: list[CustomEnumValue] = []
63+
64+
def __init__(self, **kwargs: Any):
65+
super().__init__(**kwargs)
66+
# dynamically register in the enum registry on instantiation
67+
enum_repr = f"enums.{self.name}"
68+
enum_value_repr = f"{enum_repr}.{self.values[0].name}"
69+
enum_registry.register(
70+
self.name,
71+
TypeInfo(
72+
sqlalchemy_type=f"Enum({enum_repr})",
73+
sqlalchemy_prefix=True,
74+
python_type=enum_repr,
75+
faker_field_value=enum_value_repr,
76+
value=enum_value_repr,
77+
test_value=enum_value_repr,
78+
),
79+
)
80+
81+
@model_validator(mode="after")
82+
def _validate_enum(self) -> Self:
83+
names = [v.name for v in self.values]
84+
85+
if len(names) != len(set(names)):
86+
raise ValueError(f"Enum '{self.name}' has duplicate names.")
87+
return self
88+
89+
@computed_field
90+
@property
91+
def class_definition(self) -> str:
92+
"""Returns a string representing the Python Enum class definition."""
93+
lines: list[str] = []
94+
lines.extend([f"class {self.name}(StrEnum):"])
95+
lines.extend([f'{TAB}"""{self.name} Enum."""\n'])
96+
97+
value_lines: list[str] = []
98+
for v in self.values:
99+
value_repr = v.value if v.value == "auto()" else f'"{v.value}"'
100+
value_lines.extend([f"{TAB}{v.name} = {value_repr}"])
101+
102+
lines.extend(value_lines)
103+
return "\n".join(lines)
104+
105+
39106
class ModelField(_Base):
40107
"""Represents a field in a model with validation and computed properties."""
41108

42109
name: FieldName
43110
type: FieldDataTypeEnum
111+
type_enum: EnumStr | None = None
44112
primary_key: bool = False
45113
nullable: bool = False
46114
unique: bool = False
@@ -58,9 +126,32 @@ def name_cc(self) -> str:
58126

59127
@computed_field
60128
@property
61-
def type_info(self) -> DataTypeInfo:
129+
def type_info(self) -> TypeInfo:
130+
if self.type_enum:
131+
return enum_registry.get(self.type_enum)
62132
return registry.get(self.type)
63133

134+
@model_validator(mode="after")
135+
def _validate_type(self) -> Self:
136+
if self.type == FieldDataTypeEnum.Enum and self.type_enum is None:
137+
msg = (
138+
f"ModelField '{self.name}' has field type 'ENUM', "
139+
"but is missing 'type_enum'."
140+
)
141+
raise ValueError(msg)
142+
143+
if self.type_enum and self.type != FieldDataTypeEnum.Enum:
144+
msg = (
145+
f"ModelField '{self.name}' has 'type_enum' set, "
146+
"but is not field type 'ENUM'."
147+
)
148+
raise ValueError(msg)
149+
150+
# if self.type_enum and self.default_value:
151+
# self.default_value = f"enums.{self.type_enum.name}.{self.default_value}"
152+
153+
return self
154+
64155
@model_validator(mode="after")
65156
def _validate(self) -> Self:
66157
"""Validate field constraints."""
@@ -183,10 +274,10 @@ def _validate(self) -> Self:
183274
if sum(field.primary_key for field in self.fields) != 1:
184275
raise ValueError(f"Model '{self.name}' must have exactly one primary key.")
185276

186-
unque_relationships = [
277+
unique_relationships = [
187278
relationship.field_name for relationship in self.relationships
188279
]
189-
if len(unque_relationships) != len(set(unque_relationships)):
280+
if len(unique_relationships) != len(set(unique_relationships)):
190281
raise ValueError(
191282
f"Model '{self.name}' contains duplicate relationship field names.",
192283
)
@@ -277,6 +368,7 @@ class ProjectSpec(_Base):
277368
use_rabbitmq: bool = False
278369
use_taskiq: bool = False
279370
models: list[Model] = []
371+
custom_enums: list[CustomEnum] = []
280372

281373
@model_validator(mode="after")
282374
def _validate_models(self) -> Self:
@@ -286,6 +378,11 @@ def _validate_models(self) -> Self:
286378
msg = "Model names must be unique."
287379
raise ValueError(msg)
288380

381+
enum_names = [enum.name for enum in self.custom_enums]
382+
if len(enum_names) != len(set(enum_names)):
383+
msg = "Enum names must be unique."
384+
raise ValueError(msg)
385+
289386
if self.use_alembic and not self.use_postgres:
290387
msg = "Cannot use Alembic if PostgreSQL is not enabled."
291388
raise ValueError(msg)
@@ -322,6 +419,7 @@ def _validate_models(self) -> Self:
322419
"TaskIQ is enabled, but the following are missing and required "
323420
f"for its operation: {', '.join(missing)}."
324421
)
422+
325423
return self
326424

327425
@model_validator(mode="after")

fastapi_forge/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class FieldDataTypeEnum(StrEnum):
1010
DATETIME = "DateTime"
1111
UUID = "UUID"
1212
JSONB = "JSONB"
13+
Enum = "Enum"
1314

1415
@classmethod
1516
@lru_cache

fastapi_forge/example-projects/game_zone.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ project:
77
use_rabbitmq: true
88
use_taskiq: true
99

10+
custom_enums:
11+
- name: UserRole
12+
values:
13+
- name: ADMIN
14+
value: auto()
15+
- name: USER
16+
value: auto()
17+
1018
models:
1119
- name: auth_user
1220
fields:
@@ -19,6 +27,10 @@ project:
1927
index: true
2028
- name: password
2129
type: String
30+
- name: role
31+
type: Enum
32+
type_enum: UserRole
33+
default_value: USER
2234
- name: created_at
2335
type: DateTime
2436
default_value: datetime.now(timezone.utc)

fastapi_forge/frontend/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from fastapi_forge.frontend.components.header import Header
2-
from fastapi_forge.frontend.components.model_create import ModelCreate
3-
from fastapi_forge.frontend.components.model_row import ModelRow
4-
from fastapi_forge.frontend.panels.model_panel import ModelPanel
2+
from fastapi_forge.frontend.components.item_create import ModelCreate, EnumCreate
3+
from fastapi_forge.frontend.components.item_row import ModelRow, EnumRow
4+
from fastapi_forge.frontend.panels.left_panel import LeftPanel
55
from fastapi_forge.frontend.panels.model_editor_panel import ModelEditorPanel
6+
from fastapi_forge.frontend.panels.enum_editor_panel import EnumEditorPanel
67
from fastapi_forge.frontend.panels.project_config_panel import ProjectConfigPanel
8+
from fastapi_forge.frontend.panels.item_editor_panel import ItemEditorPanel
79
from fastapi_forge.frontend.main import init
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from collections.abc import Callable
2+
3+
from nicegui import ui
4+
5+
from fastapi_forge.frontend.state import state
6+
7+
8+
class _RowCreate(ui.row):
9+
def __init__(
10+
self,
11+
*,
12+
input_placeholder: str,
13+
input_tooltip: str,
14+
button_tooltip: str,
15+
on_add_item: Callable[[str], None],
16+
):
17+
super().__init__(wrap=False)
18+
self.input_placeholder = input_placeholder
19+
self.input_tooltip = input_tooltip
20+
self.button_tooltip = button_tooltip
21+
self.on_add_item = on_add_item
22+
23+
self._build()
24+
25+
def _build(self) -> None:
26+
with self.classes("w-full flex items-center justify-between"):
27+
self.item_input = (
28+
ui.input(placeholder=self.input_placeholder)
29+
.classes("self-center")
30+
.tooltip(
31+
self.input_tooltip,
32+
)
33+
)
34+
self.add_button = (
35+
ui.button(icon="add", on_click=self._add_item)
36+
.classes("self-center")
37+
.tooltip(self.button_tooltip)
38+
)
39+
40+
def _add_item(self) -> None:
41+
if not self.item_input.value:
42+
return
43+
value: str = self.item_input.value
44+
item_name = value.strip()
45+
if item_name:
46+
self.on_add_item(item_name)
47+
self.item_input.value = ""
48+
49+
50+
class ModelCreate(_RowCreate):
51+
def __init__(self):
52+
super().__init__(
53+
input_placeholder="Model name",
54+
input_tooltip="Model names should be singular (e.g., 'user' instead of 'users').",
55+
button_tooltip="Add Model",
56+
on_add_item=state.add_model,
57+
)
58+
59+
60+
class EnumCreate(_RowCreate):
61+
def __init__(self):
62+
super().__init__(
63+
input_placeholder="Enum name",
64+
input_tooltip="Enums can be used as data types for model fields.",
65+
button_tooltip="Add Enum",
66+
on_add_item=state.add_enum,
67+
)

0 commit comments

Comments
 (0)