Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion sidemantic/core/semantic_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def query(
ungrouped: bool = False,
parameters: dict[str, any] | None = None,
use_preaggregations: bool | None = None,
post_process: str | None = None,
):
"""Execute a query against the semantic layer.

Expand All @@ -448,6 +449,9 @@ def query(
ungrouped: If True, return raw rows without aggregation (no GROUP BY)
parameters: Template parameters for Jinja2 rendering
use_preaggregations: Override pre-aggregation routing setting for this query
post_process: Optional SQL to wrap around the semantic query result.
Use {inner} as a placeholder for the compiled semantic query, e.g.:
"SELECT *, revenue / count AS avg_value FROM ({inner})"

Returns:
DuckDB relation object (can convert to DataFrame with .df() or .to_df())
Expand All @@ -462,6 +466,7 @@ def query(
ungrouped=ungrouped,
parameters=parameters,
use_preaggregations=use_preaggregations,
post_process=post_process,
)

return self.adapter.execute(sql)
Expand All @@ -479,6 +484,7 @@ def compile(
ungrouped: bool = False,
parameters: dict[str, any] | None = None,
use_preaggregations: bool | None = None,
post_process: str | None = None,
) -> str:
"""Compile a query to SQL without executing.

Expand All @@ -493,6 +499,9 @@ def compile(
dialect: SQL dialect override (defaults to layer's dialect)
ungrouped: If True, return raw rows without aggregation (no GROUP BY)
use_preaggregations: Override pre-aggregation routing setting for this query
post_process: Optional SQL to wrap around the semantic query result.
Use {inner} as a placeholder for the compiled semantic query, e.g.:
"SELECT *, revenue / count AS avg_value FROM ({inner})"

Returns:
SQL query string
Expand Down Expand Up @@ -520,7 +529,7 @@ def compile(
preagg_schema=self.preagg_schema,
)

return generator.generate(
inner_sql = generator.generate(
metrics=metrics,
dimensions=dimensions,
filters=filters,
Expand All @@ -533,6 +542,34 @@ def compile(
use_preaggregations=use_preaggs,
)

if post_process is not None:
if "{inner}" not in post_process:
raise ValueError("post_process must contain a {inner} placeholder")

# Strip sidemantic instrumentation comment
stripped = inner_sql.rstrip()
last_line = stripped.split("\n")[-1].strip()
if last_line.startswith("-- sidemantic:"):
stripped = "\n".join(stripped.split("\n")[:-1])

# If inner SQL starts with WITH (CTEs), hoist them outside
# the subquery position so the SQL is valid.
if stripped.lstrip().upper().startswith("WITH "):
import sqlglot

target_dialect = dialect or self.dialect
parsed_inner = sqlglot.parse_one(stripped, dialect=target_dialect)
with_clause = parsed_inner.args.get("with")
if with_clause:
parsed_inner.set("with", None)
body = parsed_inner.sql(dialect=target_dialect)
ctes = with_clause.sql(dialect=target_dialect)
return ctes + "\n" + post_process.replace("{inner}", body)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Merge hoisted CTEs with outer post-process WITH

The post-processing path prepends hoisted inner CTEs directly in front of the user-supplied post_process SQL. If post_process itself starts with WITH, the result becomes WITH ...\nWITH ..., which is invalid SQL and fails to parse/execute. Since the feature is documented as allowing arbitrary SQL wrapping, this makes valid CTE-based post-processing unusable whenever the inner semantic query emits CTEs.

Useful? React with 👍 / 👎.


return post_process.replace("{inner}", stripped)

return inner_sql

def explain(
self,
metrics: list[str] | None = None,
Expand Down
69 changes: 40 additions & 29 deletions sidemantic/sql/query_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def rewrite(self, sql: str, strict: bool = True) -> str:
# Check if this is a CTE-based query or has subqueries
has_ctes = parsed.args.get("with") is not None
has_subquery_in_from = self._has_subquery_in_from(parsed)
has_subquery_in_joins = any(isinstance(join.this, exp.Subquery) for join in (parsed.args.get("joins") or []))

if has_ctes or has_subquery_in_from:
if has_ctes or has_subquery_in_from or has_subquery_in_joins:
Comment on lines +121 to +123
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve root semantic rewrite when JOIN has subquery

Routing every query with a JOIN subquery through _rewrite_with_ctes_or_subqueries skips _rewrite_simple_query for the root SELECT, because _rewrite_select_tree only rewrites child scopes. For queries whose root FROM is a semantic model (for example, FROM orders o JOIN (SELECT ...) s), metric references like o.revenue are no longer rewritten through the semantic pipeline or rejected by the explicit JOIN guard, so execution can now fail with binder errors or silently return raw row-level columns instead of semantic aggregates.

Useful? React with 👍 / 👎.

# Handle CTEs and subqueries
return self._rewrite_with_ctes_or_subqueries(parsed)

Expand Down Expand Up @@ -1851,41 +1852,51 @@ def _has_subquery_in_from(self, select: exp.Select) -> bool:
def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str:
"""Rewrite query that contains CTEs or subqueries.

Strategy:
1. Rewrite each CTE that references semantic models
2. Rewrite subqueries in FROM clause
3. Return the modified SQL
Recursively walks the query tree bottom-up, rewriting any
SELECT whose FROM target resolves to a semantic model.
Outer queries are left as plain SQL, so post-processing
(CASE, window functions, arithmetic, etc.) works naturally.
"""
# Handle CTEs
if parsed.args.get("with"):
with_clause = parsed.args["with"]
for cte in with_clause.expressions:
# Each CTE has a name (alias) and a query (this)
self._rewrite_select_tree(parsed)
return parsed.sql(dialect=self.dialect)

def _rewrite_select_tree(self, select: exp.Select):
"""Recursively rewrite semantic subqueries and CTEs (bottom-up).

At each level: recurse into children first, then rewrite this
node if it directly references a semantic model.
"""
# Recurse into CTEs
if select.args.get("with"):
for cte in select.args["with"].expressions:
cte_query = cte.this
if isinstance(cte_query, exp.Select):
# Check if this CTE references a semantic model
self._rewrite_select_tree(cte_query)
if self._references_semantic_model(cte_query):
# Rewrite the CTE query
rewritten_cte_sql = self._rewrite_simple_query(cte_query)
# Parse the rewritten SQL and replace the CTE query
rewritten_cte = sqlglot.parse_one(rewritten_cte_sql, dialect=self.dialect)
cte.set("this", rewritten_cte)

# Handle subquery in FROM
from_clause = parsed.args.get("from")
rewritten_sql = self._rewrite_simple_query(cte_query)
cte.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))

# Recurse into FROM subquery
from_clause = select.args.get("from")
if from_clause and isinstance(from_clause.this, exp.Subquery):
subquery = from_clause.this
subquery_select = subquery.this
if isinstance(subquery_select, exp.Select) and self._references_semantic_model(subquery_select):
# Rewrite the subquery
rewritten_subquery_sql = self._rewrite_simple_query(subquery_select)
rewritten_subquery = sqlglot.parse_one(rewritten_subquery_sql, dialect=self.dialect)
subquery.set("this", rewritten_subquery)

# Return the modified SQL
# Note: Individual CTEs/subqueries are already instrumented by _rewrite_simple_query -> generator
# The outer query wrapper doesn't need separate instrumentation
return parsed.sql(dialect=self.dialect)
if isinstance(subquery_select, exp.Select):
self._rewrite_select_tree(subquery_select)
if self._references_semantic_model(subquery_select):
rewritten_sql = self._rewrite_simple_query(subquery_select)
subquery.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))

# Recurse into JOIN subqueries
for join in select.args.get("joins") or []:
join_expr = join.this
if isinstance(join_expr, exp.Subquery):
join_select = join_expr.this
if isinstance(join_select, exp.Select):
self._rewrite_select_tree(join_select)
if self._references_semantic_model(join_select):
rewritten_sql = self._rewrite_simple_query(join_select)
join_expr.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))

def _references_semantic_model(self, select: exp.Select) -> bool:
"""Check if a SELECT statement references any semantic models."""
Expand Down
Loading
Loading