Skip to content

Commit 5434248

Browse files
authored
feat: Let the user set and/or change the auth model (#21)
1 parent 43dab52 commit 5434248

11 files changed

Lines changed: 173 additions & 48 deletions

File tree

fastapi_forge/dtos.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ def _validate_models(self) -> Self:
313313
msg = "Cannot use built-in auth if PostgreSQL is not enabled."
314314
raise ValueError(msg)
315315

316+
if self.use_builtin_auth and self.get_auth_model() is None:
317+
msg = "Cannot use built-in auth if no auth model is defined."
318+
raise ValueError(msg)
319+
316320
for model in self.models:
317321
for relationship in model.relationships:
318322
if relationship.target_model not in model_names_set:
@@ -364,3 +368,11 @@ def has_cycle(node):
364368
)
365369

366370
return self
371+
372+
def get_auth_model(self) -> Model | None:
373+
if not self.use_builtin_auth:
374+
return None
375+
for model in self.models:
376+
if model.metadata.is_auth_model:
377+
return model
378+
return None

fastapi_forge/forge.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,27 @@ async def build_project(spec: ProjectSpec) -> None:
3737

3838
template_path = str(_get_template_path())
3939

40+
extra_context = {
41+
**spec.model_dump(exclude={"models"}),
42+
"models": {
43+
"models": [model.model_dump() for model in spec.models],
44+
},
45+
}
46+
47+
if spec.use_builtin_auth:
48+
auth_user = spec.get_auth_model()
49+
if auth_user:
50+
extra_context["auth_model"] = auth_user.model_dump()
51+
else:
52+
logger.warning("No auth model found. Skipping authentication setup.")
53+
extra_context["use_builtin_auth"] = False
54+
4055
cookiecutter(
4156
template_path,
4257
output_dir=str(Path.cwd()),
4358
no_input=True,
4459
overwrite_if_exists=True,
45-
extra_context={
46-
**spec.model_dump(exclude={"models"}),
47-
"models": {
48-
"models": [model.model_dump() for model in spec.models],
49-
},
50-
},
60+
extra_context=extra_context,
5161
)
5262
logger.info(f"Project '{spec.project_name}' created successfully.")
5363

fastapi_forge/frontend/constants.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import Any
22

3+
from fastapi_forge.dtos import ModelField
4+
from fastapi_forge.enums import FieldDataType
5+
36
SELECTED_MODEL_TEXT_COLOR = "text-black-500 dark:text-amber-300"
47

58
FIELD_COLUMNS: list[dict[str, Any]] = [
@@ -46,3 +49,17 @@
4649
{"name": "index", "label": "Index", "field": "index", "align": "center"},
4750
{"name": "unique", "label": "Unique", "field": "unique", "align": "center"},
4851
]
52+
53+
54+
DEFAULT_AUTH_USER_FIELDS: list[ModelField] = [
55+
ModelField(
56+
name="email",
57+
type=FieldDataType.STRING,
58+
unique=True,
59+
index=True,
60+
),
61+
ModelField(
62+
name="password",
63+
type=FieldDataType.STRING,
64+
),
65+
]

fastapi_forge/frontend/modals/field_modal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def remove_row():
122122

123123
def _show_field_preview(self) -> None:
124124
if not self.field_name.value:
125+
ui.notify("Set a field name first", type="warning")
126+
return
127+
if not self.field_type.value:
128+
ui.notify("Select a field type first", type="warning")
125129
return
126130
try:
127131
with ui.dialog() as modal, ui.card().classes("no-shadow border-[1px]"):

fastapi_forge/frontend/panels/model_editor_panel.py

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
from fastapi_forge.dtos import Model, ModelField, ModelFieldMetadata, ModelRelationship
77
from fastapi_forge.enums import FieldDataType
8-
from fastapi_forge.frontend.constants import FIELD_COLUMNS, RELATIONSHIP_COLUMNS
8+
from fastapi_forge.frontend.constants import (
9+
DEFAULT_AUTH_USER_FIELDS,
10+
FIELD_COLUMNS,
11+
RELATIONSHIP_COLUMNS,
12+
)
913
from fastapi_forge.frontend.modals import (
1014
AddFieldModal,
1115
AddRelationModal,
@@ -28,6 +32,7 @@ def __init__(self):
2832
state.select_model_fn = self.set_selected_model
2933
state.deselect_model_fn = self.deselect_model
3034
state.render_model_editor_fn = self.refresh
35+
state.render_actions_fn = self._render_action_group
3136

3237
self.add_field_modal: AddFieldModal = AddFieldModal(
3338
on_add_field=self._handle_modal_add_field,
@@ -56,6 +61,63 @@ def _show_code_preview(self) -> None:
5661
ui.code(code).classes("w-full")
5762
modal.open()
5863

64+
def _toggle_auth_model(self) -> None:
65+
if not state.selected_model or not state.use_builtin_auth:
66+
return
67+
68+
if not state.selected_model.metadata.is_auth_model and any(
69+
m.metadata.is_auth_model for m in state.models
70+
):
71+
ui.notify(
72+
"Cannot have more than one authentication model.", type="negative"
73+
)
74+
return
75+
76+
model = state.selected_model
77+
model.metadata.is_auth_model = not model.metadata.is_auth_model
78+
79+
if not model.metadata.is_auth_model:
80+
self._remove_auth_fields(model)
81+
if state.render_model_editor_fn:
82+
state.render_model_editor_fn()
83+
84+
if state.render_models_fn:
85+
state.render_models_fn()
86+
self._render_action_group.refresh()
87+
88+
if model.metadata.is_auth_model:
89+
self._setup_auth_model_fields(model)
90+
91+
def _remove_auth_fields(self, model: Model) -> None:
92+
for field_name in ("email", "password"):
93+
if field := next((f for f in model.fields if f.name == field_name), None):
94+
model.fields.remove(field)
95+
96+
def _setup_auth_model_fields(self, model: Model) -> None:
97+
self._remove_auth_fields(model)
98+
id_index = 0
99+
insert_position = id_index + 1 if id_index >= 0 else 0
100+
for field in reversed(DEFAULT_AUTH_USER_FIELDS):
101+
model.fields.insert(insert_position, field)
102+
103+
self._refresh_table(model.fields)
104+
105+
@ui.refreshable
106+
def _render_action_group(self) -> None:
107+
ui.button(
108+
icon="security",
109+
on_click=self._toggle_auth_model,
110+
color=(
111+
"green"
112+
if state.use_builtin_auth
113+
and state.selected_model
114+
and state.selected_model.metadata.is_auth_model
115+
else "grey"
116+
),
117+
).tooltip("Authentication model").bind_visibility_from(
118+
state, "use_builtin_auth"
119+
)
120+
59121
def _build(self) -> None:
60122
with self:
61123
with ui.row().classes("w-full justify-between items-center"):
@@ -67,6 +129,7 @@ def _build(self) -> None:
67129
).tooltip("Preview SQLAlchemy model code")
68130

69131
with ui.row().classes("gap-2 items-center"):
132+
self._render_action_group()
70133
with (
71134
ui.button(icon="menu").tooltip("Generate"),
72135
ui.menu(),
@@ -373,22 +436,35 @@ def _deselect_relation(self) -> None:
373436
self.relationship_table.selected = []
374437

375438
def _on_select_field(self, selection: list[dict[str, Any]]) -> None:
376-
if not state.selected_model:
439+
if not state.selected_model or not selection:
440+
self._deselect_field()
377441
return
378-
if not selection:
442+
443+
name = selection[0].get("name")
444+
445+
if (
446+
state.selected_model.metadata.is_auth_model
447+
and state.use_builtin_auth
448+
and name in ("password", "email")
449+
):
450+
ui.notify(
451+
f"Cannot edit {name} field in authentication model.",
452+
type="warning",
453+
)
379454
self._deselect_field()
380455
return
381-
if selection[0].get("name") == "id":
456+
457+
if name == "id":
382458
self._deselect_field()
383-
else:
384-
state.selected_field = next(
385-
(
386-
field
387-
for field in state.selected_model.fields
388-
if field.name == selection[0]["name"]
389-
),
390-
None,
459+
ui.notify(
460+
"Cannot edit the 'id' field, it is automatically generated.",
461+
type="warning",
391462
)
463+
return
464+
465+
state.selected_field = next(
466+
(field for field in state.selected_model.fields if field.name == name), None
467+
)
392468

393469
def _on_select_relation(self, selection: list[dict[str, Any]]) -> None:
394470
if not state.selected_model:
@@ -417,11 +493,18 @@ def _handle_update_field(
417493
default_value: str | None = None,
418494
extra_kwargs: dict[str, Any] | None = None,
419495
) -> None:
496+
if not state.selected_model or not state.selected_field:
497+
return
498+
499+
exclude_set = {"id"}
420500
if (
421-
not state.selected_model
422-
or not state.selected_field
423-
or state.selected_field.name == "id"
501+
state.selected_model
502+
and state.selected_model.metadata.is_auth_model
503+
and state.use_builtin_auth
424504
):
505+
exclude_set.update({"email", "password"})
506+
507+
if name in exclude_set:
425508
return
426509

427510
if state.selected_field.name != name and self._field_name_exists(name):

fastapi_forge/frontend/panels/project_config_panel.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from fastapi_forge.enums import FieldDataType
1414
from fastapi_forge.forge import build_project
15+
from fastapi_forge.frontend.constants import DEFAULT_AUTH_USER_FIELDS
1516
from fastapi_forge.frontend.notifications import notify_validation_error
1617
from fastapi_forge.frontend.state import state
1718

@@ -166,16 +167,7 @@ def _handle_builtin_auth_change(self, event: ValueChangeEventArguments) -> None:
166167
unique=True,
167168
index=True,
168169
),
169-
ModelField(
170-
name="email",
171-
type=FieldDataType.STRING,
172-
unique=True,
173-
index=True,
174-
),
175-
ModelField(
176-
name="password",
177-
type=FieldDataType.STRING,
178-
),
170+
*DEFAULT_AUTH_USER_FIELDS,
179171
ModelField(
180172
name="created_at",
181173
type=FieldDataType.DATETIME,

fastapi_forge/frontend/state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ProjectState(BaseModel):
2727

2828
render_models_fn: Callable | None = None
2929
render_model_editor_fn: Callable | None = None
30+
render_actions_fn: Callable | None = None
3031
select_model_fn: Callable[[Model], None] | None = None
3132
deselect_model_fn: Callable | None = None
3233

@@ -177,6 +178,8 @@ def _trigger_ui_refresh(self) -> None:
177178
self.render_models_fn()
178179
if self.render_model_editor_fn:
179180
self.render_model_editor_fn()
181+
if self.render_actions_fn:
182+
self.render_actions_fn.refresh()
180183

181184

182185
state: ProjectState = ProjectState()

fastapi_forge/template/cookiecutter.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"models": {
2121
"models": []
2222
},
23+
"auth_model": null,
2324
"_extensions": [
2425
"local_extensions.camel_to_snake"
2526
]

fastapi_forge/template/{{cookiecutter.project_name}}/src/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from src.settings import settings
22

33

4+
45
if __name__ == "__main__":
56
import uvicorn
67

fastapi_forge/template/{{cookiecutter.project_name}}/src/dependencies/auth_dependencies.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
{%- if cookiecutter.use_builtin_auth %}
1+
from typing import Annotated
2+
23
from fastapi import Depends, HTTPException, Request
34
from fastapi.security import HTTPBearer as _HTTPBearer
4-
from src.dtos.auth_user_dtos import AuthUserDTO
5-
from src.daos import GetDAOs
5+
66
from src import exceptions
7+
from src.daos import GetDAOs
8+
from src.dtos.{{ cookiecutter.auth_model.name }}_dtos import {{ cookiecutter.auth_model.name_cc }}DTO
79
from src.utils import auth_utils
8-
from typing import Annotated
910

1011

1112
class HTTPBearer(_HTTPBearer):
@@ -20,7 +21,8 @@ async def __call__(self, request: Request) -> str | None: # type: ignore
2021
obj = await super().__call__(request)
2122
return obj.credentials if obj else None
2223
except HTTPException:
23-
raise exceptions.Http401("Missing token.")
24+
msg = "Missing token."
25+
raise exceptions.Http401(msg)
2426

2527

2628
auth_scheme = HTTPBearer()
@@ -37,17 +39,17 @@ def get_token(token: str = Depends(auth_scheme)) -> str:
3739
async def get_current_user(
3840
token: GetToken,
3941
daos: GetDAOs,
40-
) -> AuthUserDTO:
42+
) -> {{ cookiecutter.auth_model.name_cc }}DTO:
4143
"""Get current user from token data."""
4244
token_data = auth_utils.decode_token(token)
4345

44-
user = await daos.auth_user.filter_first(id=token_data.user_id)
46+
user = await daos.{{ cookiecutter.auth_model.name }}.filter_first(id=token_data.user_id)
4547

4648
if not user:
47-
raise exceptions.Http404("Decoded user not found.")
49+
msg = "Decoded user not found."
50+
raise exceptions.Http404(msg)
4851

49-
return AuthUserDTO.model_validate(user)
52+
return {{ cookiecutter.auth_model.name_cc }}DTO.model_validate(user)
5053

5154

52-
GetCurrentUser = Annotated[AuthUserDTO, Depends(get_current_user)]
53-
{% endif %}
55+
GetCurrentUser = Annotated[{{ cookiecutter.auth_model.name_cc }}DTO, Depends(get_current_user)]

0 commit comments

Comments
 (0)