Skip to content

Commit d783ac7

Browse files
committed
feat: support custom postgres enums
1 parent a94f5d2 commit d783ac7

2 files changed

Lines changed: 99 additions & 1 deletion

File tree

fastapi_forge/dtos.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
BoundedStr = Annotated[str, Field(..., min_length=1, max_length=100)]
1717
SnakeCaseStr = Annotated[BoundedStr, Field(..., pattern=r"^[a-z][a-z0-9_]*$")]
18+
PascalCaseStr = Annotated[
19+
BoundedStr,
20+
Field(..., pattern=r"^[A-Z][a-zA-Z0-9]*$"),
21+
]
1822
ModelName = SnakeCaseStr
1923
FieldName = SnakeCaseStr
2024
BackPopulates = Annotated[str, Field(..., pattern=r"^[a-z][a-z0-9_]*$")]
@@ -36,6 +40,51 @@ class ModelFieldMetadata(_Base):
3640
is_foreign_key: bool = False
3741

3842

43+
class CustomEnumValue(_Base):
44+
"""Represents a single name/value pair in a custom enum."""
45+
46+
name: Annotated[
47+
BoundedStr,
48+
Field(
49+
...,
50+
pattern=r"^[a-zA-Z][a-zA-Z0-9_]*$",
51+
),
52+
]
53+
value: BoundedStr
54+
55+
56+
class CustomEnum(_Base):
57+
"""Represents a custom PostgreSQL ENUM type."""
58+
59+
name: PascalCaseStr
60+
values: Annotated[list[CustomEnumValue], Field(..., min_length=1)]
61+
62+
@model_validator(mode="after")
63+
def _validate_enum(self) -> Self:
64+
names = [v.name for v in self.values]
65+
values = [v.value for v in self.values]
66+
67+
if len(names) != len(set(names)):
68+
raise ValueError(f"Enum '{self.name}' has duplicate names.")
69+
if len(values) != len(set(values)):
70+
raise ValueError(f"Enum '{self.name}' has duplicate values.")
71+
return self
72+
73+
@computed_field
74+
@property
75+
def class_definition(self) -> str:
76+
"""Returns a string representing the Python Enum class definition."""
77+
lines: list[str] = []
78+
lines.extend([f"class {self.name}(StrEnum):"])
79+
80+
value_lines: list[str] = []
81+
for v in self.values:
82+
value_lines.extend([f' {v.name} = "{v.value}"'])
83+
84+
lines.extend(value_lines)
85+
return "\n".join(lines)
86+
87+
3988
class ModelField(_Base):
4089
"""Represents a field in a model with validation and computed properties."""
4190

tests/test_dtos.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
import pytest
44
from pydantic import ValidationError
55

6-
from fastapi_forge.dtos import Model, ModelField, ModelRelationship, ProjectSpec
6+
from fastapi_forge.dtos import (
7+
CustomEnum,
8+
CustomEnumValue,
9+
Model,
10+
ModelField,
11+
ModelRelationship,
12+
ProjectSpec,
13+
)
714
from fastapi_forge.enums import FieldDataTypeEnum
815

916
########################
@@ -126,3 +133,45 @@ def test_project_spec_non_existing_target_model() -> None:
126133
"'restaurant' has a relationship to 'non_existing', which does not exist."
127134
in str(exc_info.value)
128135
)
136+
137+
138+
##############
139+
# CustomEnum #
140+
##############
141+
142+
143+
def test_custom_enum_not_unique_values() -> None:
144+
with pytest.raises(ValidationError) as exc_info:
145+
CustomEnum(
146+
name="MyEnum",
147+
values=[
148+
CustomEnumValue(name="HELLO", value="hello"),
149+
CustomEnumValue(name="HI", value="hello"),
150+
],
151+
)
152+
assert "Enum 'MyEnum' has duplicate values." in str(exc_info.value)
153+
154+
155+
def test_custom_enum_not_unique_names() -> None:
156+
with pytest.raises(ValidationError) as exc_info:
157+
CustomEnum(
158+
name="MyEnum",
159+
values=[
160+
CustomEnumValue(name="HELLO", value="hello"),
161+
CustomEnumValue(name="HELLO", value="hi"),
162+
],
163+
)
164+
assert "Enum 'MyEnum' has duplicate names." in str(exc_info.value)
165+
166+
167+
def test_custom_enum_valid() -> None:
168+
enum = CustomEnum(
169+
name="MyEnum",
170+
values=[
171+
CustomEnumValue(name="FoO", value="foo"),
172+
CustomEnumValue(name="BAR", value="bar"),
173+
],
174+
)
175+
assert enum.class_definition == (
176+
'class MyEnum(StrEnum):\n FoO = "foo"\n BAR = "bar"'
177+
)

0 commit comments

Comments
 (0)