Skip to content

Commit b7e6aee

Browse files
committed
Add retention metric type for cohort retention analysis
1 parent 9c95340 commit b7e6aee

4 files changed

Lines changed: 539 additions & 1 deletion

File tree

sidemantic/adapters/sidemantic.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ def _parse_model(self, model_def: dict) -> Model | None:
297297
conversion_event=measure_def.get("conversion_event"),
298298
conversion_window=measure_def.get("conversion_window"),
299299
offset_window=measure_def.get("offset_window"),
300+
# Retention parameters
301+
cohort_event=measure_def.get("cohort_event"),
302+
activity_event=measure_def.get("activity_event"),
303+
periods=measure_def.get("periods"),
304+
retention_granularity=measure_def.get("granularity")
305+
if measure_def.get("type") == "retention"
306+
else None,
300307
# Cumulative/window parameters
301308
window=measure_def.get("window"),
302309
grain_to_date=measure_def.get("grain_to_date"),
@@ -413,6 +420,10 @@ def _parse_metric(self, metric_def: dict) -> Metric | None:
413420
conversion_event=metric_def.get("conversion_event"),
414421
conversion_window=metric_def.get("conversion_window"),
415422
offset_window=metric_def.get("offset_window"),
423+
cohort_event=metric_def.get("cohort_event"),
424+
activity_event=metric_def.get("activity_event"),
425+
periods=metric_def.get("periods"),
426+
retention_granularity=metric_def.get("granularity") if metric_type == "retention" else None,
416427
window=metric_def.get("window"),
417428
grain_to_date=metric_def.get("grain_to_date"),
418429
window_expression=metric_def.get("window_expression"),
@@ -574,6 +585,15 @@ def _export_model(self, model: Model) -> dict:
574585
measure_def["conversion_window"] = measure.conversion_window
575586
if measure.offset_window:
576587
measure_def["offset_window"] = measure.offset_window
588+
# Retention parameters
589+
if measure.cohort_event:
590+
measure_def["cohort_event"] = measure.cohort_event
591+
if measure.activity_event:
592+
measure_def["activity_event"] = measure.activity_event
593+
if measure.periods is not None:
594+
measure_def["periods"] = measure.periods
595+
if measure.retention_granularity:
596+
measure_def["granularity"] = measure.retention_granularity
577597
# Cumulative/window parameters
578598
if measure.window:
579599
measure_def["window"] = measure.window
@@ -655,6 +675,14 @@ def _export_metric(self, measure: Metric, graph) -> dict:
655675
result["conversion_window"] = measure.conversion_window
656676
if measure.offset_window:
657677
result["offset_window"] = measure.offset_window
678+
if measure.cohort_event:
679+
result["cohort_event"] = measure.cohort_event
680+
if measure.activity_event:
681+
result["activity_event"] = measure.activity_event
682+
if measure.periods is not None:
683+
result["periods"] = measure.periods
684+
if measure.retention_granularity:
685+
result["granularity"] = measure.retention_granularity
658686
if measure.sql:
659687
result["sql"] = measure.sql
660688
# Auto-detect and export dependencies for derived measures

sidemantic/core/metric.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,15 @@ def validate_type_specific_fields(self):
204204
raise ValueError("conversion metric requires 'base_event' field")
205205
if not self.conversion_event:
206206
raise ValueError("conversion metric requires 'conversion_event' field")
207+
if self.type == "retention":
208+
if not self.entity:
209+
raise ValueError("retention metric requires 'entity' field")
210+
if not self.cohort_event:
211+
raise ValueError("retention metric requires 'cohort_event' field")
207212
return self
208213

209214
# Metric type (if this is a complex metric, not just a simple aggregation)
210-
type: Literal["ratio", "derived", "cumulative", "time_comparison", "conversion"] | None = Field(
215+
type: Literal["ratio", "derived", "cumulative", "time_comparison", "conversion", "retention"] | None = Field(
211216
None, description="Metric type for complex calculations"
212217
)
213218

@@ -252,6 +257,18 @@ def validate_type_specific_fields(self):
252257
conversion_event: str | None = Field(None, description="Target event filter")
253258
conversion_window: str | None = Field(None, description="Conversion time window")
254259

260+
# Retention parameters
261+
cohort_event: str | None = Field(
262+
None, description="SQL filter for cohort-defining event (e.g., \"event = 'install'\")"
263+
)
264+
activity_event: str | None = Field(
265+
None, description='SQL filter for activity event (default: any event, e.g., "event IS NOT NULL")'
266+
)
267+
periods: int | None = Field(None, description="Number of retention periods to compute (e.g., 28 for 28-day)")
268+
retention_granularity: Literal["day", "week", "month"] | None = Field(
269+
None, description="Time granularity for retention periods (day, week, month)"
270+
)
271+
255272
# Common parameters
256273
filters: list[str] | None = Field(None, description="Optional WHERE clause filters")
257274
fill_nulls_with: int | float | str | None = Field(None, description="Default value when result is NULL")

sidemantic/sql/generator.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ def metric_needs_window(m):
323323
# Conversion metrics need special handling
324324
if metric.type == "conversion":
325325
return True
326+
# Retention metrics need special handling
327+
if metric.type == "retention":
328+
return True
326329
return False
327330

328331
needs_window_functions = any(metric_needs_window(m) for m in metrics)
@@ -2197,6 +2200,170 @@ def resolve_ratio_ref(ref: str) -> str:
21972200

21982201
raise NotImplementedError(f"Metric type {metric.type} not yet implemented")
21992202

2203+
def _generate_retention_query(
2204+
self,
2205+
metric_name: str,
2206+
dimensions: list[str],
2207+
filters: list[str] | None = None,
2208+
order_by: list[str] | None = None,
2209+
limit: int | None = None,
2210+
) -> str:
2211+
"""Generate SQL for cohort retention metrics.
2212+
2213+
Uses a multi-CTE pattern:
2214+
1. cohorts: identify each entity's first qualifying event (cohort_date)
2215+
2. activity: distinct entity activity dates
2216+
3. retention: join cohorts to activity, compute periods_since and active_users
2217+
4. cohort_sizes: count of entities per cohort_date
2218+
5. Final SELECT: retention percentage per cohort_date and period
2219+
2220+
Args:
2221+
metric_name: Name of the retention metric
2222+
dimensions: List of dimension references (unused for retention, reserved)
2223+
filters: List of filter expressions
2224+
order_by: List of fields to order by
2225+
limit: Maximum number of rows to return
2226+
2227+
Returns:
2228+
SQL query string
2229+
"""
2230+
import re as _re
2231+
2232+
# Resolve metric and model
2233+
metric = None
2234+
model = None
2235+
2236+
if "." in metric_name:
2237+
model_name, measure_name = metric_name.split(".", 1)
2238+
model = self.graph.get_model(model_name)
2239+
if model:
2240+
metric = model.get_metric(measure_name)
2241+
else:
2242+
try:
2243+
metric = self.graph.get_metric(metric_name)
2244+
except KeyError:
2245+
pass
2246+
2247+
if not metric or not metric.entity or not metric.cohort_event:
2248+
raise ValueError(f"Retention metric {metric_name} missing required fields (entity, cohort_event)")
2249+
2250+
# Find the model that owns this metric if not already found
2251+
if not model:
2252+
for m_name, m in self.graph.models.items():
2253+
if m.get_metric(metric_name):
2254+
model = m
2255+
break
2256+
if not model:
2257+
for m_name, m in self.graph.models.items():
2258+
for dim in m.dimensions:
2259+
if dim.name == metric.entity:
2260+
model = m
2261+
break
2262+
if model:
2263+
break
2264+
2265+
if not model:
2266+
raise ValueError(f"No model found for retention metric {metric_name}")
2267+
2268+
# Defaults
2269+
periods = metric.periods or 28
2270+
granularity = metric.retention_granularity or "day"
2271+
activity_event = metric.activity_event or "TRUE"
2272+
2273+
# Validate entity identifier
2274+
if not _re.match(r"^[a-zA-Z_][a-zA-Z0-9_.]*$", metric.entity):
2275+
raise ValueError(f"Invalid entity identifier: {metric.entity}")
2276+
if not isinstance(periods, int) or periods < 1:
2277+
raise ValueError(f"Invalid periods value: {periods}")
2278+
2279+
# Find timestamp dimension
2280+
timestamp_dim = None
2281+
for dim in model.dimensions:
2282+
if dim.type == "time":
2283+
timestamp_dim = dim.name
2284+
break
2285+
2286+
if not timestamp_dim:
2287+
raise ValueError("Retention metrics require a time dimension on the model")
2288+
2289+
# Build FROM clause
2290+
if model.sql:
2291+
from_clause = f"({model.sql}) AS t"
2292+
else:
2293+
from_clause = model.table
2294+
2295+
# Build granularity-specific date truncation and interval
2296+
if granularity == "day":
2297+
trunc_expr = f"{timestamp_dim}::date"
2298+
diff_expr = "a.active_date - c.cohort_date"
2299+
periods_label = "days_since"
2300+
elif granularity == "week":
2301+
trunc_expr = f"{self._date_trunc('week', timestamp_dim)}::date"
2302+
diff_expr = "(a.active_date - c.cohort_date) / 7"
2303+
periods_label = "weeks_since"
2304+
elif granularity == "month":
2305+
trunc_expr = f"{self._date_trunc('month', timestamp_dim)}::date"
2306+
diff_expr = (
2307+
"(EXTRACT(YEAR FROM a.active_date) - EXTRACT(YEAR FROM c.cohort_date)) * 12"
2308+
" + (EXTRACT(MONTH FROM a.active_date) - EXTRACT(MONTH FROM c.cohort_date))"
2309+
)
2310+
periods_label = "months_since"
2311+
else:
2312+
raise ValueError(f"Unsupported retention granularity: {granularity}")
2313+
2314+
# Build optional WHERE filters for the source data
2315+
filter_clause = ""
2316+
if filters:
2317+
filter_clause = " AND " + " AND ".join(filters)
2318+
2319+
order_clause = "\nORDER BY r.cohort_date, r.periods_since"
2320+
if order_by:
2321+
order_fields = []
2322+
for field in order_by:
2323+
field_name = field.split(".", 1)[1] if "." in field else field
2324+
order_fields.append(field_name)
2325+
order_clause = f"\nORDER BY {', '.join(order_fields)}"
2326+
2327+
limit_clause = ""
2328+
if limit is not None:
2329+
limit_clause = f"\nLIMIT {limit}"
2330+
2331+
sql = f"""WITH cohorts AS (
2332+
SELECT {metric.entity}, MIN({trunc_expr}) AS cohort_date
2333+
FROM {from_clause}
2334+
WHERE {metric.cohort_event}{filter_clause}
2335+
GROUP BY {metric.entity}
2336+
),
2337+
activity AS (
2338+
SELECT DISTINCT {metric.entity}, {trunc_expr} AS active_date
2339+
FROM {from_clause}
2340+
WHERE {activity_event}{filter_clause}
2341+
),
2342+
retention AS (
2343+
SELECT
2344+
c.cohort_date,
2345+
CAST({diff_expr} AS INTEGER) AS periods_since,
2346+
COUNT(DISTINCT c.{metric.entity}) AS active_users
2347+
FROM cohorts c
2348+
JOIN activity a ON c.{metric.entity} = a.{metric.entity} AND a.active_date >= c.cohort_date
2349+
WHERE CAST({diff_expr} AS INTEGER) <= {periods}
2350+
GROUP BY 1, 2
2351+
),
2352+
cohort_sizes AS (
2353+
SELECT cohort_date, COUNT(DISTINCT {metric.entity}) AS cohort_size
2354+
FROM cohorts GROUP BY 1
2355+
)
2356+
SELECT
2357+
r.cohort_date,
2358+
r.periods_since AS {periods_label},
2359+
r.active_users,
2360+
c.cohort_size,
2361+
ROUND(r.active_users * 100.0 / c.cohort_size, 1) AS retention_pct
2362+
FROM retention r
2363+
JOIN cohort_sizes c ON r.cohort_date = c.cohort_date{order_clause}{limit_clause}"""
2364+
2365+
return sql.strip()
2366+
22002367
def _generate_conversion_query(
22012368
self,
22022369
metric_name: str,
@@ -2427,6 +2594,7 @@ def _generate_with_window_functions(
24272594
time_comparison_metrics = []
24282595
offset_ratio_metrics = []
24292596
conversion_metrics = []
2597+
retention_metrics = []
24302598
base_metrics = []
24312599
time_comparison_base_plans = {}
24322600
regular_expression_metric_plans = {}
@@ -2570,6 +2738,9 @@ def collect_leaf_base_metrics(
25702738
elif metric and metric.type == "conversion":
25712739
add_unique(conversion_metrics, m)
25722740
# Conversion metrics need special handling - don't add to base_metrics
2741+
elif metric and metric.type == "retention":
2742+
add_unique(retention_metrics, m)
2743+
# Retention metrics need special handling - don't add to base_metrics
25732744
else:
25742745
# Regular metric or measure
25752746
metric_ref = canonical_ref(m, resolved_context)
@@ -2596,6 +2767,10 @@ def collect_leaf_base_metrics(
25962767
if precomputed_cumulative_metrics:
25972768
cumulative_metrics = [m for m in cumulative_metrics if m not in precomputed_cumulative_metrics]
25982769

2770+
# Handle retention metrics separately - they need a completely different pattern
2771+
if retention_metrics:
2772+
return self._generate_retention_query(retention_metrics[0], dimensions, filters, order_by, limit)
2773+
25992774
# Handle conversion metrics separately - they need a completely different pattern
26002775
if conversion_metrics:
26012776
return self._generate_conversion_query(conversion_metrics[0], dimensions, filters, order_by, limit)

0 commit comments

Comments
 (0)