Skip to content

Commit 71ca576

Browse files
committed
Fix: expand {model} placeholders, strict periods default, parse retention_granularity
1 parent f1c138b commit 71ca576

3 files changed

Lines changed: 250 additions & 8 deletions

File tree

sidemantic/adapters/sidemantic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def _parse_model(self, model_def: dict) -> Model | None:
301301
cohort_event=measure_def.get("cohort_event"),
302302
activity_event=measure_def.get("activity_event"),
303303
periods=measure_def.get("periods"),
304-
retention_granularity=measure_def.get("granularity")
304+
retention_granularity=(measure_def.get("retention_granularity") or measure_def.get("granularity"))
305305
if measure_def.get("type") == "retention"
306306
else None,
307307
# Cumulative/window parameters
@@ -423,7 +423,9 @@ def _parse_metric(self, metric_def: dict) -> Metric | None:
423423
cohort_event=metric_def.get("cohort_event"),
424424
activity_event=metric_def.get("activity_event"),
425425
periods=metric_def.get("periods"),
426-
retention_granularity=metric_def.get("granularity") if metric_type == "retention" else None,
426+
retention_granularity=(metric_def.get("retention_granularity") or metric_def.get("granularity"))
427+
if metric_type == "retention"
428+
else None,
427429
window=metric_def.get("window"),
428430
grain_to_date=metric_def.get("grain_to_date"),
429431
window_expression=metric_def.get("window_expression"),
@@ -593,7 +595,7 @@ def _export_model(self, model: Model) -> dict:
593595
if measure.periods is not None:
594596
measure_def["periods"] = measure.periods
595597
if measure.retention_granularity:
596-
measure_def["granularity"] = measure.retention_granularity
598+
measure_def["retention_granularity"] = measure.retention_granularity
597599
# Cumulative/window parameters
598600
if measure.window:
599601
measure_def["window"] = measure.window
@@ -682,7 +684,7 @@ def _export_metric(self, measure: Metric, graph) -> dict:
682684
if measure.periods is not None:
683685
result["periods"] = measure.periods
684686
if measure.retention_granularity:
685-
result["granularity"] = measure.retention_granularity
687+
result["retention_granularity"] = measure.retention_granularity
686688
if measure.sql:
687689
result["sql"] = measure.sql
688690
# Auto-detect and export dependencies for derived measures

sidemantic/sql/generator.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,11 +2310,20 @@ def _generate_retention_query(
23102310
if not model:
23112311
raise ValueError(f"No model found for retention metric {metric_name}")
23122312

2313-
# Defaults
2314-
periods = metric.periods or 28
2315-
granularity = metric.retention_granularity or "day"
2313+
# Defaults (use `is not None` to avoid converting 0 to the default)
2314+
periods = metric.periods if metric.periods is not None else 28
2315+
granularity = metric.retention_granularity if metric.retention_granularity is not None else "day"
23162316
activity_event = metric.activity_event or "TRUE"
23172317

2318+
# Replace {model} placeholders in event predicates with actual table alias
2319+
table_alias = "t" if model.sql else ""
2320+
if table_alias:
2321+
cohort_event = metric.cohort_event.replace("{model}", table_alias)
2322+
activity_event = activity_event.replace("{model}", table_alias)
2323+
else:
2324+
cohort_event = metric.cohort_event.replace("{model}.", "")
2325+
activity_event = activity_event.replace("{model}.", "")
2326+
23182327
# Validate entity identifier
23192328
if not _re.match(r"^[a-zA-Z_][a-zA-Z0-9_.]*$", metric.entity):
23202329
raise ValueError(f"Invalid entity identifier: {metric.entity}")
@@ -2385,7 +2394,7 @@ def _generate_retention_query(
23852394
sql = f"""WITH cohorts AS (
23862395
SELECT {metric.entity}, MIN({trunc_expr}) AS cohort_date
23872396
FROM {from_clause}
2388-
WHERE {metric.cohort_event}{filter_clause}
2397+
WHERE {cohort_event}{filter_clause}
23892398
GROUP BY {metric.entity}
23902399
),
23912400
activity AS (

tests/metrics/test_retention.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,234 @@ def test_retention_week_granularity():
316316
assert week_data[0][4] == 100.0
317317
assert week_data[1][4] == 100.0
318318
assert week_data[2][4] == 50.0
319+
320+
321+
def test_retention_model_placeholder_expansion_sql_model():
322+
"""Test that {model} placeholders in cohort_event/activity_event are expanded to table alias."""
323+
events = Model(
324+
name="events",
325+
sql="""
326+
SELECT 1 AS uid, 'signup' AS event, '2024-01-01'::DATE AS ts
327+
UNION ALL SELECT 1, 'login', '2024-01-02'::DATE
328+
""",
329+
primary_key="uid",
330+
dimensions=[
331+
Dimension(name="uid", sql="uid", type="categorical"),
332+
Dimension(name="event", sql="event", type="categorical"),
333+
Dimension(name="ts", sql="ts", type="time"),
334+
],
335+
metrics=[],
336+
)
337+
338+
retention = Metric(
339+
name="retention",
340+
type="retention",
341+
entity="uid",
342+
cohort_event="{model}.event = 'signup'",
343+
activity_event="{model}.event IS NOT NULL",
344+
periods=1,
345+
retention_granularity="day",
346+
)
347+
348+
graph = SemanticGraph()
349+
graph.add_model(events)
350+
graph.add_metric(retention)
351+
352+
generator = SQLGenerator(graph)
353+
sql = generator.generate(metrics=["retention"], dimensions=[])
354+
355+
# {model} should be replaced with 't' (SQL subquery alias)
356+
assert "{model}" not in sql
357+
assert "t.event = 'signup'" in sql
358+
assert "t.event IS NOT NULL" in sql
359+
360+
# Should still execute correctly
361+
conn = duckdb.connect(":memory:")
362+
result = conn.execute(sql)
363+
rows = df_rows(result)
364+
assert len(rows) > 0
365+
366+
367+
def test_retention_model_placeholder_expansion_table_model():
368+
"""Test that {model} placeholders are stripped for table-backed models."""
369+
conn = duckdb.connect(":memory:")
370+
conn.execute("""
371+
CREATE TABLE test_events AS
372+
SELECT 1 AS uid, 'signup' AS event, '2024-01-01'::DATE AS ts
373+
UNION ALL SELECT 1, 'login', '2024-01-02'::DATE
374+
""")
375+
376+
events = Model(
377+
name="events",
378+
table="test_events",
379+
primary_key="uid",
380+
dimensions=[
381+
Dimension(name="uid", sql="uid", type="categorical"),
382+
Dimension(name="event", sql="event", type="categorical"),
383+
Dimension(name="ts", sql="ts", type="time"),
384+
],
385+
metrics=[],
386+
)
387+
388+
retention = Metric(
389+
name="retention",
390+
type="retention",
391+
entity="uid",
392+
cohort_event="{model}.event = 'signup'",
393+
activity_event="{model}.event IS NOT NULL",
394+
periods=1,
395+
retention_granularity="day",
396+
)
397+
398+
graph = SemanticGraph()
399+
graph.add_model(events)
400+
graph.add_metric(retention)
401+
402+
generator = SQLGenerator(graph)
403+
sql = generator.generate(metrics=["retention"], dimensions=[])
404+
405+
# {model}. should be stripped for table-backed models
406+
assert "{model}" not in sql
407+
assert "event = 'signup'" in sql
408+
409+
result = conn.execute(sql)
410+
rows = df_rows(result)
411+
assert len(rows) > 0
412+
413+
414+
def test_retention_periods_zero_raises_validation_error():
415+
"""Test that periods=0 raises a validation error instead of silently becoming 28."""
416+
events = Model(
417+
name="events",
418+
sql="""
419+
SELECT 1 AS uid, 'signup' AS event, '2024-01-01'::DATE AS ts
420+
""",
421+
primary_key="uid",
422+
dimensions=[
423+
Dimension(name="uid", sql="uid", type="categorical"),
424+
Dimension(name="event", sql="event", type="categorical"),
425+
Dimension(name="ts", sql="ts", type="time"),
426+
],
427+
metrics=[],
428+
)
429+
430+
retention = Metric(
431+
name="retention",
432+
type="retention",
433+
entity="uid",
434+
cohort_event="event = 'signup'",
435+
periods=0,
436+
)
437+
438+
graph = SemanticGraph()
439+
graph.add_model(events)
440+
graph.add_metric(retention)
441+
442+
generator = SQLGenerator(graph)
443+
with pytest.raises(ValueError, match="Invalid periods value"):
444+
generator.generate(metrics=["retention"], dimensions=[])
445+
446+
447+
def test_retention_yaml_retention_granularity_key():
448+
"""Test that YAML with retention_granularity: week parses correctly."""
449+
import os
450+
import tempfile
451+
452+
from sidemantic.adapters.sidemantic import SidemanticAdapter
453+
454+
yaml_content = """
455+
models:
456+
- name: events
457+
table: events
458+
dimensions:
459+
- name: user_id
460+
type: categorical
461+
- name: ts
462+
type: time
463+
metrics:
464+
- name: weekly_retention
465+
type: retention
466+
entity: user_id
467+
cohort_event: "event = 'signup'"
468+
retention_granularity: week
469+
periods: 4
470+
"""
471+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
472+
f.write(yaml_content)
473+
tmp_path = f.name
474+
475+
try:
476+
adapter = SidemanticAdapter()
477+
graph = adapter.parse(tmp_path)
478+
model = graph.get_model("events")
479+
metric = model.get_metric("weekly_retention")
480+
assert metric.retention_granularity == "week"
481+
assert metric.periods == 4
482+
finally:
483+
os.unlink(tmp_path)
484+
485+
486+
def test_retention_yaml_granularity_fallback():
487+
"""Test that YAML with granularity: month also parses for retention metrics."""
488+
import os
489+
import tempfile
490+
491+
from sidemantic.adapters.sidemantic import SidemanticAdapter
492+
493+
yaml_content = """
494+
metrics:
495+
- name: monthly_retention
496+
type: retention
497+
entity: user_id
498+
cohort_event: "event = 'signup'"
499+
granularity: month
500+
periods: 12
501+
"""
502+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
503+
f.write(yaml_content)
504+
tmp_path = f.name
505+
506+
try:
507+
adapter = SidemanticAdapter()
508+
graph = adapter.parse(tmp_path)
509+
metric = graph.get_metric("monthly_retention")
510+
assert metric.retention_granularity == "month"
511+
assert metric.periods == 12
512+
finally:
513+
os.unlink(tmp_path)
514+
515+
516+
def test_retention_export_roundtrip_retention_granularity():
517+
"""Test that export uses retention_granularity key and roundtrips correctly."""
518+
import os
519+
import tempfile
520+
521+
from sidemantic.adapters.sidemantic import SidemanticAdapter
522+
523+
# Create a graph with a retention metric
524+
graph = SemanticGraph()
525+
retention = Metric(
526+
name="weekly_retention",
527+
type="retention",
528+
entity="user_id",
529+
cohort_event="event = 'signup'",
530+
retention_granularity="week",
531+
periods=4,
532+
)
533+
graph.add_metric(retention)
534+
535+
# Export
536+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f:
537+
tmp_path = f.name
538+
539+
try:
540+
adapter = SidemanticAdapter()
541+
adapter.export(graph, tmp_path)
542+
543+
# Re-parse and verify
544+
graph2 = adapter.parse(tmp_path)
545+
metric = graph2.get_metric("weekly_retention")
546+
assert metric.retention_granularity == "week"
547+
assert metric.periods == 4
548+
finally:
549+
os.unlink(tmp_path)

0 commit comments

Comments
 (0)