Skip to content

Commit 50a54dd

Browse files
authored
Qualify segment filters in CTEs (#55)
* Qualify segment filters in CTEs * Format segment filter helper * Avoid qualifying subquery columns
1 parent 8fb84e5 commit 50a54dd

3 files changed

Lines changed: 116 additions & 2 deletions

File tree

sidemantic/sql/generator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,33 @@ def _resolve_segments(self, segments: list[str]) -> list[str]:
439439
Raises:
440440
ValueError: If segment not found
441441
"""
442+
443+
def qualify_unaliased_columns(filter_sql: str, model_alias: str) -> str:
444+
"""Qualify unaliased columns in segment filters with model alias."""
445+
try:
446+
parsed = sqlglot.parse_one(filter_sql, dialect=self.dialect)
447+
except Exception:
448+
return filter_sql
449+
450+
def visit(node: exp.Expression) -> None:
451+
if isinstance(node, exp.Subquery):
452+
return
453+
454+
if isinstance(node, exp.Column) and not node.table:
455+
node.set("table", model_alias)
456+
457+
for arg in node.args.values():
458+
if isinstance(arg, exp.Expression):
459+
visit(arg)
460+
elif isinstance(arg, list):
461+
for item in arg:
462+
if isinstance(item, exp.Expression):
463+
visit(item)
464+
465+
visit(parsed)
466+
467+
return parsed.sql(dialect=self.dialect)
468+
442469
filters = []
443470
for seg_ref in segments:
444471
# Parse model.segment format
@@ -457,6 +484,7 @@ def _resolve_segments(self, segments: list[str]) -> list[str]:
457484
# Get SQL expression with model alias replaced
458485
# Use model_cte as the alias (consistent with CTE naming)
459486
filter_sql = segment.get_sql(f"{model_name}_cte")
487+
filter_sql = qualify_unaliased_columns(filter_sql, f"{model_name}_cte")
460488
filters.append(filter_sql)
461489

462490
return filters

tests/optimizations/test_predicate_pushdown.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import sqlglot
44
from sqlglot import exp
55

6-
from sidemantic import Dimension, Metric, Model
6+
from sidemantic import Dimension, Metric, Model, Segment
7+
from sidemantic.sql.generator import SQLGenerator
78

89

910
def test_single_model_filter_pushdown(layer):
@@ -283,6 +284,42 @@ def test_segment_filters_pushed_down(layer):
283284
assert "completed" in where_sql
284285

285286

287+
def test_segment_filter_skips_subquery_columns(layer):
288+
"""Test that segment filter qualification does not touch subquery columns."""
289+
model = Model(
290+
name="orders",
291+
table="orders_table",
292+
primary_key="id",
293+
metrics=[
294+
Metric(name="count", agg="count"),
295+
],
296+
segments=[
297+
Segment(name="in_other", sql="id in (select id from other_table where flag = 'y')"),
298+
],
299+
)
300+
301+
layer.add_model(model)
302+
303+
generator = SQLGenerator(layer.graph)
304+
filters = generator._resolve_segments(["orders.in_other"])
305+
assert len(filters) == 1
306+
307+
filter_sql = filters[0]
308+
parsed = sqlglot.parse_one(filter_sql)
309+
310+
assert any(col.table == "orders_cte" for col in parsed.find_all(exp.Column))
311+
312+
subquery = None
313+
for subquery_def in parsed.find_all(exp.Subquery):
314+
subquery = subquery_def
315+
break
316+
317+
assert subquery is not None
318+
319+
for col in subquery.find_all(exp.Column):
320+
assert not col.table
321+
322+
286323
def test_metric_level_filters_not_pushed(layer):
287324
"""Test that metric-level filters are applied via CASE WHEN, not pushed to CTE.
288325

tests/queries/test_basic.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import duckdb
44
import pytest
5+
import sqlglot
6+
from sqlglot import exp
57

6-
from sidemantic import Dimension, Metric, Model, Relationship, SemanticLayer
8+
from sidemantic import Dimension, Metric, Model, Relationship, Segment, SemanticLayer
79
from tests.utils import df_rows
810

911

@@ -361,6 +363,53 @@ def test_count_distinct_without_sql_uses_primary_key(layer):
361363
assert "COUNT(DISTINCT" in sql
362364

363365

366+
def test_count_distinct_with_segment_filter_without_model_placeholder(layer):
367+
"""Test count_distinct with segment filters that omit {model} placeholders."""
368+
layer = SemanticLayer()
369+
370+
location = Model(
371+
name="location",
372+
table="dim_location",
373+
primary_key="sk_location_id",
374+
dimensions=[
375+
Dimension(name="city", type="categorical"),
376+
],
377+
metrics=[
378+
Metric(name="count", agg="count_distinct"), # No sql field
379+
],
380+
segments=[
381+
Segment(name="lockers_3000", sql="zipcode = '3000'"),
382+
],
383+
)
384+
385+
layer.add_model(location)
386+
387+
sql = layer.compile(
388+
metrics=["location.count"],
389+
dimensions=["location.city"],
390+
segments=["location.lockers_3000"],
391+
)
392+
393+
# Should still use primary key for count_distinct
394+
assert "sk_location_id AS count_raw" in sql
395+
assert "count AS count_raw" not in sql
396+
397+
parsed = sqlglot.parse_one(sql)
398+
cte = None
399+
for cte_def in parsed.find_all(exp.CTE):
400+
if cte_def.alias == "location_cte":
401+
cte = cte_def
402+
break
403+
404+
assert cte is not None
405+
406+
where_clause = cte.this.find(exp.Where)
407+
assert where_clause is not None
408+
where_sql = where_clause.sql()
409+
assert "zipcode" in where_sql
410+
assert "3000" in where_sql
411+
412+
364413
def test_count_distinct_with_explicit_sql(layer):
365414
"""Test that count_distinct with explicit sql uses that column."""
366415
layer = SemanticLayer()

0 commit comments

Comments
 (0)