Skip to content

Commit 552c2d9

Browse files
committed
Add post-processing SQL over semantic query results
Support arbitrary SQL (CASE, window functions, arithmetic, etc.) on top of semantic query results via subquery wrapping. The rewriter now walks the query tree recursively so nested subqueries and JOIN subqueries that reference semantic models are compiled correctly. Also adds a post_process parameter to compile() and query() for the Python API path, with automatic CTE hoisting.
1 parent 98406be commit 552c2d9

3 files changed

Lines changed: 364 additions & 32 deletions

File tree

sidemantic/core/semantic_layer.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def query(
435435
ungrouped: bool = False,
436436
parameters: dict[str, any] | None = None,
437437
use_preaggregations: bool | None = None,
438+
post_process: str | None = None,
438439
):
439440
"""Execute a query against the semantic layer.
440441
@@ -448,6 +449,9 @@ def query(
448449
ungrouped: If True, return raw rows without aggregation (no GROUP BY)
449450
parameters: Template parameters for Jinja2 rendering
450451
use_preaggregations: Override pre-aggregation routing setting for this query
452+
post_process: Optional SQL to wrap around the semantic query result.
453+
Use {inner} as a placeholder for the compiled semantic query, e.g.:
454+
"SELECT *, revenue / count AS avg_value FROM ({inner})"
451455
452456
Returns:
453457
DuckDB relation object (can convert to DataFrame with .df() or .to_df())
@@ -462,6 +466,7 @@ def query(
462466
ungrouped=ungrouped,
463467
parameters=parameters,
464468
use_preaggregations=use_preaggregations,
469+
post_process=post_process,
465470
)
466471

467472
return self.adapter.execute(sql)
@@ -479,6 +484,7 @@ def compile(
479484
ungrouped: bool = False,
480485
parameters: dict[str, any] | None = None,
481486
use_preaggregations: bool | None = None,
487+
post_process: str | None = None,
482488
) -> str:
483489
"""Compile a query to SQL without executing.
484490
@@ -493,6 +499,9 @@ def compile(
493499
dialect: SQL dialect override (defaults to layer's dialect)
494500
ungrouped: If True, return raw rows without aggregation (no GROUP BY)
495501
use_preaggregations: Override pre-aggregation routing setting for this query
502+
post_process: Optional SQL to wrap around the semantic query result.
503+
Use {inner} as a placeholder for the compiled semantic query, e.g.:
504+
"SELECT *, revenue / count AS avg_value FROM ({inner})"
496505
497506
Returns:
498507
SQL query string
@@ -520,7 +529,7 @@ def compile(
520529
preagg_schema=self.preagg_schema,
521530
)
522531

523-
return generator.generate(
532+
inner_sql = generator.generate(
524533
metrics=metrics,
525534
dimensions=dimensions,
526535
filters=filters,
@@ -533,6 +542,34 @@ def compile(
533542
use_preaggregations=use_preaggs,
534543
)
535544

545+
if post_process is not None:
546+
if "{inner}" not in post_process:
547+
raise ValueError("post_process must contain a {inner} placeholder")
548+
549+
# Strip sidemantic instrumentation comment
550+
stripped = inner_sql.rstrip()
551+
last_line = stripped.split("\n")[-1].strip()
552+
if last_line.startswith("-- sidemantic:"):
553+
stripped = "\n".join(stripped.split("\n")[:-1])
554+
555+
# If inner SQL starts with WITH (CTEs), hoist them outside
556+
# the subquery position so the SQL is valid.
557+
if stripped.lstrip().upper().startswith("WITH "):
558+
import sqlglot
559+
560+
target_dialect = dialect or self.dialect
561+
parsed_inner = sqlglot.parse_one(stripped, dialect=target_dialect)
562+
with_clause = parsed_inner.args.get("with")
563+
if with_clause:
564+
parsed_inner.set("with", None)
565+
body = parsed_inner.sql(dialect=target_dialect)
566+
ctes = with_clause.sql(dialect=target_dialect)
567+
return ctes + "\n" + post_process.replace("{inner}", body)
568+
569+
return post_process.replace("{inner}", stripped)
570+
571+
return inner_sql
572+
536573
def explain(
537574
self,
538575
metrics: list[str] | None = None,

sidemantic/sql/query_rewriter.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ def rewrite(self, sql: str, strict: bool = True) -> str:
118118
# Check if this is a CTE-based query or has subqueries
119119
has_ctes = parsed.args.get("with") is not None
120120
has_subquery_in_from = self._has_subquery_in_from(parsed)
121+
has_subquery_in_joins = any(isinstance(join.this, exp.Subquery) for join in (parsed.args.get("joins") or []))
121122

122-
if has_ctes or has_subquery_in_from:
123+
if has_ctes or has_subquery_in_from or has_subquery_in_joins:
123124
# Handle CTEs and subqueries
124125
return self._rewrite_with_ctes_or_subqueries(parsed)
125126

@@ -1851,41 +1852,51 @@ def _has_subquery_in_from(self, select: exp.Select) -> bool:
18511852
def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str:
18521853
"""Rewrite query that contains CTEs or subqueries.
18531854
1854-
Strategy:
1855-
1. Rewrite each CTE that references semantic models
1856-
2. Rewrite subqueries in FROM clause
1857-
3. Return the modified SQL
1855+
Recursively walks the query tree bottom-up, rewriting any
1856+
SELECT whose FROM target resolves to a semantic model.
1857+
Outer queries are left as plain SQL, so post-processing
1858+
(CASE, window functions, arithmetic, etc.) works naturally.
18581859
"""
1859-
# Handle CTEs
1860-
if parsed.args.get("with"):
1861-
with_clause = parsed.args["with"]
1862-
for cte in with_clause.expressions:
1863-
# Each CTE has a name (alias) and a query (this)
1860+
self._rewrite_select_tree(parsed)
1861+
return parsed.sql(dialect=self.dialect)
1862+
1863+
def _rewrite_select_tree(self, select: exp.Select):
1864+
"""Recursively rewrite semantic subqueries and CTEs (bottom-up).
1865+
1866+
At each level: recurse into children first, then rewrite this
1867+
node if it directly references a semantic model.
1868+
"""
1869+
# Recurse into CTEs
1870+
if select.args.get("with"):
1871+
for cte in select.args["with"].expressions:
18641872
cte_query = cte.this
18651873
if isinstance(cte_query, exp.Select):
1866-
# Check if this CTE references a semantic model
1874+
self._rewrite_select_tree(cte_query)
18671875
if self._references_semantic_model(cte_query):
1868-
# Rewrite the CTE query
1869-
rewritten_cte_sql = self._rewrite_simple_query(cte_query)
1870-
# Parse the rewritten SQL and replace the CTE query
1871-
rewritten_cte = sqlglot.parse_one(rewritten_cte_sql, dialect=self.dialect)
1872-
cte.set("this", rewritten_cte)
1873-
1874-
# Handle subquery in FROM
1875-
from_clause = parsed.args.get("from")
1876+
rewritten_sql = self._rewrite_simple_query(cte_query)
1877+
cte.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))
1878+
1879+
# Recurse into FROM subquery
1880+
from_clause = select.args.get("from")
18761881
if from_clause and isinstance(from_clause.this, exp.Subquery):
18771882
subquery = from_clause.this
18781883
subquery_select = subquery.this
1879-
if isinstance(subquery_select, exp.Select) and self._references_semantic_model(subquery_select):
1880-
# Rewrite the subquery
1881-
rewritten_subquery_sql = self._rewrite_simple_query(subquery_select)
1882-
rewritten_subquery = sqlglot.parse_one(rewritten_subquery_sql, dialect=self.dialect)
1883-
subquery.set("this", rewritten_subquery)
1884-
1885-
# Return the modified SQL
1886-
# Note: Individual CTEs/subqueries are already instrumented by _rewrite_simple_query -> generator
1887-
# The outer query wrapper doesn't need separate instrumentation
1888-
return parsed.sql(dialect=self.dialect)
1884+
if isinstance(subquery_select, exp.Select):
1885+
self._rewrite_select_tree(subquery_select)
1886+
if self._references_semantic_model(subquery_select):
1887+
rewritten_sql = self._rewrite_simple_query(subquery_select)
1888+
subquery.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))
1889+
1890+
# Recurse into JOIN subqueries
1891+
for join in select.args.get("joins") or []:
1892+
join_expr = join.this
1893+
if isinstance(join_expr, exp.Subquery):
1894+
join_select = join_expr.this
1895+
if isinstance(join_select, exp.Select):
1896+
self._rewrite_select_tree(join_select)
1897+
if self._references_semantic_model(join_select):
1898+
rewritten_sql = self._rewrite_simple_query(join_select)
1899+
join_expr.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))
18891900

18901901
def _references_semantic_model(self, select: exp.Select) -> bool:
18911902
"""Check if a SELECT statement references any semantic models."""

0 commit comments

Comments
 (0)