Skip to content

Commit b87b32c

Browse files
authored
Fix aggregation parsing for complex expressions (#31)
- Replace greedy regex with sqlglot in Metric class to properly handle expressions like SUM(x) / SUM(y) without mangling them - Fix COUNT DISTINCT detection using isinstance(parsed.this, exp.Distinct) - Add expression metric support in SQL generator for metrics with inline aggregations (agg=None, type=None, sql=<expression>) - Fix dependency analyzer to skip resolution for expression metrics with inline aggregations - Fix cumulative metrics to properly resolve references to other measures and generate valid aliases Add kitchen sink tests using patterns from rill-examples to catch edge cases.
1 parent 18b945f commit b87b32c

6 files changed

Lines changed: 601 additions & 68 deletions

File tree

sidemantic/adapters/rill.py

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from pathlib import Path
88
from typing import Any
99

10-
import sqlglot
1110
import yaml
12-
from sqlglot import expressions as exp
1311

1412
from sidemantic.core.dimension import Dimension
1513
from sidemantic.core.metric import Metric
@@ -203,46 +201,14 @@ def _parse_measure(self, measure_def: dict[str, Any]) -> Metric | None:
203201
# "simple" = basic aggregation (None type), "derived" = calculation using other measures
204202
metric_type = "derived"
205203

206-
# Use sqlglot to detect simple aggregations
207-
agg_type = None
208-
agg_sql = None
209-
try:
210-
parsed = sqlglot.parse_one(expression, read="duckdb")
211-
212-
# Check if this is a simple aggregation function
213-
if isinstance(parsed, (exp.Sum, exp.Avg, exp.Count, exp.Min, exp.Max)):
214-
# Map sqlglot aggregation types to Sidemantic agg types
215-
if isinstance(parsed, exp.Sum):
216-
agg_type = "sum"
217-
elif isinstance(parsed, exp.Avg):
218-
agg_type = "avg"
219-
elif isinstance(parsed, exp.Count):
220-
if parsed.args.get("distinct"):
221-
agg_type = "count_distinct"
222-
else:
223-
agg_type = "count"
224-
elif isinstance(parsed, exp.Min):
225-
agg_type = "min"
226-
elif isinstance(parsed, exp.Max):
227-
agg_type = "max"
228-
229-
# Extract the aggregated column/expression
230-
agg_arg = parsed.this
231-
if agg_arg:
232-
agg_sql = agg_arg.sql(dialect="duckdb")
233-
elif isinstance(parsed, exp.Count):
234-
# COUNT(*) case
235-
agg_sql = None
236-
except Exception:
237-
# If parsing fails, treat as custom SQL expression
238-
pass
239-
204+
# Let the Metric class handle aggregation parsing via its model_validator.
205+
# This properly handles complex expressions like SUM(x) / SUM(y) and
206+
# COUNT(DISTINCT col) using sqlglot.
240207
return Metric(
241208
name=name,
242209
label=label,
243210
description=description,
244-
agg=agg_type,
245-
sql=agg_sql if agg_type else expression,
211+
sql=expression, # Pass full expression, Metric will parse aggregations
246212
type=metric_type,
247213
value_format_name=value_format_name,
248214
window_order=window_order,

sidemantic/core/dependency_analyzer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ def extract_metric_dependencies(metric_obj, graph=None, model_context=None) -> s
5656
deps.add(metric_obj.sql)
5757
return deps
5858

59+
# Check if this is an expression metric with inline aggregations
60+
# (e.g., SUM(x) / SUM(y), COUNT(DISTINCT col) * 1.0)
61+
# These don't have measure dependencies - the aggregations are inline
62+
try:
63+
parsed = sqlglot.parse_one(metric_obj.sql)
64+
# Check if the expression contains any aggregation functions
65+
agg_types = (exp.Sum, exp.Avg, exp.Count, exp.Min, exp.Max, exp.Median)
66+
has_inline_agg = any(parsed.find_all(*agg_types))
67+
if has_inline_agg and not metric_obj.type:
68+
# Expression metric with inline aggregations - no measure dependencies
69+
return deps
70+
except Exception:
71+
pass
72+
5973
# Extract column references from expression
6074
refs = extract_column_references(metric_obj.sql)
6175

sidemantic/core/metric.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ def handle_expr_and_parse_agg(cls, data):
4545
4646
1. Converts expr= to sql= for backwards compatibility
4747
2. Parses aggregation functions from SQL (e.g., SUM(amount) -> agg=sum, sql=amount)
48-
"""
49-
import re
5048
49+
Uses sqlglot to properly parse expressions and handle nested parentheses.
50+
Only extracts aggregation from SIMPLE expressions (single aggregation function).
51+
Complex expressions like SUM(x) / SUM(y) are preserved as-is.
52+
"""
5153
if isinstance(data, dict):
5254
# Step 1: Handle expr alias
5355
expr_val = data.get("expr")
@@ -72,23 +74,61 @@ def handle_expr_and_parse_agg(cls, data):
7274
# Parse if sql is provided and agg is not set
7375
# Allow parsing for simple metrics (no type) OR cumulative metrics (to support AVG/COUNT windows)
7476
if sql_val and not agg_val and (not type_val or type_val == "cumulative"):
75-
# Match aggregation functions at the start: SUM(expr), COUNT(expr), etc.
76-
agg_pattern = r"^\s*(SUM|COUNT|AVG|MIN|MAX|MEDIAN|COUNT_DISTINCT)\s*\((.*)\)\s*$"
77-
match = re.match(agg_pattern, sql_val, re.IGNORECASE)
78-
79-
if match:
80-
agg_func = match.group(1).lower()
81-
inner_expr = match.group(2).strip()
82-
83-
# Extract DISTINCT for COUNT(DISTINCT col)
84-
if agg_func == "count":
85-
distinct_match = re.match(r"^\s*DISTINCT\s+(.+)$", inner_expr, re.IGNORECASE)
86-
if distinct_match:
77+
try:
78+
import sqlglot
79+
from sqlglot import expressions as exp
80+
81+
parsed = sqlglot.parse_one(sql_val, read="duckdb")
82+
83+
# Only extract if the TOP-LEVEL expression is a simple aggregation
84+
# This prevents breaking expressions like SUM(x) / SUM(y)
85+
agg_map = {
86+
exp.Sum: "sum",
87+
exp.Avg: "avg",
88+
exp.Min: "min",
89+
exp.Max: "max",
90+
exp.Median: "median",
91+
}
92+
93+
agg_func = None
94+
inner_expr = None
95+
96+
# Check for standard aggregations
97+
for agg_class, agg_name in agg_map.items():
98+
if isinstance(parsed, agg_class):
99+
agg_func = agg_name
100+
if parsed.this:
101+
inner_expr = parsed.this.sql(dialect="duckdb")
102+
break
103+
104+
# Handle COUNT specially (need to detect DISTINCT)
105+
if isinstance(parsed, exp.Count):
106+
# Check if the argument is a Distinct expression
107+
if isinstance(parsed.this, exp.Distinct):
87108
agg_func = "count_distinct"
88-
inner_expr = distinct_match.group(1).strip()
89-
90-
data["agg"] = agg_func
91-
data["sql"] = inner_expr
109+
# Extract all expressions from inside Distinct
110+
# e.g., COUNT(DISTINCT a, b) -> "a, b"
111+
if parsed.this.expressions:
112+
inner_expr = ", ".join(e.sql(dialect="duckdb") for e in parsed.this.expressions)
113+
else:
114+
inner_expr = parsed.this.sql(dialect="duckdb")
115+
else:
116+
agg_func = "count"
117+
if parsed.this:
118+
inner_expr = parsed.this.sql(dialect="duckdb")
119+
# COUNT(*) case - inner_expr stays None
120+
121+
if agg_func:
122+
data["agg"] = agg_func
123+
if inner_expr is not None:
124+
data["sql"] = inner_expr
125+
elif agg_func == "count":
126+
# COUNT(*) - leave sql as None or "*"
127+
data["sql"] = None
128+
129+
except Exception:
130+
# If sqlglot parsing fails, leave the expression as-is
131+
pass
92132

93133
return data
94134

sidemantic/sql/generator.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,13 +1398,15 @@ def _build_main_select(
13981398
# Complex metric types (derived, ratio) can be built inline
13991399
# Note: cumulative, time_comparison, conversion are handled via special query generators
14001400
# and won't appear in this code path
1401-
if measure.type in ["derived", "ratio"]:
1401+
# Also handle "expression metrics" - metrics with inline aggregations like SUM(x)/SUM(y)
1402+
is_expression_metric = not measure.type and not measure.agg and measure.sql
1403+
if measure.type in ["derived", "ratio"] or is_expression_metric:
14021404
# Use complex metric builder
14031405
metric_expr = self._build_metric_sql(measure, model_name)
14041406
metric_expr = self._wrap_with_fill_nulls(metric_expr, measure)
14051407
select_exprs.append(f"{metric_expr} AS {alias}")
14061408
elif not measure.agg:
1407-
# Complex types that need special handling (shouldn't reach here normally)
1409+
# Unknown metric type that needs special handling
14081410
raise ValueError(
14091411
f"Metric '{measure.name}' with type '{measure.type}' cannot be queried directly. "
14101412
f"Use generate() instead of _build_main_select() for this metric type."
@@ -1736,7 +1738,12 @@ def _build_metric_sql(self, metric, model_context: str | None = None) -> str:
17361738

17371739
# Check if this is a SQL expression metric (has inline aggregations)
17381740
# These metrics already contain complete SQL and shouldn't have dependencies replaced
1739-
has_inline_agg = any(agg in formula.upper() for agg in ["COUNT(", "SUM(", "AVG(", "MIN(", "MAX("])
1741+
try:
1742+
parsed = sqlglot.parse_one(formula, read=self.dialect)
1743+
agg_types = (exp.Sum, exp.Avg, exp.Count, exp.Min, exp.Max, exp.Median)
1744+
has_inline_agg = any(parsed.find_all(*agg_types))
1745+
except Exception:
1746+
has_inline_agg = False
17401747

17411748
if has_inline_agg:
17421749
# This is a SQL expression metric with inline aggregations.
@@ -2009,7 +2016,12 @@ def _generate_with_window_functions(
20092016
cumulative_metrics.append(m)
20102017
# Add the base measure/metric to base_metrics
20112018
if metric.sql:
2012-
base_metrics.append(metric.sql)
2019+
base_ref = metric.sql
2020+
# Qualify unqualified references with the model name
2021+
if "." not in base_ref and "." in m:
2022+
model_name = m.split(".")[0]
2023+
base_ref = f"{model_name}.{base_ref}"
2024+
base_metrics.append(base_ref)
20132025
elif metric and metric.type == "time_comparison":
20142026
# Validate required fields
20152027
if not metric.base_metric:
@@ -2076,7 +2088,16 @@ def _generate_with_window_functions(
20762088

20772089
# Add cumulative metrics with window functions
20782090
for m in cumulative_metrics:
2079-
metric = self.graph.get_metric(m)
2091+
# Handle both qualified (model.measure) and unqualified references
2092+
if "." in m:
2093+
model_name, measure_name = m.split(".", 1)
2094+
model = self.graph.get_model(model_name)
2095+
metric = model.get_metric(measure_name) if model else None
2096+
# Use just the measure name as the alias (not model.measure)
2097+
metric_alias = measure_name
2098+
else:
2099+
metric = self.graph.get_metric(m)
2100+
metric_alias = m
20802101
if not metric or (not metric.sql and not metric.window_expression):
20812102
continue
20822103

@@ -2107,7 +2128,7 @@ def _generate_with_window_functions(
21072128
if metric.window_expression:
21082129
order_col = time_dim
21092130
frame = metric.window_frame or "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
2110-
window_expr = f"{metric.window_expression} OVER (ORDER BY {order_col} {frame}) AS {m}"
2131+
window_expr = f"{metric.window_expression} OVER (ORDER BY {order_col} {frame}) AS {metric_alias}"
21112132
select_exprs.append(window_expr)
21122133
continue
21132134

@@ -2118,8 +2139,22 @@ def _generate_with_window_functions(
21182139
# It's a direct measure reference - extract just the measure name
21192140
base_alias = base_ref.split(".")[1]
21202141
else:
2121-
# It's a metric reference - check if it exists and get its underlying measure
2122-
base_metric = self.graph.get_metric(base_ref)
2142+
# It's an unqualified reference - check model first, then graph-level
2143+
base_metric = None
2144+
# Get model name from the cumulative metric reference
2145+
cum_model_name = m.split(".")[0] if "." in m else None
2146+
if cum_model_name:
2147+
cum_model = self.graph.get_model(cum_model_name)
2148+
if cum_model:
2149+
base_metric = cum_model.get_metric(base_ref)
2150+
2151+
# Fallback to graph-level metric
2152+
if not base_metric:
2153+
try:
2154+
base_metric = self.graph.get_metric(base_ref)
2155+
except KeyError:
2156+
pass
2157+
21232158
if base_metric and base_metric.sql:
21242159
# Use the underlying measure name
21252160
if "." in base_metric.sql:
@@ -2145,20 +2180,20 @@ def _generate_with_window_functions(
21452180
grain = metric.grain_to_date
21462181
partition = self._date_trunc(grain, time_dim)
21472182

2148-
window_expr = f"{agg_func}({base_col}) OVER (PARTITION BY {partition} ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {m}"
2183+
window_expr = f"{agg_func}({base_col}) OVER (PARTITION BY {partition} ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {metric_alias}"
21492184
elif metric.window:
21502185
# Parse window (e.g., "7 days")
21512186
window_parts = metric.window.split()
21522187
if len(window_parts) == 2:
21532188
num, unit = window_parts
21542189
# For date-based windows, use RANGE
2155-
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} RANGE BETWEEN INTERVAL '{num} {unit}' PRECEDING AND CURRENT ROW) AS {m}"
2190+
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} RANGE BETWEEN INTERVAL '{num} {unit}' PRECEDING AND CURRENT ROW) AS {metric_alias}"
21562191
else:
21572192
# Fallback to rows
2158-
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {m}"
2193+
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {metric_alias}"
21592194
else:
21602195
# Running total (unbounded window)
2161-
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {m}"
2196+
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {metric_alias}"
21622197

21632198
select_exprs.append(window_expr)
21642199

0 commit comments

Comments
 (0)