Skip to content

Commit c862f71

Browse files
authored
update: improve project generation security (#39)
1 parent 5682000 commit c862f71

2 files changed

Lines changed: 72 additions & 25 deletions

File tree

fastapi_forge/forge.py

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import shutil
3-
from pathlib import Path
3+
from pathlib import Path, PurePath
44
from time import perf_counter
55

66
from cookiecutter.main import cookiecutter
@@ -11,32 +11,73 @@
1111

1212

1313
def _get_template_path() -> Path:
14-
"""Return the absolute path to the project template directory."""
15-
template_path = Path(__file__).parent / "template"
14+
"""Return the absolute path to the project template directory with validation."""
15+
template_path = Path(__file__).resolve().parent / "template"
1616
if not template_path.exists():
1717
raise RuntimeError(f"Template directory not found: {template_path}")
18+
if not template_path.is_dir():
19+
raise RuntimeError(f"Template path is not a directory: {template_path}")
1820
return template_path
1921

2022

21-
async def _teardown_project(project_name: str) -> None:
22-
"""Forcefully remove the project directory and all its contents."""
23-
project_dir = Path.cwd() / project_name
24-
if project_dir.exists():
25-
await asyncio.to_thread(shutil.rmtree, project_dir)
26-
logger.info(f"Removed project directory: {project_dir}")
23+
def _validate_project_name(project_name: str) -> None:
24+
"""Validate that the project name is safe to use in paths."""
25+
if not project_name:
26+
msg = "Project name cannot be empty"
27+
raise ValueError(msg)
28+
if PurePath(project_name).name != project_name:
29+
raise ValueError(
30+
f"Invalid project name: {project_name} (contains path traversal)"
31+
)
32+
if not project_name.isidentifier():
33+
logger.warning(
34+
f"Project name '{project_name}' may not be a valid Python identifier"
35+
)
36+
37+
38+
async def _teardown_project(project_name: str, *, dry_run: bool = False) -> bool:
39+
"""Safely remove the project directory and all its contents."""
40+
project_dir = Path.cwd().resolve() / project_name
41+
42+
if not project_dir.exists():
43+
logger.debug(f"Project directory does not exist: {project_dir}")
44+
return False
45+
46+
if not project_dir.is_dir():
47+
logger.warning(f"Path exists but is not a directory: {project_dir}")
48+
return False
49+
50+
if not any(project_dir.glob("pyproject.toml")):
51+
logger.warning(
52+
f"Directory {project_dir} does not appear to be a project "
53+
"(missing pyproject.toml) - skipping deletion"
54+
)
55+
return False
56+
57+
try:
58+
logger.info(
59+
f"{'Would remove' if dry_run else 'Removing'} project directory: {project_dir}"
60+
)
61+
if not dry_run:
62+
await asyncio.to_thread(shutil.rmtree, project_dir)
63+
except Exception as e:
64+
logger.error(f"Failed to remove project directory {project_dir}: {e!s}")
65+
return False
66+
return True
2767

2868

2969
async def build_project(spec: ProjectSpec) -> None:
3070
"""Create a new project using the provided template and specifications."""
71+
start_time = perf_counter()
72+
project_name = spec.project_name
73+
3174
try:
32-
start = perf_counter()
33-
logger.info(f"Building project '{spec.project_name}'...")
75+
_validate_project_name(project_name)
76+
logger.info(f"Building project '{project_name}'...")
3477

3578
builder = ProjectBuilder(spec)
3679
await builder.build_artifacts()
3780

38-
template_path = str(_get_template_path())
39-
4081
extra_context = {
4182
**spec.model_dump(exclude={"models"}),
4283
"models": {
@@ -53,17 +94,19 @@ async def build_project(spec: ProjectSpec) -> None:
5394
extra_context["use_builtin_auth"] = False
5495

5596
cookiecutter(
56-
template_path,
57-
output_dir=str(Path.cwd()),
97+
template=str(_get_template_path()),
98+
output_dir=str(Path.cwd().resolve()),
5899
no_input=True,
59100
overwrite_if_exists=True,
60101
extra_context=extra_context,
61102
)
62-
logger.info(f"Project '{spec.project_name}' created successfully.")
63103

64-
end = perf_counter()
65-
logger.info(f"Project built in {end - start:.2f} seconds.")
66-
except Exception as exc:
67-
logger.error(f"Failed to create project: {exc}")
68-
await _teardown_project(spec.project_name)
104+
build_time = perf_counter() - start_time
105+
logger.info(
106+
f"Project '{project_name}' created successfully in {build_time:.2f} seconds."
107+
)
108+
109+
except Exception as error:
110+
logger.error(f"Failed to create project '{project_name}': {error}")
111+
69112
raise

fastapi_forge/frontend/panels/project_config_panel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,14 @@ async def _warn_overwrite(self) -> bool:
265265

266266
async def _create_project(self) -> None:
267267
"""Generate the project based on the current state."""
268+
if not state.project_name:
269+
ui.notify("Missing project name!", type="negative")
270+
return
271+
272+
if not state.models:
273+
ui.notify("No models to generate!", type="negative")
274+
return
275+
268276
project_path = Path(state.project_name)
269277

270278
if project_path.exists():
@@ -282,10 +290,6 @@ async def _create_project(self) -> None:
282290
ongoing_notification = ui.notification("Generating project...")
283291

284292
try:
285-
if not state.models:
286-
ui.notify("No models to generate!", type="negative")
287-
return
288-
289293
state.project_name = self.project_name.value
290294
state.use_postgres = self.use_postgres.value
291295
state.use_alembic = self.use_alembic.value

0 commit comments

Comments
 (0)