Skip to content

Commit c07e33c

Browse files
committed
Add comprehensive tests and fix rewritten queries to SQL format
Changes: - Added 9 comprehensive tests for model generation functionality - Fixed rewritten queries to output SQL format instead of Python code - Rewritten queries now use semantic layer syntax (model.dimension, model.metric) - Fixed COUNT(DISTINCT col) parsing (use expressions[0] not .this) - Write rewritten queries as .sql files not .py files Tests cover: - Generating models from multiple query types - COUNT(DISTINCT) metric generation - Duplicate metric handling - Rewritten query generation with filters - Skipping unparseable queries - Writing models and queries to disk - Multiple aggregations on same column All 16 coverage analyzer tests passing.
1 parent b6473f9 commit c07e33c

2 files changed

Lines changed: 299 additions & 31 deletions

File tree

sidemantic/core/coverage_analyzer.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,10 @@ def _extract_aggregations(self, parsed: exp.Expression, analysis: QueryAnalysis)
194194
analysis.aggregations.append((agg_name, "*", ""))
195195
elif isinstance(col, exp.Distinct):
196196
# COUNT(DISTINCT col) - handle specially
197-
if isinstance(col.this, exp.Column):
198-
col_name = col.this.name
199-
table_name = col.this.table if col.this.table else None
197+
if col.expressions and isinstance(col.expressions[0], exp.Column):
198+
distinct_col = col.expressions[0]
199+
col_name = distinct_col.name
200+
table_name = distinct_col.table if distinct_col.table else None
200201
analysis.aggregations.append(("count_distinct", col_name, table_name or ""))
201202

202203
def _extract_group_by(self, parsed: exp.Expression, analysis: QueryAnalysis) -> None:
@@ -591,29 +592,30 @@ def write_model_files(self, models: dict[str, dict], output_dir: str) -> None:
591592
print(f"Generated: {file_path}")
592593

593594
def generate_rewritten_queries(self, report: CoverageReport) -> dict[str, str]:
594-
"""Generate rewritten queries using semantic layer.
595+
"""Generate rewritten SQL queries using semantic layer syntax.
595596
596597
Args:
597598
report: Coverage report
598599
599600
Returns:
600-
Dictionary mapping query names to rewritten Python code
601+
Dictionary mapping query names to rewritten SQL
601602
"""
602603
rewritten = {}
603604

604605
for i, analysis in enumerate(report.query_analyses, 1):
605-
if analysis.parse_error or not analysis.can_rewrite:
606+
if analysis.parse_error:
606607
continue
607608

608-
# Build dimension references
609-
dimensions = []
609+
# Build SELECT clause with model.dimension and model.metric format
610+
select_parts = []
611+
612+
# Add dimensions
610613
for table_name, col_name in analysis.group_by_columns:
611614
if not table_name and len(analysis.tables) == 1:
612615
table_name = list(analysis.tables)[0]
613-
dimensions.append(f"{table_name}.{col_name}")
616+
select_parts.append(f"{table_name}.{col_name}")
614617

615-
# Build metric references
616-
metrics = []
618+
# Add metrics
617619
for agg_type, col_name, table_name in analysis.aggregations:
618620
if not table_name and len(analysis.tables) == 1:
619621
table_name = list(analysis.tables)[0]
@@ -626,32 +628,32 @@ def generate_rewritten_queries(self, report: CoverageReport) -> dict[str, str]:
626628
else:
627629
metric_name = f"{agg_type}_{col_name}"
628630

629-
metrics.append(f"{table_name}.{metric_name}")
631+
select_parts.append(f"{table_name}.{metric_name}")
632+
633+
if not select_parts:
634+
continue
635+
636+
# Build SQL query
637+
sql = "SELECT\n"
638+
sql += " " + ",\n ".join(select_parts)
639+
640+
# Determine main table
641+
if len(analysis.tables) == 1:
642+
main_table = list(analysis.tables)[0]
643+
sql += f"\nFROM {main_table}"
630644

631-
# Build filter clause
632-
where_clause = None
645+
# Add WHERE clause
633646
if analysis.filters:
634647
where_clause = analysis.filters[0]
635-
636-
# Generate Python code
637-
parts = []
638-
if dimensions:
639-
parts.append(f" dimensions={dimensions}")
640-
if metrics:
641-
parts.append(f" metrics={metrics}")
642-
if where_clause:
643-
parts.append(f' where="{where_clause}"')
648+
sql += f"\nWHERE {where_clause}"
644649

645650
query_name = f"query_{i}"
646-
code = f"# Original query:\n# {analysis.query.strip()}\n\n"
647-
code += "result = layer.query(\n" + ",\n".join(parts) + "\n)"
648-
649-
rewritten[query_name] = code
651+
rewritten[query_name] = sql
650652

651653
return rewritten
652654

653655
def write_rewritten_queries(self, queries: dict[str, str], output_dir: str) -> None:
654-
"""Write rewritten queries to Python files.
656+
"""Write rewritten queries to SQL files.
655657
656658
Args:
657659
queries: Dictionary of rewritten queries from generate_rewritten_queries()
@@ -662,10 +664,10 @@ def write_rewritten_queries(self, queries: dict[str, str], output_dir: str) -> N
662664
output_path = Path(output_dir)
663665
output_path.mkdir(parents=True, exist_ok=True)
664666

665-
for query_name, code in queries.items():
666-
file_path = output_path / f"{query_name}.py"
667+
for query_name, sql in queries.items():
668+
file_path = output_path / f"{query_name}.sql"
667669
with open(file_path, "w") as f:
668-
f.write(code)
670+
f.write(sql)
669671
f.write("\n")
670672

671673
print(f"Generated: {file_path}")
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
"""Tests for coverage analyzer model and query generation."""
2+
3+
from sidemantic import SemanticLayer
4+
from sidemantic.core.coverage_analyzer import CoverageAnalyzer
5+
6+
7+
def test_generate_models_from_queries():
8+
"""Test generating model definitions from queries."""
9+
layer = SemanticLayer(auto_register=False)
10+
analyzer = CoverageAnalyzer(layer)
11+
12+
queries = [
13+
"""
14+
SELECT status, region, SUM(amount), COUNT(*)
15+
FROM orders
16+
GROUP BY status, region
17+
""",
18+
"""
19+
SELECT category, AVG(price), COUNT(DISTINCT product_id)
20+
FROM products
21+
GROUP BY category
22+
""",
23+
]
24+
25+
report = analyzer.analyze_queries(queries)
26+
models = analyzer.generate_models(report)
27+
28+
# Should generate 2 models
29+
assert len(models) == 2
30+
assert "orders" in models
31+
assert "products" in models
32+
33+
# Check orders model
34+
orders = models["orders"]
35+
assert orders["model"]["name"] == "orders"
36+
assert orders["model"]["table"] == "orders"
37+
38+
# Check orders dimensions
39+
assert len(orders["dimensions"]) == 2
40+
dim_names = {d["name"] for d in orders["dimensions"]}
41+
assert "status" in dim_names
42+
assert "region" in dim_names
43+
44+
# Check orders metrics
45+
assert len(orders["metrics"]) == 2
46+
metric_names = {m["name"] for m in orders["metrics"]}
47+
assert "sum_amount" in metric_names
48+
assert "count" in metric_names
49+
50+
# Check products model
51+
products = models["products"]
52+
assert products["model"]["name"] == "products"
53+
54+
# Check products dimensions
55+
assert len(products["dimensions"]) == 1
56+
assert products["dimensions"][0]["name"] == "category"
57+
58+
# Check products metrics
59+
assert len(products["metrics"]) == 2
60+
metric_names = {m["name"] for m in products["metrics"]}
61+
assert "avg_price" in metric_names
62+
assert "product_id_count" in metric_names
63+
64+
65+
def test_generate_models_count_distinct():
66+
"""Test COUNT(DISTINCT col) generates correct metric."""
67+
layer = SemanticLayer(auto_register=False)
68+
analyzer = CoverageAnalyzer(layer)
69+
70+
queries = [
71+
"""
72+
SELECT status, COUNT(DISTINCT customer_id)
73+
FROM orders
74+
GROUP BY status
75+
"""
76+
]
77+
78+
report = analyzer.analyze_queries(queries)
79+
models = analyzer.generate_models(report)
80+
81+
orders = models["orders"]
82+
metrics = {m["name"]: m for m in orders["metrics"]}
83+
84+
assert "customer_id_count" in metrics
85+
assert metrics["customer_id_count"]["agg"] == "count_distinct"
86+
assert metrics["customer_id_count"]["sql"] == "customer_id"
87+
88+
89+
def test_generate_models_no_duplicate_metrics():
90+
"""Test that duplicate metrics are not generated."""
91+
layer = SemanticLayer(auto_register=False)
92+
analyzer = CoverageAnalyzer(layer)
93+
94+
queries = [
95+
"SELECT status, SUM(amount) FROM orders GROUP BY status",
96+
"SELECT region, SUM(amount) FROM orders GROUP BY region",
97+
]
98+
99+
report = analyzer.analyze_queries(queries)
100+
models = analyzer.generate_models(report)
101+
102+
orders = models["orders"]
103+
metric_names = [m["name"] for m in orders["metrics"]]
104+
105+
# sum_amount should only appear once
106+
assert metric_names.count("sum_amount") == 1
107+
108+
109+
def test_generate_rewritten_queries():
110+
"""Test generating rewritten queries."""
111+
layer = SemanticLayer(auto_register=False)
112+
analyzer = CoverageAnalyzer(layer)
113+
114+
queries = [
115+
"""
116+
SELECT status, SUM(amount), COUNT(*)
117+
FROM orders
118+
GROUP BY status
119+
"""
120+
]
121+
122+
report = analyzer.analyze_queries(queries)
123+
rewritten = analyzer.generate_rewritten_queries(report)
124+
125+
# Should generate 1 rewritten query
126+
assert len(rewritten) == 1
127+
assert "query_1" in rewritten
128+
129+
sql = rewritten["query_1"]
130+
131+
# Check it's SQL format
132+
assert "SELECT" in sql
133+
assert "FROM orders" in sql
134+
135+
# Check it uses semantic layer syntax (model.dimension, model.metric)
136+
assert "orders.status" in sql
137+
assert "orders.count" in sql
138+
assert "orders.sum_amount" in sql
139+
140+
141+
def test_generate_rewritten_queries_with_filter():
142+
"""Test generating rewritten queries with WHERE clause."""
143+
layer = SemanticLayer(auto_register=False)
144+
analyzer = CoverageAnalyzer(layer)
145+
146+
queries = [
147+
"""
148+
SELECT status, SUM(amount)
149+
FROM orders
150+
WHERE status = 'completed'
151+
GROUP BY status
152+
"""
153+
]
154+
155+
report = analyzer.analyze_queries(queries)
156+
rewritten = analyzer.generate_rewritten_queries(report)
157+
158+
sql = rewritten["query_1"]
159+
160+
# Check it includes WHERE clause
161+
assert "WHERE" in sql
162+
assert "status = 'completed'" in sql or "status='completed'" in sql
163+
164+
165+
def test_generate_rewritten_queries_skips_unparseable():
166+
"""Test that unparseable queries are skipped."""
167+
layer = SemanticLayer(auto_register=False)
168+
analyzer = CoverageAnalyzer(layer)
169+
170+
queries = [
171+
"SELECT FROM WHERE", # Invalid
172+
"SELECT status, COUNT(*) FROM orders GROUP BY status", # Valid
173+
]
174+
175+
report = analyzer.analyze_queries(queries)
176+
rewritten = analyzer.generate_rewritten_queries(report)
177+
178+
# Should only generate 1 query (skip the invalid one)
179+
assert len(rewritten) == 1
180+
181+
182+
def test_write_model_files(tmp_path):
183+
"""Test writing model files to disk."""
184+
layer = SemanticLayer(auto_register=False)
185+
analyzer = CoverageAnalyzer(layer)
186+
187+
queries = [
188+
"SELECT status, SUM(amount) FROM orders GROUP BY status",
189+
]
190+
191+
report = analyzer.analyze_queries(queries)
192+
models = analyzer.generate_models(report)
193+
194+
output_dir = tmp_path / "models"
195+
analyzer.write_model_files(models, str(output_dir))
196+
197+
# Check file was created
198+
orders_file = output_dir / "orders.yml"
199+
assert orders_file.exists()
200+
201+
# Check file contents
202+
import yaml
203+
204+
with open(orders_file) as f:
205+
data = yaml.safe_load(f)
206+
207+
assert data["model"]["name"] == "orders"
208+
assert len(data["dimensions"]) == 1
209+
assert len(data["metrics"]) == 1
210+
211+
212+
def test_write_rewritten_queries(tmp_path):
213+
"""Test writing rewritten queries to disk."""
214+
layer = SemanticLayer(auto_register=False)
215+
analyzer = CoverageAnalyzer(layer)
216+
217+
queries = [
218+
"SELECT status, COUNT(*) FROM orders GROUP BY status",
219+
]
220+
221+
report = analyzer.analyze_queries(queries)
222+
rewritten = analyzer.generate_rewritten_queries(report)
223+
224+
output_dir = tmp_path / "queries"
225+
analyzer.write_rewritten_queries(rewritten, str(output_dir))
226+
227+
# Check file was created
228+
query_file = output_dir / "query_1.sql"
229+
assert query_file.exists()
230+
231+
# Check file contents
232+
content = query_file.read_text()
233+
assert "SELECT" in content
234+
assert "FROM orders" in content
235+
assert "orders.status" in content
236+
assert "orders.count" in content
237+
238+
239+
def test_generate_models_multiple_aggregations_same_column():
240+
"""Test handling multiple aggregation types on same column."""
241+
layer = SemanticLayer(auto_register=False)
242+
analyzer = CoverageAnalyzer(layer)
243+
244+
queries = [
245+
"""
246+
SELECT
247+
status,
248+
SUM(amount),
249+
AVG(amount),
250+
MIN(amount),
251+
MAX(amount)
252+
FROM orders
253+
GROUP BY status
254+
"""
255+
]
256+
257+
report = analyzer.analyze_queries(queries)
258+
models = analyzer.generate_models(report)
259+
260+
orders = models["orders"]
261+
metric_names = {m["name"] for m in orders["metrics"]}
262+
263+
assert "sum_amount" in metric_names
264+
assert "avg_amount" in metric_names
265+
assert "min_amount" in metric_names
266+
assert "max_amount" in metric_names

0 commit comments

Comments
 (0)