Skip to content

Commit bb0ee78

Browse files
committed
Choose custom auth user
1 parent 27afa0b commit bb0ee78

6 files changed

Lines changed: 48 additions & 23 deletions

File tree

fastapi_forge/dtos.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ def _validate_models(self) -> Self:
313313
msg = "Cannot use built-in auth if PostgreSQL is not enabled."
314314
raise ValueError(msg)
315315

316+
if self.use_builtin_auth and self.get_auth_model() is None:
317+
msg = "Cannot use built-in auth if no auth model is defined."
318+
raise ValueError(msg)
319+
316320
for model in self.models:
317321
for relationship in model.relationships:
318322
if relationship.target_model not in model_names_set:
@@ -364,3 +368,11 @@ def has_cycle(node):
364368
)
365369

366370
return self
371+
372+
def get_auth_model(self) -> Model | None:
373+
if not self.use_builtin_auth:
374+
return None
375+
for model in self.models:
376+
if model.metadata.is_auth_model:
377+
return model
378+
return None

fastapi_forge/forge.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,26 @@ async def build_project(spec: ProjectSpec) -> None:
3737

3838
template_path = str(_get_template_path())
3939

40+
extra_context = {
41+
**spec.model_dump(exclude={"models"}),
42+
"models": {
43+
"models": [model.model_dump() for model in spec.models],
44+
},
45+
}
46+
47+
if spec.use_builtin_auth:
48+
auth_user = spec.get_auth_model()
49+
if auth_user:
50+
extra_context["auth_model"] = auth_user.model_dump()
51+
else:
52+
logger.warning("No AuthUser model found in the project spec.")
53+
4054
cookiecutter(
4155
template_path,
4256
output_dir=str(Path.cwd()),
4357
no_input=True,
4458
overwrite_if_exists=True,
45-
extra_context={
46-
**spec.model_dump(exclude={"models"}),
47-
"models": {
48-
"models": [model.model_dump() for model in spec.models],
49-
},
50-
},
59+
extra_context=extra_context,
5160
)
5261
logger.info(f"Project '{spec.project_name}' created successfully.")
5362

fastapi_forge/template/cookiecutter.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"models": {
2121
"models": []
2222
},
23+
"auth_model": null,
2324
"_extensions": [
2425
"local_extensions.camel_to_snake"
2526
]

fastapi_forge/template/{{cookiecutter.project_name}}/src/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from src.settings import settings
22

33

4+
45
if __name__ == "__main__":
56
import uvicorn
67

fastapi_forge/template/{{cookiecutter.project_name}}/src/dependencies/auth_dependencies.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
{%- if cookiecutter.use_builtin_auth %}
1+
from typing import Annotated
2+
23
from fastapi import Depends, HTTPException, Request
34
from fastapi.security import HTTPBearer as _HTTPBearer
4-
from src.dtos.auth_user_dtos import AuthUserDTO
5-
from src.daos import GetDAOs
5+
66
from src import exceptions
7+
from src.daos import GetDAOs
8+
from src.dtos.{{ cookiecutter.auth_model.name }}_dtos import {{ cookiecutter.auth_model.name_cc }}DTO
79
from src.utils import auth_utils
8-
from typing import Annotated
910

1011

1112
class HTTPBearer(_HTTPBearer):
@@ -20,7 +21,8 @@ async def __call__(self, request: Request) -> str | None: # type: ignore
2021
obj = await super().__call__(request)
2122
return obj.credentials if obj else None
2223
except HTTPException:
23-
raise exceptions.Http401("Missing token.")
24+
msg = "Missing token."
25+
raise exceptions.Http401(msg)
2426

2527

2628
auth_scheme = HTTPBearer()
@@ -37,17 +39,17 @@ def get_token(token: str = Depends(auth_scheme)) -> str:
3739
async def get_current_user(
3840
token: GetToken,
3941
daos: GetDAOs,
40-
) -> AuthUserDTO:
42+
) -> {{ cookiecutter.auth_model.name_cc }}DTO:
4143
"""Get current user from token data."""
4244
token_data = auth_utils.decode_token(token)
4345

44-
user = await daos.auth_user.filter_first(id=token_data.user_id)
46+
user = await daos.{{ cookiecutter.auth_model.name }}.filter_first(id=token_data.user_id)
4547

4648
if not user:
47-
raise exceptions.Http404("Decoded user not found.")
49+
msg = "Decoded user not found."
50+
raise exceptions.Http404(msg)
4851

49-
return AuthUserDTO.model_validate(user)
52+
return {{ cookiecutter.auth_model.name_cc }}DTO.model_validate(user)
5053

5154

52-
GetCurrentUser = Annotated[AuthUserDTO, Depends(get_current_user)]
53-
{% endif %}
55+
GetCurrentUser = Annotated[{{ cookiecutter.auth_model.name_cc }}DTO, Depends(get_current_user)]

fastapi_forge/template/{{cookiecutter.project_name}}/src/routes/auth_routes.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{%- if cookiecutter.use_builtin_auth %}
22
from fastapi import APIRouter
33
from src.dtos.auth_dtos import UserLoginDTO, UserCreateDTO, LoginResponse, TokenData
4-
from src.dtos.auth_user_dtos import AuthUserDTO, AuthUserInputDTO
4+
from src.dtos.{{ cookiecutter.auth_model.name }}_dtos import {{ cookiecutter.auth_model.name_cc }}DTO, {{ cookiecutter.auth_model.name_cc }}InputDTO
55
from src.dtos import DataResponse, CreatedResponse
66
from src.daos import GetDAOs
77
from src import exceptions
@@ -19,7 +19,7 @@ async def login(
1919
) -> DataResponse[LoginResponse]:
2020
"""Login by email and password."""
2121

22-
user = await daos.auth_user.filter_first(email=input_dto.email)
22+
user = await daos.{{ cookiecutter.auth_model.name }}.filter_first(email=input_dto.email)
2323

2424
if user is None:
2525
raise exceptions.Http401("Wrong email or password")
@@ -47,13 +47,13 @@ async def register(
4747
) -> DataResponse:
4848
"""Register by email and password."""
4949

50-
user = await daos.auth_user.filter_first(email=input_dto.email)
50+
user = await daos.{{ cookiecutter.auth_model.name }}.filter_first(email=input_dto.email)
5151

5252
if user:
5353
raise exceptions.Http401("User already exists")
5454

55-
user_id = await daos.auth_user.create(
56-
AuthUserInputDTO(
55+
user_id = await daos.{{ cookiecutter.auth_model.name }}.create(
56+
{{ cookiecutter.auth_model.name_cc }}InputDTO(
5757
email=input_dto.email,
5858
password=auth_utils.hash_password(
5959
input_dto.password.get_secret_value(),
@@ -72,7 +72,7 @@ async def register(
7272
@router.get("/users/me", status_code=200)
7373
async def get_current_user(
7474
current_user: GetCurrentUser,
75-
) -> DataResponse[AuthUserDTO]:
75+
) -> DataResponse[{{ cookiecutter.auth_model.name_cc }}DTO]:
7676
"""Get current user."""
7777

7878
return DataResponse(data=current_user)

0 commit comments

Comments
 (0)