Skip to content

Commit cf241dc

Browse files
authored
Add Yardstick adapter and SQL rewrite parity (#104)
* Add Yardstick adapter and parity SQL rewriting * Fix Yardstick strict passthrough and alias scoping * Port Yardstick measure syntax with replay parity * Fix Yardstick scope rewrite and alias quoting
1 parent 6b483ff commit cf241dc

16 files changed

Lines changed: 7225 additions & 65 deletions

sidemantic/adapters/yardstick.py

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
"""Yardstick adapter for importing SQL models with AS MEASURE semantics."""
2+
3+
from functools import lru_cache
4+
from pathlib import Path
5+
from typing import Literal, get_args, get_origin
6+
7+
import sqlglot
8+
from sqlglot import exp
9+
from sqlglot.dialects.duckdb import DuckDB
10+
from sqlglot.tokens import TokenType
11+
12+
from sidemantic.adapters.base import BaseAdapter
13+
from sidemantic.core.dimension import Dimension
14+
from sidemantic.core.metric import Metric
15+
from sidemantic.core.model import Model
16+
from sidemantic.core.semantic_graph import SemanticGraph
17+
18+
19+
def _extract_literal_strings(annotation) -> set[str]:
20+
if get_origin(annotation) is Literal:
21+
return {value for value in get_args(annotation) if isinstance(value, str)}
22+
23+
values = set()
24+
for arg in get_args(annotation):
25+
values.update(_extract_literal_strings(arg))
26+
return values
27+
28+
29+
@lru_cache(maxsize=1)
30+
def _supported_metric_aggs() -> set[str]:
31+
annotation = Metric.model_fields["agg"].annotation
32+
return _extract_literal_strings(annotation)
33+
34+
35+
class YardstickDialect(DuckDB):
36+
"""DuckDB dialect extension that supports `AS MEASURE <alias>`."""
37+
38+
class Parser(DuckDB.Parser):
39+
"""Parser extension for Yardstick's measure alias syntax."""
40+
41+
def _parse_alias(self, this: exp.Expression | None, explicit: bool = False) -> exp.Expression | None:
42+
if self._can_parse_limit_or_offset():
43+
return this
44+
45+
any_token = self._match(TokenType.ALIAS)
46+
comments = self._prev_comments or []
47+
48+
if explicit and not any_token:
49+
return this
50+
51+
if self._match(TokenType.L_PAREN):
52+
aliases = self.expression(
53+
exp.Aliases,
54+
comments=comments,
55+
this=this,
56+
expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
57+
)
58+
self._match_r_paren(aliases)
59+
return aliases
60+
61+
is_measure_alias = bool(any_token and self._match_texts({"MEASURE"}))
62+
alias = self._parse_id_var(any_token, tokens=self.ALIAS_TOKENS) or (
63+
self.STRING_ALIASES and self._parse_string_as_identifier()
64+
)
65+
66+
if alias:
67+
comments.extend(alias.pop_comments())
68+
this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
69+
if is_measure_alias:
70+
this.set("yardstick_measure", True)
71+
72+
column = this.this
73+
if not this.comments and column and column.comments:
74+
this.comments = column.pop_comments()
75+
76+
return this
77+
78+
79+
class YardstickAdapter(BaseAdapter):
80+
"""Adapter for Yardstick SQL definitions.
81+
82+
Yardstick defines measures inside CREATE VIEW statements with:
83+
`AGG(expr) AS MEASURE measure_name`.
84+
"""
85+
86+
_SIMPLE_AGGREGATIONS: dict[type[exp.Expression], str] = {
87+
exp.Sum: "sum",
88+
exp.Avg: "avg",
89+
exp.Min: "min",
90+
exp.Max: "max",
91+
exp.Median: "median",
92+
exp.Stddev: "stddev",
93+
exp.StddevPop: "stddev_pop",
94+
exp.Variance: "variance",
95+
exp.VariancePop: "variance_pop",
96+
}
97+
_ANONYMOUS_AGGREGATIONS: set[str] = {"mode"}
98+
99+
def parse(self, source: str | Path) -> SemanticGraph:
100+
"""Parse Yardstick SQL files into a semantic graph."""
101+
source_path = Path(source)
102+
if not source_path.exists():
103+
raise FileNotFoundError(f"Path does not exist: {source_path}")
104+
105+
graph = SemanticGraph()
106+
if source_path.is_dir():
107+
for sql_file in sorted(source_path.rglob("*.sql")):
108+
self._parse_sql_file(sql_file, graph)
109+
else:
110+
self._parse_sql_file(source_path, graph)
111+
112+
return graph
113+
114+
def _parse_sql_file(self, path: Path, graph: SemanticGraph) -> None:
115+
content = path.read_text()
116+
if not content.strip():
117+
return
118+
119+
statements = self._parse_statements(content)
120+
for statement in statements:
121+
if not statement:
122+
continue
123+
124+
if not isinstance(statement, exp.Create):
125+
continue
126+
127+
if (statement.args.get("kind") or "").upper() != "VIEW":
128+
continue
129+
130+
select = statement.expression
131+
if not isinstance(select, exp.Select):
132+
continue
133+
134+
model = self._model_from_create_view(statement, select)
135+
if model:
136+
graph.add_model(model)
137+
138+
def _parse_statements(self, sql: str) -> list[exp.Expression | None]:
139+
return sqlglot.parse(sql, read=YardstickDialect)
140+
141+
def _model_from_create_view(self, create_stmt: exp.Create, select: exp.Select) -> Model | None:
142+
measure_aliases = {
143+
projection.output_name
144+
for projection in select.expressions
145+
if isinstance(projection, exp.Alias) and projection.args.get("yardstick_measure")
146+
}
147+
if not measure_aliases:
148+
return None
149+
150+
view_name = create_stmt.this.name if isinstance(create_stmt.this, exp.Table) else None
151+
if not view_name:
152+
return None
153+
154+
source_table, source_sql = self._extract_model_source(select)
155+
dimensions: list[Dimension] = []
156+
metrics: list[Metric] = []
157+
all_measure_names = set(measure_aliases)
158+
159+
for projection in select.expressions:
160+
output_name = projection.output_name
161+
if not output_name:
162+
continue
163+
164+
if output_name in measure_aliases:
165+
metric_expr = projection.this if isinstance(projection, exp.Alias) else projection
166+
metric = self._metric_from_expression(output_name, metric_expr, all_measure_names)
167+
metrics.append(metric)
168+
else:
169+
dim_expr = projection.this if isinstance(projection, exp.Alias) else projection
170+
if isinstance(dim_expr, exp.Star):
171+
continue
172+
dim_type, dim_granularity = self._infer_dimension_type(dim_expr)
173+
dimensions.append(
174+
Dimension(
175+
name=output_name,
176+
type=dim_type,
177+
sql=dim_expr.sql(dialect="duckdb"),
178+
granularity=dim_granularity,
179+
)
180+
)
181+
182+
if not metrics:
183+
return None
184+
185+
yardstick_metadata: dict[str, str] = {"view_sql": select.sql(dialect="duckdb")}
186+
if source_table:
187+
yardstick_metadata["base_table"] = source_table
188+
if source_sql:
189+
yardstick_metadata["base_relation_sql"] = source_sql
190+
191+
primary_key = dimensions[0].name if dimensions else "id"
192+
model_kwargs: dict[str, object] = {
193+
"name": view_name,
194+
"primary_key": primary_key,
195+
"dimensions": dimensions,
196+
"metrics": metrics,
197+
"metadata": {"yardstick": yardstick_metadata},
198+
}
199+
if source_sql:
200+
model_kwargs["sql"] = source_sql
201+
elif source_table:
202+
model_kwargs["table"] = source_table
203+
else:
204+
model_kwargs["table"] = view_name
205+
206+
return Model(**model_kwargs)
207+
208+
def _metric_from_expression(self, name: str, expression: exp.Expression, all_measure_names: set[str]) -> Metric:
209+
expression_sql = expression.sql(dialect="duckdb")
210+
if self._references_other_measures(name, expression, all_measure_names):
211+
return Metric(name=name, type="derived", sql=expression_sql)
212+
213+
filtered_aggregation = self._extract_filtered_aggregation(expression)
214+
if filtered_aggregation:
215+
agg, inner_sql, filters = filtered_aggregation
216+
return Metric(name=name, agg=agg, sql=inner_sql, filters=filters)
217+
218+
simple_aggregation = self._extract_supported_aggregation(expression)
219+
if simple_aggregation:
220+
agg, inner_sql = simple_aggregation
221+
return Metric(name=name, agg=agg, sql=inner_sql)
222+
223+
if self._has_aggregate_semantics(expression):
224+
return Metric(name=name, sql=expression_sql)
225+
226+
metric = Metric(name=name, sql=expression_sql)
227+
if metric.agg is None and metric.type is None:
228+
return Metric(name=name, type="derived", sql=expression_sql)
229+
return metric
230+
231+
def _extract_model_source(self, select: exp.Select) -> tuple[str | None, str | None]:
232+
from_clause = select.args.get("from")
233+
joins = select.args.get("joins") or []
234+
where_clause = select.args.get("where")
235+
with_clause = select.args.get("with")
236+
237+
if (
238+
isinstance(from_clause, exp.From)
239+
and isinstance(from_clause.this, exp.Table)
240+
and not joins
241+
and where_clause is None
242+
and with_clause is None
243+
):
244+
table_expr = from_clause.this
245+
is_simple_table = isinstance(table_expr.this, exp.Identifier) and table_expr.args.get("alias") is None
246+
if is_simple_table:
247+
return table_expr.sql(dialect="duckdb"), None
248+
249+
if from_clause is None:
250+
return None, None
251+
252+
base_relation = exp.select("*")
253+
if with_clause is not None:
254+
base_relation.set("with", with_clause.copy())
255+
base_relation.set("from", from_clause.copy())
256+
if joins:
257+
base_relation.set("joins", [join.copy() for join in joins])
258+
if where_clause is not None:
259+
base_relation.set("where", where_clause.copy())
260+
261+
return None, base_relation.sql(dialect="duckdb")
262+
263+
def _references_other_measures(self, name: str, expression: exp.Expression, all_measure_names: set[str]) -> bool:
264+
measure_lookup = {
265+
measure_name.lower() for measure_name in all_measure_names if measure_name.lower() != name.lower()
266+
}
267+
referenced_columns = {column.name.lower() for column in expression.find_all(exp.Column)}
268+
return bool(referenced_columns & measure_lookup)
269+
270+
def _extract_filtered_aggregation(self, expression: exp.Expression) -> tuple[str, str, list[str] | None] | None:
271+
if not isinstance(expression, exp.Filter):
272+
return None
273+
274+
aggregation = self._extract_supported_aggregation(expression.this)
275+
if aggregation is None:
276+
return None
277+
278+
agg, inner_sql = aggregation
279+
where_expression = expression.args.get("expression")
280+
if isinstance(where_expression, exp.Where):
281+
filter_sql = where_expression.this.sql(dialect="duckdb")
282+
elif isinstance(where_expression, exp.Expression):
283+
filter_sql = where_expression.sql(dialect="duckdb")
284+
else:
285+
filter_sql = ""
286+
287+
filters = [filter_sql] if filter_sql else None
288+
return agg, inner_sql, filters
289+
290+
def _extract_supported_aggregation(self, expression: exp.Expression) -> tuple[str, str] | None:
291+
if isinstance(expression, exp.Count):
292+
count_expr = expression.this
293+
if isinstance(count_expr, exp.Distinct):
294+
if count_expr.expressions:
295+
inner_sql = ", ".join(expr.sql(dialect="duckdb") for expr in count_expr.expressions)
296+
else:
297+
inner_sql = count_expr.sql(dialect="duckdb")
298+
return "count_distinct", inner_sql
299+
300+
if count_expr is None or isinstance(count_expr, exp.Star):
301+
return "count", "*"
302+
return "count", count_expr.sql(dialect="duckdb")
303+
304+
for expression_type, aggregation_name in self._SIMPLE_AGGREGATIONS.items():
305+
if isinstance(expression, expression_type):
306+
inner_expression = expression.this
307+
if inner_expression is None:
308+
return aggregation_name, "*"
309+
return aggregation_name, inner_expression.sql(dialect="duckdb")
310+
311+
if isinstance(expression, exp.Func):
312+
function_name = (expression.name or "").lower()
313+
if function_name == "count":
314+
count_expr = expression.this or (expression.expressions[0] if expression.expressions else None)
315+
if isinstance(count_expr, exp.Distinct):
316+
if count_expr.expressions:
317+
inner_sql = ", ".join(expr.sql(dialect="duckdb") for expr in count_expr.expressions)
318+
else:
319+
inner_sql = count_expr.sql(dialect="duckdb")
320+
return "count_distinct", inner_sql
321+
if count_expr is None or isinstance(count_expr, exp.Star):
322+
return "count", "*"
323+
return "count", count_expr.sql(dialect="duckdb")
324+
325+
supported_function_aggs = _supported_metric_aggs() - {"count", "count_distinct"}
326+
if function_name in supported_function_aggs:
327+
inner_expression = expression.this or (expression.expressions[0] if expression.expressions else None)
328+
if inner_expression is None:
329+
return function_name, "*"
330+
return function_name, inner_expression.sql(dialect="duckdb")
331+
332+
return None
333+
334+
def _has_aggregate_semantics(self, expression: exp.Expression) -> bool:
335+
if any(isinstance(node, exp.AggFunc) for node in expression.walk()):
336+
return True
337+
338+
for node in expression.walk():
339+
if isinstance(node, exp.Anonymous) and (node.name or "").lower() in self._ANONYMOUS_AGGREGATIONS:
340+
return True
341+
return False
342+
343+
def _infer_dimension_type(self, expression: exp.Expression) -> tuple[str, str | None]:
344+
if isinstance(expression, exp.Boolean):
345+
return "boolean", None
346+
if isinstance(expression, exp.Literal):
347+
if expression.is_number:
348+
return "numeric", None
349+
return "categorical", None
350+
if isinstance(expression, exp.Column):
351+
column_name = expression.name.lower()
352+
if "timestamp" in column_name:
353+
return "time", "second"
354+
if "date" in column_name:
355+
return "time", "day"
356+
if "time" in column_name:
357+
return "time", "second"
358+
return "categorical", None
359+
if isinstance(expression, exp.Func):
360+
function_name = (expression.name or "").lower()
361+
granularity_by_func = {
362+
"date": "day",
363+
"date_trunc": "day",
364+
"year": "year",
365+
"quarter": "quarter",
366+
"month": "month",
367+
"week": "week",
368+
"day": "day",
369+
"hour": "hour",
370+
"minute": "minute",
371+
}
372+
if function_name in granularity_by_func:
373+
return "time", granularity_by_func[function_name]
374+
return "categorical", None

0 commit comments

Comments
 (0)