Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions marimo/_ast/sql_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,37 @@ def get_ref_from_table(table: exp.Table) -> SQLRef | None:

refs: set[SQLRef] = set()

def _collect_table_refs_excluding_ctes(expression: exp.Expression) -> None:
"""Walk all Table nodes, filtering out unqualified CTE references.

find_all(exp.Table) doesn't understand CTE scope, so bare
references to CTE names would be misidentified as real tables.

We only collect CTEs from the statement-level WITH clause
rather than nested subqueries, because a subquery's CTE is
scoped to that subquery and must not mask a real table with the
same name in the outer query. We identify statement-level CTEs
by checking that the CTE's grandparent (With -> Expression) is
the top-level expression. Schema-qualified refs (e.g. schema.foo)
are always real tables even if a CTE shares the same base name.
"""
cte_names: set[str] = set()
for cte in expression.find_all(exp.CTE):
with_node = cte.parent
if with_node and with_node.parent is expression:
alias = cte.alias
if alias:
cte_names.add(alias.lower())
for table in expression.find_all(exp.Table):
if ref := get_ref_from_table(table):
is_unqualified_cte = (
ref.table.lower() in cte_names
and ref.schema is None
and ref.catalog is None
)
if not is_unqualified_cte:
refs.add(ref)

Comment on lines +544 to +573
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_collect_table_refs_excluding_ctes only collects CTE names from the statement-level WITH clause. If an OptimizeError forces this fallback on a query that contains a nested subquery WITH (CTE scoped to that subquery), unqualified references to that nested CTE will still be returned as table refs, reintroducing false dependencies. Consider making the fallback traversal scope-aware by tracking active CTE names while recursively walking the AST (push CTE names when entering an expression with a WITH clause, and only filter matching Table nodes within that subtree).

Suggested change
"""Walk all Table nodes, filtering out unqualified CTE references.
find_all(exp.Table) doesn't understand CTE scope, so bare
references to CTE names would be misidentified as real tables.
We only collect CTEs from the statement-level WITH clause
(expression.args["with_"]) rather than traversing into nested
subqueries, because a subquery's CTE is scoped to that subquery
and must not mask a real table with the same name in the outer
query. Schema-qualified refs (e.g. schema.foo) are always real
tables even if a CTE shares the same base name.
"""
if expression is None:
return
cte_names: set[str] = set()
with_clause = expression.args.get("with_")
if with_clause:
for cte in with_clause.expressions:
alias = cte.alias
if alias:
cte_names.add(alias.lower())
for table in expression.find_all(exp.Table):
if ref := get_ref_from_table(table):
is_unqualified_cte = (
ref.table.lower() in cte_names
and ref.schema is None
and ref.catalog is None
)
if not is_unqualified_cte:
refs.add(ref)
"""Walk Table nodes, filtering out unqualified CTE references.
``find_all(exp.Table)`` doesn't understand CTE scope, so bare
references to CTE names can be misidentified as real tables.
Track active CTE names while recursively traversing the AST so
nested subqueries with their own WITH clauses only mask matching
unqualified table refs within that subtree. Schema-qualified refs
(e.g. schema.foo) are always treated as real tables even if a CTE
shares the same base name.
"""
if expression is None:
return
def _walk(
node: exp.Expression | None,
active_cte_names: set[str],
) -> None:
if node is None:
return
scoped_cte_names = active_cte_names
with_clause = node.args.get("with_")
if with_clause:
scoped_cte_names = set(active_cte_names)
for cte in with_clause.expressions:
alias = cte.alias
if alias:
scoped_cte_names.add(alias.lower())
if isinstance(node, exp.Table):
if ref := get_ref_from_table(node):
is_unqualified_cte = (
ref.table.lower() in scoped_cte_names
and ref.schema is None
and ref.catalog is None
)
if not is_unqualified_cte:
refs.add(ref)
return
for child in node.iter_expressions():
_walk(child, scoped_cte_names)
_walk(expression, set())

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't handle complex cases, but I think it's sufficient for now

for expression in expression_list:
if expression is None:
continue
Expand All @@ -558,9 +589,7 @@ def get_ref_from_table(table: exp.Table) -> SQLRef | None:
exp.Copy,
)
):
for table in expression.find_all(exp.Table):
if ref := get_ref_from_table(table):
refs.add(ref)
_collect_table_refs_excluding_ctes(expression) # type: ignore[arg-type]

# build_scope only works for select statements.
# It may raise OptimizeError for valid SQL with duplicate aliases
Expand All @@ -574,13 +603,6 @@ def get_ref_from_table(table: exp.Table) -> SQLRef | None:
if ref := get_ref_from_table(source):
refs.add(ref)
except OptimizeError:
# Fall back to extracting table references without scope analysis.
# This can happen with valid SQL that has duplicate aliases
# (e.g., cross-joined subqueries with the same column alias).
# We prefer build_scope when possible because it correctly handles
# CTEs - find_all would incorrectly report CTE names as table refs.
for table in expression.find_all(exp.Table):
if ref := get_ref_from_table(table):
refs.add(ref)
_collect_table_refs_excluding_ctes(expression) # type: ignore[arg-type]

return refs
173 changes: 173 additions & 0 deletions tests/_ast/test_sql_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,179 @@ def test_duplicate_alias_with_table_refs(self) -> None:
SQLRef(table="table3"),
}

def test_cte_with_duplicate_join_aliases(self) -> None:
# Regression test for issue #9168
# When two JOIN-ed tables share the same alias and one of them
# is a CTE, build_scope raises OptimizeError ("Alias already used").
# The fallback path must still filter out CTE names.
sql = """
WITH
num_exams AS (
SELECT
student_id,
exam_type_id,
COUNT(*) AS num_exams
FROM
exam_records
GROUP BY
student_id,
exam_type_id
)
SELECT
student_id,
student_name,
exam_type_id,
c.class_name,
COALESCE(c.num_exams, 0) AS num_exams
FROM
students
LEFT JOIN num_exams c USING (student_id)
JOIN classes c USING (class_id)
"""
# num_exams is a CTE and should NOT appear as a dependency
assert find_sql_refs(sql) == {
SQLRef(table="exam_records"),
SQLRef(table="students"),
SQLRef(table="classes"),
}

def test_cte_with_duplicate_join_aliases_mixed_case(self) -> None:
# CTE defined as "Num_Exams" but referenced as "num_exams".
# SQL identifiers are case-insensitive, so these must match.
sql = """
WITH
Num_Exams AS (
SELECT student_id, exam_type_id, COUNT(*) AS num_exams
FROM exam_records
GROUP BY student_id, exam_type_id
)
SELECT
student_id, student_name, exam_type_id,
c.class_name, COALESCE(c.num_exams, 0) AS num_exams
FROM students
LEFT JOIN num_exams c USING (student_id)
JOIN classes c USING (class_id)
"""
assert find_sql_refs(sql) == {
SQLRef(table="exam_records"),
SQLRef(table="students"),
SQLRef(table="classes"),
}

def test_cte_with_duplicate_join_aliases_different_aliases(self) -> None:
# Same query as above but with distinct aliases β€” should work
# both before and after the fix (build_scope succeeds here).
sql = """
WITH
num_exams AS (
SELECT
student_id,
exam_type_id,
COUNT(*) AS num_exams
FROM
exam_records
GROUP BY
student_id,
exam_type_id
)
SELECT
student_id,
student_name,
exam_type_id,
cl.class_name,
COALESCE(ne.num_exams, 0) AS num_exams
FROM
students
LEFT JOIN num_exams ne USING (student_id)
JOIN classes cl USING (class_id)
"""
assert find_sql_refs(sql) == {
SQLRef(table="exam_records"),
SQLRef(table="students"),
SQLRef(table="classes"),
}

def test_multiple_ctes_with_duplicate_aliases(self) -> None:
# Multiple CTEs referenced with the same alias in joins
sql = """
WITH
cte1 AS (SELECT id, val FROM table1),
cte2 AS (SELECT id, val FROM table2)
SELECT *
FROM table3
JOIN cte1 x ON table3.id = x.id
JOIN cte2 x ON table3.id = x.id
"""
# Neither cte1 nor cte2 should appear as dependencies
assert find_sql_refs(sql) == {
SQLRef(table="table1"),
SQLRef(table="table2"),
SQLRef(table="table3"),
}

def test_cte_name_matches_real_table_with_duplicate_alias(self) -> None:
# Edge case: CTE name shadows a real table used elsewhere.
# The CTE itself still shouldn't be a dependency β€” only the
# tables referenced inside and outside it should be.
sql = """
WITH
shared_name AS (SELECT id FROM source_table)
SELECT *
FROM base_table
JOIN shared_name a ON base_table.id = a.id
JOIN other_table a ON base_table.id = a.id
"""
assert find_sql_refs(sql) == {
SQLRef(table="source_table"),
SQLRef(table="base_table"),
SQLRef(table="other_table"),
}

def test_schema_qualified_table_same_name_as_cte(self) -> None:
# A schema-qualified table reference should never be filtered,
# even if its base name matches a CTE in the same query.
sql = """
WITH foo AS (SELECT id FROM source)
SELECT *
FROM schema1.foo
JOIN foo a ON schema1.foo.id = a.id
JOIN bar a ON schema1.foo.id = a.id
"""
assert find_sql_refs(sql) == {
SQLRef(table="source"),
SQLRef(table="foo", schema="schema1"),
SQLRef(table="bar"),
}

def test_nested_subquery_cte_does_not_mask_outer_table(self) -> None:
# A CTE defined inside a subquery is scoped to that subquery.
# It must not mask a real table with the same name in the outer
# query. Duplicate aliases force the OptimizeError fallback.
sql = """
SELECT *
FROM my_table
JOIN (
WITH my_table AS (SELECT 1 AS id)
SELECT * FROM my_table
) a ON my_table.id = a.id
JOIN other_table a ON my_table.id = a.id
"""
assert find_sql_refs(sql) == {
SQLRef(table="my_table"),
SQLRef(table="other_table"),
}

Comment thread
Light2Dark marked this conversation as resolved.
def test_dml_with_cte(self) -> None:
# CTE names should be filtered in DML statements too.
sql = """
WITH cte AS (SELECT * FROM source_table)
INSERT INTO target_table SELECT * FROM cte;
"""
assert find_sql_refs(sql) == {
SQLRef(table="source_table"),
SQLRef(table="target_table"),
}

def test_multiple_statements_with_optimize_error(self) -> None:
# Verify that OptimizeError in one statement doesn't affect others.
# The try/except is inside the loop, so each statement is independent.
Expand Down
Loading