|
| 1 | +"""Yardstick adapter for importing SQL models with AS MEASURE semantics.""" |
| 2 | + |
| 3 | +from functools import lru_cache |
| 4 | +from pathlib import Path |
| 5 | +from typing import Literal, get_args, get_origin |
| 6 | + |
| 7 | +import sqlglot |
| 8 | +from sqlglot import exp |
| 9 | +from sqlglot.dialects.duckdb import DuckDB |
| 10 | +from sqlglot.tokens import TokenType |
| 11 | + |
| 12 | +from sidemantic.adapters.base import BaseAdapter |
| 13 | +from sidemantic.core.dimension import Dimension |
| 14 | +from sidemantic.core.metric import Metric |
| 15 | +from sidemantic.core.model import Model |
| 16 | +from sidemantic.core.semantic_graph import SemanticGraph |
| 17 | + |
| 18 | + |
| 19 | +def _extract_literal_strings(annotation) -> set[str]: |
| 20 | + if get_origin(annotation) is Literal: |
| 21 | + return {value for value in get_args(annotation) if isinstance(value, str)} |
| 22 | + |
| 23 | + values = set() |
| 24 | + for arg in get_args(annotation): |
| 25 | + values.update(_extract_literal_strings(arg)) |
| 26 | + return values |
| 27 | + |
| 28 | + |
| 29 | +@lru_cache(maxsize=1) |
| 30 | +def _supported_metric_aggs() -> set[str]: |
| 31 | + annotation = Metric.model_fields["agg"].annotation |
| 32 | + return _extract_literal_strings(annotation) |
| 33 | + |
| 34 | + |
| 35 | +class YardstickDialect(DuckDB): |
| 36 | + """DuckDB dialect extension that supports `AS MEASURE <alias>`.""" |
| 37 | + |
| 38 | + class Parser(DuckDB.Parser): |
| 39 | + """Parser extension for Yardstick's measure alias syntax.""" |
| 40 | + |
| 41 | + def _parse_alias(self, this: exp.Expression | None, explicit: bool = False) -> exp.Expression | None: |
| 42 | + if self._can_parse_limit_or_offset(): |
| 43 | + return this |
| 44 | + |
| 45 | + any_token = self._match(TokenType.ALIAS) |
| 46 | + comments = self._prev_comments or [] |
| 47 | + |
| 48 | + if explicit and not any_token: |
| 49 | + return this |
| 50 | + |
| 51 | + if self._match(TokenType.L_PAREN): |
| 52 | + aliases = self.expression( |
| 53 | + exp.Aliases, |
| 54 | + comments=comments, |
| 55 | + this=this, |
| 56 | + expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), |
| 57 | + ) |
| 58 | + self._match_r_paren(aliases) |
| 59 | + return aliases |
| 60 | + |
| 61 | + is_measure_alias = bool(any_token and self._match_texts({"MEASURE"})) |
| 62 | + alias = self._parse_id_var(any_token, tokens=self.ALIAS_TOKENS) or ( |
| 63 | + self.STRING_ALIASES and self._parse_string_as_identifier() |
| 64 | + ) |
| 65 | + |
| 66 | + if alias: |
| 67 | + comments.extend(alias.pop_comments()) |
| 68 | + this = self.expression(exp.Alias, comments=comments, this=this, alias=alias) |
| 69 | + if is_measure_alias: |
| 70 | + this.set("yardstick_measure", True) |
| 71 | + |
| 72 | + column = this.this |
| 73 | + if not this.comments and column and column.comments: |
| 74 | + this.comments = column.pop_comments() |
| 75 | + |
| 76 | + return this |
| 77 | + |
| 78 | + |
| 79 | +class YardstickAdapter(BaseAdapter): |
| 80 | + """Adapter for Yardstick SQL definitions. |
| 81 | +
|
| 82 | + Yardstick defines measures inside CREATE VIEW statements with: |
| 83 | + `AGG(expr) AS MEASURE measure_name`. |
| 84 | + """ |
| 85 | + |
| 86 | + _SIMPLE_AGGREGATIONS: dict[type[exp.Expression], str] = { |
| 87 | + exp.Sum: "sum", |
| 88 | + exp.Avg: "avg", |
| 89 | + exp.Min: "min", |
| 90 | + exp.Max: "max", |
| 91 | + exp.Median: "median", |
| 92 | + exp.Stddev: "stddev", |
| 93 | + exp.StddevPop: "stddev_pop", |
| 94 | + exp.Variance: "variance", |
| 95 | + exp.VariancePop: "variance_pop", |
| 96 | + } |
| 97 | + _ANONYMOUS_AGGREGATIONS: set[str] = {"mode"} |
| 98 | + |
| 99 | + def parse(self, source: str | Path) -> SemanticGraph: |
| 100 | + """Parse Yardstick SQL files into a semantic graph.""" |
| 101 | + source_path = Path(source) |
| 102 | + if not source_path.exists(): |
| 103 | + raise FileNotFoundError(f"Path does not exist: {source_path}") |
| 104 | + |
| 105 | + graph = SemanticGraph() |
| 106 | + if source_path.is_dir(): |
| 107 | + for sql_file in sorted(source_path.rglob("*.sql")): |
| 108 | + self._parse_sql_file(sql_file, graph) |
| 109 | + else: |
| 110 | + self._parse_sql_file(source_path, graph) |
| 111 | + |
| 112 | + return graph |
| 113 | + |
| 114 | + def _parse_sql_file(self, path: Path, graph: SemanticGraph) -> None: |
| 115 | + content = path.read_text() |
| 116 | + if not content.strip(): |
| 117 | + return |
| 118 | + |
| 119 | + statements = self._parse_statements(content) |
| 120 | + for statement in statements: |
| 121 | + if not statement: |
| 122 | + continue |
| 123 | + |
| 124 | + if not isinstance(statement, exp.Create): |
| 125 | + continue |
| 126 | + |
| 127 | + if (statement.args.get("kind") or "").upper() != "VIEW": |
| 128 | + continue |
| 129 | + |
| 130 | + select = statement.expression |
| 131 | + if not isinstance(select, exp.Select): |
| 132 | + continue |
| 133 | + |
| 134 | + model = self._model_from_create_view(statement, select) |
| 135 | + if model: |
| 136 | + graph.add_model(model) |
| 137 | + |
| 138 | + def _parse_statements(self, sql: str) -> list[exp.Expression | None]: |
| 139 | + return sqlglot.parse(sql, read=YardstickDialect) |
| 140 | + |
| 141 | + def _model_from_create_view(self, create_stmt: exp.Create, select: exp.Select) -> Model | None: |
| 142 | + measure_aliases = { |
| 143 | + projection.output_name |
| 144 | + for projection in select.expressions |
| 145 | + if isinstance(projection, exp.Alias) and projection.args.get("yardstick_measure") |
| 146 | + } |
| 147 | + if not measure_aliases: |
| 148 | + return None |
| 149 | + |
| 150 | + view_name = create_stmt.this.name if isinstance(create_stmt.this, exp.Table) else None |
| 151 | + if not view_name: |
| 152 | + return None |
| 153 | + |
| 154 | + source_table, source_sql = self._extract_model_source(select) |
| 155 | + dimensions: list[Dimension] = [] |
| 156 | + metrics: list[Metric] = [] |
| 157 | + all_measure_names = set(measure_aliases) |
| 158 | + |
| 159 | + for projection in select.expressions: |
| 160 | + output_name = projection.output_name |
| 161 | + if not output_name: |
| 162 | + continue |
| 163 | + |
| 164 | + if output_name in measure_aliases: |
| 165 | + metric_expr = projection.this if isinstance(projection, exp.Alias) else projection |
| 166 | + metric = self._metric_from_expression(output_name, metric_expr, all_measure_names) |
| 167 | + metrics.append(metric) |
| 168 | + else: |
| 169 | + dim_expr = projection.this if isinstance(projection, exp.Alias) else projection |
| 170 | + if isinstance(dim_expr, exp.Star): |
| 171 | + continue |
| 172 | + dim_type, dim_granularity = self._infer_dimension_type(dim_expr) |
| 173 | + dimensions.append( |
| 174 | + Dimension( |
| 175 | + name=output_name, |
| 176 | + type=dim_type, |
| 177 | + sql=dim_expr.sql(dialect="duckdb"), |
| 178 | + granularity=dim_granularity, |
| 179 | + ) |
| 180 | + ) |
| 181 | + |
| 182 | + if not metrics: |
| 183 | + return None |
| 184 | + |
| 185 | + yardstick_metadata: dict[str, str] = {"view_sql": select.sql(dialect="duckdb")} |
| 186 | + if source_table: |
| 187 | + yardstick_metadata["base_table"] = source_table |
| 188 | + if source_sql: |
| 189 | + yardstick_metadata["base_relation_sql"] = source_sql |
| 190 | + |
| 191 | + primary_key = dimensions[0].name if dimensions else "id" |
| 192 | + model_kwargs: dict[str, object] = { |
| 193 | + "name": view_name, |
| 194 | + "primary_key": primary_key, |
| 195 | + "dimensions": dimensions, |
| 196 | + "metrics": metrics, |
| 197 | + "metadata": {"yardstick": yardstick_metadata}, |
| 198 | + } |
| 199 | + if source_sql: |
| 200 | + model_kwargs["sql"] = source_sql |
| 201 | + elif source_table: |
| 202 | + model_kwargs["table"] = source_table |
| 203 | + else: |
| 204 | + model_kwargs["table"] = view_name |
| 205 | + |
| 206 | + return Model(**model_kwargs) |
| 207 | + |
| 208 | + def _metric_from_expression(self, name: str, expression: exp.Expression, all_measure_names: set[str]) -> Metric: |
| 209 | + expression_sql = expression.sql(dialect="duckdb") |
| 210 | + if self._references_other_measures(name, expression, all_measure_names): |
| 211 | + return Metric(name=name, type="derived", sql=expression_sql) |
| 212 | + |
| 213 | + filtered_aggregation = self._extract_filtered_aggregation(expression) |
| 214 | + if filtered_aggregation: |
| 215 | + agg, inner_sql, filters = filtered_aggregation |
| 216 | + return Metric(name=name, agg=agg, sql=inner_sql, filters=filters) |
| 217 | + |
| 218 | + simple_aggregation = self._extract_supported_aggregation(expression) |
| 219 | + if simple_aggregation: |
| 220 | + agg, inner_sql = simple_aggregation |
| 221 | + return Metric(name=name, agg=agg, sql=inner_sql) |
| 222 | + |
| 223 | + if self._has_aggregate_semantics(expression): |
| 224 | + return Metric(name=name, sql=expression_sql) |
| 225 | + |
| 226 | + metric = Metric(name=name, sql=expression_sql) |
| 227 | + if metric.agg is None and metric.type is None: |
| 228 | + return Metric(name=name, type="derived", sql=expression_sql) |
| 229 | + return metric |
| 230 | + |
| 231 | + def _extract_model_source(self, select: exp.Select) -> tuple[str | None, str | None]: |
| 232 | + from_clause = select.args.get("from") |
| 233 | + joins = select.args.get("joins") or [] |
| 234 | + where_clause = select.args.get("where") |
| 235 | + with_clause = select.args.get("with") |
| 236 | + |
| 237 | + if ( |
| 238 | + isinstance(from_clause, exp.From) |
| 239 | + and isinstance(from_clause.this, exp.Table) |
| 240 | + and not joins |
| 241 | + and where_clause is None |
| 242 | + and with_clause is None |
| 243 | + ): |
| 244 | + table_expr = from_clause.this |
| 245 | + is_simple_table = isinstance(table_expr.this, exp.Identifier) and table_expr.args.get("alias") is None |
| 246 | + if is_simple_table: |
| 247 | + return table_expr.sql(dialect="duckdb"), None |
| 248 | + |
| 249 | + if from_clause is None: |
| 250 | + return None, None |
| 251 | + |
| 252 | + base_relation = exp.select("*") |
| 253 | + if with_clause is not None: |
| 254 | + base_relation.set("with", with_clause.copy()) |
| 255 | + base_relation.set("from", from_clause.copy()) |
| 256 | + if joins: |
| 257 | + base_relation.set("joins", [join.copy() for join in joins]) |
| 258 | + if where_clause is not None: |
| 259 | + base_relation.set("where", where_clause.copy()) |
| 260 | + |
| 261 | + return None, base_relation.sql(dialect="duckdb") |
| 262 | + |
| 263 | + def _references_other_measures(self, name: str, expression: exp.Expression, all_measure_names: set[str]) -> bool: |
| 264 | + measure_lookup = { |
| 265 | + measure_name.lower() for measure_name in all_measure_names if measure_name.lower() != name.lower() |
| 266 | + } |
| 267 | + referenced_columns = {column.name.lower() for column in expression.find_all(exp.Column)} |
| 268 | + return bool(referenced_columns & measure_lookup) |
| 269 | + |
| 270 | + def _extract_filtered_aggregation(self, expression: exp.Expression) -> tuple[str, str, list[str] | None] | None: |
| 271 | + if not isinstance(expression, exp.Filter): |
| 272 | + return None |
| 273 | + |
| 274 | + aggregation = self._extract_supported_aggregation(expression.this) |
| 275 | + if aggregation is None: |
| 276 | + return None |
| 277 | + |
| 278 | + agg, inner_sql = aggregation |
| 279 | + where_expression = expression.args.get("expression") |
| 280 | + if isinstance(where_expression, exp.Where): |
| 281 | + filter_sql = where_expression.this.sql(dialect="duckdb") |
| 282 | + elif isinstance(where_expression, exp.Expression): |
| 283 | + filter_sql = where_expression.sql(dialect="duckdb") |
| 284 | + else: |
| 285 | + filter_sql = "" |
| 286 | + |
| 287 | + filters = [filter_sql] if filter_sql else None |
| 288 | + return agg, inner_sql, filters |
| 289 | + |
| 290 | + def _extract_supported_aggregation(self, expression: exp.Expression) -> tuple[str, str] | None: |
| 291 | + if isinstance(expression, exp.Count): |
| 292 | + count_expr = expression.this |
| 293 | + if isinstance(count_expr, exp.Distinct): |
| 294 | + if count_expr.expressions: |
| 295 | + inner_sql = ", ".join(expr.sql(dialect="duckdb") for expr in count_expr.expressions) |
| 296 | + else: |
| 297 | + inner_sql = count_expr.sql(dialect="duckdb") |
| 298 | + return "count_distinct", inner_sql |
| 299 | + |
| 300 | + if count_expr is None or isinstance(count_expr, exp.Star): |
| 301 | + return "count", "*" |
| 302 | + return "count", count_expr.sql(dialect="duckdb") |
| 303 | + |
| 304 | + for expression_type, aggregation_name in self._SIMPLE_AGGREGATIONS.items(): |
| 305 | + if isinstance(expression, expression_type): |
| 306 | + inner_expression = expression.this |
| 307 | + if inner_expression is None: |
| 308 | + return aggregation_name, "*" |
| 309 | + return aggregation_name, inner_expression.sql(dialect="duckdb") |
| 310 | + |
| 311 | + if isinstance(expression, exp.Func): |
| 312 | + function_name = (expression.name or "").lower() |
| 313 | + if function_name == "count": |
| 314 | + count_expr = expression.this or (expression.expressions[0] if expression.expressions else None) |
| 315 | + if isinstance(count_expr, exp.Distinct): |
| 316 | + if count_expr.expressions: |
| 317 | + inner_sql = ", ".join(expr.sql(dialect="duckdb") for expr in count_expr.expressions) |
| 318 | + else: |
| 319 | + inner_sql = count_expr.sql(dialect="duckdb") |
| 320 | + return "count_distinct", inner_sql |
| 321 | + if count_expr is None or isinstance(count_expr, exp.Star): |
| 322 | + return "count", "*" |
| 323 | + return "count", count_expr.sql(dialect="duckdb") |
| 324 | + |
| 325 | + supported_function_aggs = _supported_metric_aggs() - {"count", "count_distinct"} |
| 326 | + if function_name in supported_function_aggs: |
| 327 | + inner_expression = expression.this or (expression.expressions[0] if expression.expressions else None) |
| 328 | + if inner_expression is None: |
| 329 | + return function_name, "*" |
| 330 | + return function_name, inner_expression.sql(dialect="duckdb") |
| 331 | + |
| 332 | + return None |
| 333 | + |
| 334 | + def _has_aggregate_semantics(self, expression: exp.Expression) -> bool: |
| 335 | + if any(isinstance(node, exp.AggFunc) for node in expression.walk()): |
| 336 | + return True |
| 337 | + |
| 338 | + for node in expression.walk(): |
| 339 | + if isinstance(node, exp.Anonymous) and (node.name or "").lower() in self._ANONYMOUS_AGGREGATIONS: |
| 340 | + return True |
| 341 | + return False |
| 342 | + |
| 343 | + def _infer_dimension_type(self, expression: exp.Expression) -> tuple[str, str | None]: |
| 344 | + if isinstance(expression, exp.Boolean): |
| 345 | + return "boolean", None |
| 346 | + if isinstance(expression, exp.Literal): |
| 347 | + if expression.is_number: |
| 348 | + return "numeric", None |
| 349 | + return "categorical", None |
| 350 | + if isinstance(expression, exp.Column): |
| 351 | + column_name = expression.name.lower() |
| 352 | + if "timestamp" in column_name: |
| 353 | + return "time", "second" |
| 354 | + if "date" in column_name: |
| 355 | + return "time", "day" |
| 356 | + if "time" in column_name: |
| 357 | + return "time", "second" |
| 358 | + return "categorical", None |
| 359 | + if isinstance(expression, exp.Func): |
| 360 | + function_name = (expression.name or "").lower() |
| 361 | + granularity_by_func = { |
| 362 | + "date": "day", |
| 363 | + "date_trunc": "day", |
| 364 | + "year": "year", |
| 365 | + "quarter": "quarter", |
| 366 | + "month": "month", |
| 367 | + "week": "week", |
| 368 | + "day": "day", |
| 369 | + "hour": "hour", |
| 370 | + "minute": "minute", |
| 371 | + } |
| 372 | + if function_name in granularity_by_func: |
| 373 | + return "time", granularity_by_func[function_name] |
| 374 | + return "categorical", None |
0 commit comments