Skip to content

Commit b6e7600

Browse files
committed
Add information_schema support for PK/FK detection and unqualified column resolution
- Add optional database connection parameter to CoverageAnalyzer - Query information_schema for primary keys, foreign keys, and column metadata - Use information_schema data first for FK/PK detection, fall back to pattern matching - Infer table for unqualified columns using information_schema column metadata - Support both qualified (table.column) and unqualified (column) references - Add tests for information_schema-based relationship detection and column inference
1 parent 9a1eac1 commit b6e7600

2 files changed

Lines changed: 289 additions & 32 deletions

File tree

sidemantic/core/coverage_analyzer.py

Lines changed: 168 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,16 @@ class CoverageAnalyzer:
7474
- How to rewrite queries using the semantic layer
7575
"""
7676

77-
def __init__(self, layer: SemanticLayer):
77+
def __init__(self, layer: SemanticLayer, connection=None):
7878
"""Initialize analyzer.
7979
8080
Args:
8181
layer: Semantic layer to analyze coverage for
82+
connection: Optional database connection for querying information_schema.
83+
Should have an execute() method that returns results.
8284
"""
8385
self.layer = layer
86+
self.connection = connection
8487
self.analyses: list[QueryAnalysis] = []
8588

8689
# Build mapping from table names to model names
@@ -89,6 +92,63 @@ def __init__(self, layer: SemanticLayer):
8992
if model.table:
9093
self.table_to_model[model.table] = model_name
9194

95+
# Cache information_schema data if connection available
96+
self.primary_keys: dict[str, set[str]] = {} # table -> pk columns
97+
self.foreign_keys: dict[tuple[str, str], tuple[str, str]] = {} # (fk_table, fk_col) -> (pk_table, pk_col)
98+
self.table_columns: dict[str, set[str]] = defaultdict(set) # table -> columns
99+
100+
if self.connection:
101+
self._load_schema_metadata()
102+
103+
def _load_schema_metadata(self) -> None:
104+
"""Load schema metadata from information_schema."""
105+
try:
106+
# Load primary keys
107+
pk_query = """
108+
SELECT table_name, column_name
109+
FROM information_schema.key_column_usage
110+
WHERE constraint_name IN (
111+
SELECT constraint_name
112+
FROM information_schema.table_constraints
113+
WHERE constraint_type = 'PRIMARY KEY'
114+
)
115+
"""
116+
pk_results = self.connection.execute(pk_query).fetchall()
117+
for table_name, column_name in pk_results:
118+
self.primary_keys.setdefault(table_name, set()).add(column_name)
119+
120+
# Load foreign keys
121+
# DuckDB information_schema uses referential_constraints to map FK to PK constraints
122+
fk_query = """
123+
SELECT
124+
fk_kcu.table_name AS fk_table,
125+
fk_kcu.column_name AS fk_column,
126+
pk_kcu.table_name AS pk_table,
127+
pk_kcu.column_name AS pk_column
128+
FROM information_schema.referential_constraints rc
129+
JOIN information_schema.key_column_usage fk_kcu
130+
ON rc.constraint_name = fk_kcu.constraint_name
131+
JOIN information_schema.key_column_usage pk_kcu
132+
ON rc.unique_constraint_name = pk_kcu.constraint_name
133+
"""
134+
fk_results = self.connection.execute(fk_query).fetchall()
135+
for fk_table, fk_column, pk_table, pk_column in fk_results:
136+
self.foreign_keys[(fk_table, fk_column)] = (pk_table, pk_column)
137+
138+
# Load all columns for each table
139+
col_query = """
140+
SELECT table_name, column_name
141+
FROM information_schema.columns
142+
"""
143+
col_results = self.connection.execute(col_query).fetchall()
144+
for table_name, column_name in col_results:
145+
self.table_columns[table_name].add(column_name)
146+
147+
except Exception as e:
148+
# If information_schema queries fail, just continue without metadata
149+
# (will fall back to pattern-based detection)
150+
print(f"Warning: Could not load schema metadata: {e}")
151+
92152
def analyze_queries(self, queries: list[str]) -> CoverageReport:
93153
"""Analyze a list of SQL queries.
94154
@@ -183,13 +243,66 @@ def _extract_tables(self, parsed: exp.Expression, analysis: QueryAnalysis) -> No
183243
analysis.table_aliases[table.alias] = table_name
184244

185245
def _extract_columns(self, parsed: exp.Expression, analysis: QueryAnalysis) -> None:
186-
"""Extract column references grouped by table."""
246+
"""Extract column references grouped by table.
247+
248+
Handles both qualified (table.column) and unqualified (column) references.
249+
For unqualified columns, attempts to infer table using information_schema.
250+
"""
187251
for col in parsed.find_all(exp.Column):
188252
col_name = col.name
189253
table_name = col.table if col.table else None
190254

191-
if col_name and table_name:
192-
analysis.columns[table_name].add(col_name)
255+
if not col_name:
256+
continue
257+
258+
# If column has explicit table qualifier, use it
259+
if table_name:
260+
# Resolve alias to real table name
261+
real_table = analysis.table_aliases.get(table_name, table_name)
262+
analysis.columns[real_table].add(col_name)
263+
else:
264+
# Unqualified column - try to infer table
265+
inferred_table = self._infer_table_for_column(col_name, analysis)
266+
if inferred_table:
267+
analysis.columns[inferred_table].add(col_name)
268+
269+
def _infer_table_for_column(self, col_name: str, analysis: QueryAnalysis) -> str | None:
270+
"""Infer which table an unqualified column belongs to.
271+
272+
Uses information_schema data if available, otherwise falls back to heuristics.
273+
274+
Args:
275+
col_name: Column name to infer table for
276+
analysis: Current query analysis
277+
278+
Returns:
279+
Table name or None if can't be inferred
280+
"""
281+
# Get tables involved in this query
282+
query_tables = list(analysis.tables)
283+
284+
if not query_tables:
285+
return None
286+
287+
# If only one table in query, must be that table
288+
if len(query_tables) == 1:
289+
return query_tables[0]
290+
291+
# Use information_schema to find which tables have this column
292+
if self.table_columns:
293+
matching_tables = [t for t in query_tables if col_name in self.table_columns.get(t, set())]
294+
295+
# If exactly one table in the query has this column, use it
296+
if len(matching_tables) == 1:
297+
return matching_tables[0]
298+
299+
# If multiple tables have it, prefer the FROM table (if available)
300+
if len(matching_tables) > 1 and analysis.from_table:
301+
if analysis.from_table in matching_tables:
302+
return analysis.from_table
303+
304+
# Fall back to FROM table if we can't determine
305+
return analysis.from_table
193306

194307
def _extract_aggregations(self, parsed: exp.Expression, analysis: QueryAnalysis) -> None:
195308
"""Extract aggregation functions using sqlglot's AggFunc base class."""
@@ -337,7 +450,10 @@ def _extract_joins(self, parsed: exp.Expression, analysis: QueryAnalysis) -> Non
337450
analysis.joins.append((from_table, from_alias, to_table, to_alias, join_type, on_clause))
338451

339452
def _extract_relationships(self, parsed: exp.Expression, analysis: QueryAnalysis) -> None:
340-
"""Extract relationships from JOIN ON conditions."""
453+
"""Extract relationships from JOIN ON conditions.
454+
455+
Uses information_schema data if available, falls back to pattern matching.
456+
"""
341457
for join in parsed.find_all(exp.Join):
342458
if not isinstance(join.this, exp.Table):
343459
continue
@@ -362,47 +478,67 @@ def _extract_relationships(self, parsed: exp.Expression, analysis: QueryAnalysis
362478
left_table = analysis.table_aliases.get(left.table, left.table) if left.table else ""
363479
right_table = analysis.table_aliases.get(right.table, right.table) if right.table else ""
364480

365-
# Determine which side has the foreign key by checking column names
366-
# Foreign keys typically end with _id (e.g., customer_id, product_id)
367-
368-
left_is_fk = left.name.endswith("_id")
369-
right_is_fk = right.name.endswith("_id")
481+
# Try information_schema first
482+
fk_table = None
483+
fk_column = None
484+
pk_table = None
485+
pk_column = None
370486

371-
if left_is_fk and not right_is_fk:
372-
# Left has the FK, right has the PK
487+
# Check if left side is a known FK
488+
if (left_table, left.name) in self.foreign_keys:
489+
pk_table, pk_column = self.foreign_keys[(left_table, left.name)]
373490
fk_table = left_table
374491
fk_column = left.name
375-
pk_table = right_table
376-
pk_column = right.name
377-
elif right_is_fk and not left_is_fk:
378-
# Right has the FK, left has the PK
492+
# Check if right side is a known FK
493+
elif (right_table, right.name) in self.foreign_keys:
494+
pk_table, pk_column = self.foreign_keys[(right_table, right.name)]
379495
fk_table = right_table
380496
fk_column = right.name
381-
pk_table = left_table
382-
pk_column = left.name
383-
else:
384-
# Can't determine from column names, fall back to join direction
385-
# The table being joined TO (to_table) usually has the FK
386-
if right_table == to_table:
387-
fk_table = right_table
388-
fk_column = right.name
389-
pk_table = left_table
390-
pk_column = left.name
391-
elif left_table == to_table:
497+
498+
# Fall back to pattern matching if information_schema didn't help
499+
if not fk_table:
500+
# Determine which side has the foreign key by checking column names
501+
# Foreign keys typically end with _id (e.g., customer_id, product_id)
502+
left_is_fk = left.name.endswith("_id")
503+
right_is_fk = right.name.endswith("_id")
504+
505+
if left_is_fk and not right_is_fk:
506+
# Left has the FK, right has the PK
392507
fk_table = left_table
393508
fk_column = left.name
394509
pk_table = right_table
395510
pk_column = right.name
511+
elif right_is_fk and not left_is_fk:
512+
# Right has the FK, left has the PK
513+
fk_table = right_table
514+
fk_column = right.name
515+
pk_table = left_table
516+
pk_column = left.name
396517
else:
397-
# Neither side matches the joined table, skip
398-
continue
518+
# Can't determine from column names, fall back to join direction
519+
# The table being joined TO (to_table) usually has the FK
520+
if right_table == to_table:
521+
fk_table = right_table
522+
fk_column = right.name
523+
pk_table = left_table
524+
pk_column = left.name
525+
elif left_table == to_table:
526+
fk_table = left_table
527+
fk_column = left.name
528+
pk_table = right_table
529+
pk_column = right.name
530+
else:
531+
# Neither side matches the joined table, skip
532+
continue
399533

400-
# Infer primary key column - if it ends with _id, the actual PK is probably "id"
401-
inferred_pk_column = "id" if pk_column.endswith("_id") else pk_column
534+
# Infer primary key column if not from information_schema
535+
if pk_column and not self.foreign_keys:
536+
# If it ends with _id, the actual PK is probably "id"
537+
pk_column = "id" if pk_column.endswith("_id") else pk_column
402538

403539
# Generate relationships for both directions
404540
# FK table has many_to_one relationship to PK table
405-
analysis.relationships.append((fk_table, pk_table, "many_to_one", fk_column, inferred_pk_column))
541+
analysis.relationships.append((fk_table, pk_table, "many_to_one", fk_column, pk_column))
406542

407543
# PK table has one_to_many relationship to FK table
408544
# For one_to_many, we don't store FK (it's on the other side)

tests/test_coverage_analyzer_generation.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,3 +714,124 @@ def test_rewrite_query_with_derived_metrics():
714714
# Should NOT include base metrics in SELECT
715715
assert "orders.sum_revenue" not in sql
716716
assert "orders.count" not in sql
717+
718+
719+
def test_information_schema_relationship_detection():
720+
"""Test that relationships are detected using information_schema."""
721+
import duckdb
722+
723+
# Create in-memory database with FK constraints
724+
con = duckdb.connect(":memory:")
725+
726+
con.execute("""
727+
CREATE TABLE customers (
728+
id INTEGER PRIMARY KEY,
729+
name VARCHAR,
730+
region VARCHAR
731+
)
732+
""")
733+
734+
con.execute("""
735+
CREATE TABLE orders (
736+
id INTEGER PRIMARY KEY,
737+
customer_id INTEGER,
738+
amount DECIMAL,
739+
status VARCHAR,
740+
FOREIGN KEY (customer_id) REFERENCES customers(id)
741+
)
742+
""")
743+
744+
# Create analyzer with connection
745+
layer = SemanticLayer(auto_register=False)
746+
analyzer = CoverageAnalyzer(layer, connection=con)
747+
748+
# Verify schema metadata was loaded
749+
assert "customers" in analyzer.primary_keys
750+
assert "id" in analyzer.primary_keys["customers"]
751+
assert ("orders", "customer_id") in analyzer.foreign_keys
752+
assert analyzer.foreign_keys[("orders", "customer_id")] == ("customers", "id")
753+
754+
# Analyze a query
755+
queries = [
756+
"""
757+
SELECT c.region, COUNT(o.id)
758+
FROM customers c
759+
JOIN orders o ON c.id = o.customer_id
760+
GROUP BY c.region
761+
"""
762+
]
763+
764+
report = analyzer.analyze_queries(queries)
765+
analysis = report.query_analyses[0]
766+
767+
# Should extract relationships using information_schema
768+
assert len(analysis.relationships) == 2
769+
770+
# Verify correct FK/PK detection
771+
rels_by_table = {}
772+
for from_model, to_model, rel_type, fk_col, pk_col in analysis.relationships:
773+
if rel_type == "many_to_one":
774+
rels_by_table[from_model] = (to_model, rel_type, fk_col, pk_col)
775+
776+
# orders has many_to_one to customers with correct PK
777+
assert "orders" in rels_by_table
778+
assert rels_by_table["orders"][0] == "customers"
779+
assert rels_by_table["orders"][2] == "customer_id"
780+
assert rels_by_table["orders"][3] == "id" # PK from information_schema, not inferred
781+
782+
con.close()
783+
784+
785+
def test_information_schema_column_inference():
786+
"""Test unqualified column inference using information_schema."""
787+
import duckdb
788+
789+
# Create in-memory database
790+
con = duckdb.connect(":memory:")
791+
792+
con.execute("""
793+
CREATE TABLE customers (
794+
id INTEGER,
795+
name VARCHAR,
796+
region VARCHAR
797+
)
798+
""")
799+
800+
con.execute("""
801+
CREATE TABLE orders (
802+
id INTEGER,
803+
customer_id INTEGER,
804+
amount DECIMAL
805+
)
806+
""")
807+
808+
# Create analyzer with connection
809+
layer = SemanticLayer(auto_register=False)
810+
analyzer = CoverageAnalyzer(layer, connection=con)
811+
812+
# Verify column metadata was loaded
813+
assert "region" in analyzer.table_columns["customers"]
814+
assert "amount" in analyzer.table_columns["orders"]
815+
816+
# Analyze query with unqualified columns
817+
queries = [
818+
"""
819+
SELECT region, SUM(amount)
820+
FROM customers
821+
JOIN orders ON customers.id = orders.customer_id
822+
GROUP BY region
823+
"""
824+
]
825+
826+
report = analyzer.analyze_queries(queries)
827+
analysis = report.query_analyses[0]
828+
829+
# Should infer that 'region' belongs to 'customers'
830+
assert "customers" in analysis.columns
831+
assert "region" in analysis.columns["customers"]
832+
833+
# Should infer that 'amount' belongs to 'orders'
834+
assert "orders" in analysis.columns
835+
assert "amount" in analysis.columns["orders"]
836+
837+
con.close()

0 commit comments

Comments
 (0)