|
2 | 2 |
|
3 | 3 | import duckdb |
4 | 4 | import pytest |
| 5 | +import sqlglot |
| 6 | +from sqlglot import exp |
5 | 7 |
|
6 | | -from sidemantic import Dimension, Metric, Model, Relationship, SemanticLayer |
| 8 | +from sidemantic import Dimension, Metric, Model, Relationship, Segment, SemanticLayer |
7 | 9 | from tests.utils import df_rows |
8 | 10 |
|
9 | 11 |
|
@@ -361,6 +363,53 @@ def test_count_distinct_without_sql_uses_primary_key(layer): |
361 | 363 | assert "COUNT(DISTINCT" in sql |
362 | 364 |
|
363 | 365 |
|
| 366 | +def test_count_distinct_with_segment_filter_without_model_placeholder(layer): |
| 367 | + """Test count_distinct with segment filters that omit {model} placeholders.""" |
| 368 | + layer = SemanticLayer() |
| 369 | + |
| 370 | + location = Model( |
| 371 | + name="location", |
| 372 | + table="dim_location", |
| 373 | + primary_key="sk_location_id", |
| 374 | + dimensions=[ |
| 375 | + Dimension(name="city", type="categorical"), |
| 376 | + ], |
| 377 | + metrics=[ |
| 378 | + Metric(name="count", agg="count_distinct"), # No sql field |
| 379 | + ], |
| 380 | + segments=[ |
| 381 | + Segment(name="lockers_3000", sql="zipcode = '3000'"), |
| 382 | + ], |
| 383 | + ) |
| 384 | + |
| 385 | + layer.add_model(location) |
| 386 | + |
| 387 | + sql = layer.compile( |
| 388 | + metrics=["location.count"], |
| 389 | + dimensions=["location.city"], |
| 390 | + segments=["location.lockers_3000"], |
| 391 | + ) |
| 392 | + |
| 393 | + # Should still use primary key for count_distinct |
| 394 | + assert "sk_location_id AS count_raw" in sql |
| 395 | + assert "count AS count_raw" not in sql |
| 396 | + |
| 397 | + parsed = sqlglot.parse_one(sql) |
| 398 | + cte = None |
| 399 | + for cte_def in parsed.find_all(exp.CTE): |
| 400 | + if cte_def.alias == "location_cte": |
| 401 | + cte = cte_def |
| 402 | + break |
| 403 | + |
| 404 | + assert cte is not None |
| 405 | + |
| 406 | + where_clause = cte.this.find(exp.Where) |
| 407 | + assert where_clause is not None |
| 408 | + where_sql = where_clause.sql() |
| 409 | + assert "zipcode" in where_sql |
| 410 | + assert "3000" in where_sql |
| 411 | + |
| 412 | + |
364 | 413 | def test_count_distinct_with_explicit_sql(layer): |
365 | 414 | """Test that count_distinct with explicit sql uses that column.""" |
366 | 415 | layer = SemanticLayer() |
|
0 commit comments