Skip to content

Commit b542ae7

Browse files
committed
Handle CTE name collisions between user and generated CTEs
Two changes: 1. post_process path: remove CTE hoisting entirely. Inner SQL (with CTEs) is placed directly in the subquery position. CTEs inside subqueries are valid in all target databases and naturally scoped, so name collisions with post_process CTEs cannot occur. 2. Root semantic + user CTEs: detect name collisions between user CTEs and generated CTEs, raising a clear error instead of producing invalid SQL. Walk-based renaming was too aggressive (renamed user CTE references inside filter subqueries).
1 parent 6cc3958 commit b542ae7

3 files changed

Lines changed: 45 additions & 24 deletions

File tree

sidemantic/core/semantic_layer.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -552,30 +552,10 @@ def compile(
552552
if last_line.startswith("-- sidemantic:"):
553553
stripped = "\n".join(stripped.split("\n")[:-1])
554554

555-
# If inner SQL starts with WITH (CTEs), hoist them outside
556-
# the subquery position so the SQL is valid.
557-
if stripped.lstrip().upper().startswith("WITH "):
558-
import sqlglot
559-
560-
target_dialect = dialect or self.dialect
561-
parsed_inner = sqlglot.parse_one(stripped, dialect=target_dialect)
562-
inner_with = parsed_inner.args.get("with")
563-
if inner_with:
564-
parsed_inner.set("with", None)
565-
body = parsed_inner.sql(dialect=target_dialect)
566-
567-
# Substitute body into post_process, then merge CTEs
568-
outer_sql = post_process.replace("{inner}", body)
569-
outer_parsed = sqlglot.parse_one(outer_sql, dialect=target_dialect)
570-
outer_with = outer_parsed.args.get("with")
571-
if outer_with:
572-
# Prepend inner CTEs before outer CTEs
573-
merged = list(inner_with.expressions) + list(outer_with.expressions)
574-
outer_with.set("expressions", merged)
575-
else:
576-
outer_parsed.set("with", inner_with)
577-
return outer_parsed.sql(dialect=target_dialect)
578-
555+
# Inner SQL (including any CTEs) is placed directly in the
556+
# subquery position. CTEs inside subqueries are valid SQL in
557+
# all target databases and naturally scoped, avoiding name
558+
# collisions with CTEs in the post_process SQL.
579559
return post_process.replace("{inner}", stripped)
580560

581561
return inner_sql

sidemantic/sql/query_rewriter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,6 +1876,15 @@ def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str:
18761876
rewritten = sqlglot.parse_one(rewritten_sql, dialect=self.dialect)
18771877
gen_with = rewritten.args.get("with")
18781878
if gen_with:
1879+
# Check for CTE name collisions between user and generated CTEs
1880+
user_names = {cte.alias for cte in original_with.expressions}
1881+
for gen_cte in gen_with.expressions:
1882+
if gen_cte.alias in user_names:
1883+
raise ValueError(
1884+
f"CTE name '{gen_cte.alias}' conflicts with an internally "
1885+
f"generated name. Please choose a different CTE name."
1886+
)
1887+
18791888
user_ctes = [cte.copy() for cte in original_with.expressions]
18801889
gen_with.set("expressions", user_ctes + list(gen_with.expressions))
18811890
# Preserve WITH RECURSIVE from the original query

tests/queries/test_sql_rewriter.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,3 +1433,35 @@ def test_post_process_with_own_ctes(semantic_layer):
14331433

14341434
assert len(rows) >= 1
14351435
assert all(row["revenue"] >= 200 for row in rows)
1436+
1437+
1438+
def test_post_process_cte_name_collision(semantic_layer):
1439+
"""post_process CTE with same name as generated CTE doesn't collide."""
1440+
result = semantic_layer.query(
1441+
metrics=["orders.revenue"],
1442+
dimensions=["orders.status"],
1443+
post_process="""
1444+
WITH orders_cte AS (SELECT 'custom' AS source)
1445+
SELECT sq.*, oc.source
1446+
FROM ({inner}) sq
1447+
CROSS JOIN orders_cte oc
1448+
""",
1449+
)
1450+
rows = _rows(result)
1451+
1452+
assert len(rows) >= 1
1453+
assert all(row["source"] == "custom" for row in rows)
1454+
1455+
1456+
def test_root_semantic_cte_name_collision(semantic_layer):
1457+
"""User CTE with same name as generated CTE raises a clear error."""
1458+
sql = """
1459+
WITH orders_cte AS (
1460+
SELECT 'completed' AS status
1461+
)
1462+
SELECT orders.revenue
1463+
FROM orders
1464+
WHERE orders.status IN (SELECT status FROM orders_cte)
1465+
"""
1466+
with pytest.raises(ValueError, match="conflicts with an internally generated name"):
1467+
semantic_layer.sql(sql)

0 commit comments

Comments
 (0)