@@ -118,8 +118,9 @@ def rewrite(self, sql: str, strict: bool = True) -> str:
118118 # Check if this is a CTE-based query or has subqueries
119119 has_ctes = parsed .args .get ("with" ) is not None
120120 has_subquery_in_from = self ._has_subquery_in_from (parsed )
121+ has_subquery_in_joins = any (isinstance (join .this , exp .Subquery ) for join in (parsed .args .get ("joins" ) or []))
121122
122- if has_ctes or has_subquery_in_from :
123+ if has_ctes or has_subquery_in_from or has_subquery_in_joins :
123124 # Handle CTEs and subqueries
124125 return self ._rewrite_with_ctes_or_subqueries (parsed )
125126
@@ -1851,41 +1852,51 @@ def _has_subquery_in_from(self, select: exp.Select) -> bool:
18511852 def _rewrite_with_ctes_or_subqueries (self , parsed : exp .Select ) -> str :
18521853 """Rewrite query that contains CTEs or subqueries.
18531854
1854- Strategy:
1855- 1. Rewrite each CTE that references semantic models
1856- 2. Rewrite subqueries in FROM clause
1857- 3. Return the modified SQL
1855+ Recursively walks the query tree bottom-up, rewriting any
1856+ SELECT whose FROM target resolves to a semantic model.
1857+ Outer queries are left as plain SQL, so post-processing
1858+ (CASE, window functions, arithmetic, etc.) works naturally.
18581859 """
1859- # Handle CTEs
1860- if parsed .args .get ("with" ):
1861- with_clause = parsed .args ["with" ]
1862- for cte in with_clause .expressions :
1863- # Each CTE has a name (alias) and a query (this)
1860+ self ._rewrite_select_tree (parsed )
1861+ return parsed .sql (dialect = self .dialect )
1862+
1863+ def _rewrite_select_tree (self , select : exp .Select ):
1864+ """Recursively rewrite semantic subqueries and CTEs (bottom-up).
1865+
1866+ At each level: recurse into children first, then rewrite this
1867+ node if it directly references a semantic model.
1868+ """
1869+ # Recurse into CTEs
1870+ if select .args .get ("with" ):
1871+ for cte in select .args ["with" ].expressions :
18641872 cte_query = cte .this
18651873 if isinstance (cte_query , exp .Select ):
1866- # Check if this CTE references a semantic model
1874+ self . _rewrite_select_tree ( cte_query )
18671875 if self ._references_semantic_model (cte_query ):
1868- # Rewrite the CTE query
1869- rewritten_cte_sql = self ._rewrite_simple_query (cte_query )
1870- # Parse the rewritten SQL and replace the CTE query
1871- rewritten_cte = sqlglot .parse_one (rewritten_cte_sql , dialect = self .dialect )
1872- cte .set ("this" , rewritten_cte )
1873-
1874- # Handle subquery in FROM
1875- from_clause = parsed .args .get ("from" )
1876+ rewritten_sql = self ._rewrite_simple_query (cte_query )
1877+ cte .set ("this" , sqlglot .parse_one (rewritten_sql , dialect = self .dialect ))
1878+
1879+ # Recurse into FROM subquery
1880+ from_clause = select .args .get ("from" )
18761881 if from_clause and isinstance (from_clause .this , exp .Subquery ):
18771882 subquery = from_clause .this
18781883 subquery_select = subquery .this
1879- if isinstance (subquery_select , exp .Select ) and self ._references_semantic_model (subquery_select ):
1880- # Rewrite the subquery
1881- rewritten_subquery_sql = self ._rewrite_simple_query (subquery_select )
1882- rewritten_subquery = sqlglot .parse_one (rewritten_subquery_sql , dialect = self .dialect )
1883- subquery .set ("this" , rewritten_subquery )
1884-
1885- # Return the modified SQL
1886- # Note: Individual CTEs/subqueries are already instrumented by _rewrite_simple_query -> generator
1887- # The outer query wrapper doesn't need separate instrumentation
1888- return parsed .sql (dialect = self .dialect )
1884+ if isinstance (subquery_select , exp .Select ):
1885+ self ._rewrite_select_tree (subquery_select )
1886+ if self ._references_semantic_model (subquery_select ):
1887+ rewritten_sql = self ._rewrite_simple_query (subquery_select )
1888+ subquery .set ("this" , sqlglot .parse_one (rewritten_sql , dialect = self .dialect ))
1889+
1890+ # Recurse into JOIN subqueries
1891+ for join in select .args .get ("joins" ) or []:
1892+ join_expr = join .this
1893+ if isinstance (join_expr , exp .Subquery ):
1894+ join_select = join_expr .this
1895+ if isinstance (join_select , exp .Select ):
1896+ self ._rewrite_select_tree (join_select )
1897+ if self ._references_semantic_model (join_select ):
1898+ rewritten_sql = self ._rewrite_simple_query (join_select )
1899+ join_expr .set ("this" , sqlglot .parse_one (rewritten_sql , dialect = self .dialect ))
18891900
18901901 def _references_semantic_model (self , select : exp .Select ) -> bool :
18911902 """Check if a SELECT statement references any semantic models."""
0 commit comments