@@ -74,13 +74,16 @@ class CoverageAnalyzer:
7474 - How to rewrite queries using the semantic layer
7575 """
7676
77- def __init__ (self , layer : SemanticLayer ):
77+ def __init__ (self , layer : SemanticLayer , connection = None ):
7878 """Initialize analyzer.
7979
8080 Args:
8181 layer: Semantic layer to analyze coverage for
82+ connection: Optional database connection for querying information_schema.
83+ Should have an execute() method that returns results.
8284 """
8385 self .layer = layer
86+ self .connection = connection
8487 self .analyses : list [QueryAnalysis ] = []
8588
8689 # Build mapping from table names to model names
@@ -89,6 +92,63 @@ def __init__(self, layer: SemanticLayer):
8992 if model .table :
9093 self .table_to_model [model .table ] = model_name
9194
95+ # Cache information_schema data if connection available
96+ self .primary_keys : dict [str , set [str ]] = {} # table -> pk columns
97+ self .foreign_keys : dict [tuple [str , str ], tuple [str , str ]] = {} # (fk_table, fk_col) -> (pk_table, pk_col)
98+ self .table_columns : dict [str , set [str ]] = defaultdict (set ) # table -> columns
99+
100+ if self .connection :
101+ self ._load_schema_metadata ()
102+
103+ def _load_schema_metadata (self ) -> None :
104+ """Load schema metadata from information_schema."""
105+ try :
106+ # Load primary keys
107+ pk_query = """
108+ SELECT table_name, column_name
109+ FROM information_schema.key_column_usage
110+ WHERE constraint_name IN (
111+ SELECT constraint_name
112+ FROM information_schema.table_constraints
113+ WHERE constraint_type = 'PRIMARY KEY'
114+ )
115+ """
116+ pk_results = self .connection .execute (pk_query ).fetchall ()
117+ for table_name , column_name in pk_results :
118+ self .primary_keys .setdefault (table_name , set ()).add (column_name )
119+
120+ # Load foreign keys
121+ # DuckDB information_schema uses referential_constraints to map FK to PK constraints
122+ fk_query = """
123+ SELECT
124+ fk_kcu.table_name AS fk_table,
125+ fk_kcu.column_name AS fk_column,
126+ pk_kcu.table_name AS pk_table,
127+ pk_kcu.column_name AS pk_column
128+ FROM information_schema.referential_constraints rc
129+ JOIN information_schema.key_column_usage fk_kcu
130+ ON rc.constraint_name = fk_kcu.constraint_name
131+ JOIN information_schema.key_column_usage pk_kcu
132+ ON rc.unique_constraint_name = pk_kcu.constraint_name
133+ """
134+ fk_results = self .connection .execute (fk_query ).fetchall ()
135+ for fk_table , fk_column , pk_table , pk_column in fk_results :
136+ self .foreign_keys [(fk_table , fk_column )] = (pk_table , pk_column )
137+
138+ # Load all columns for each table
139+ col_query = """
140+ SELECT table_name, column_name
141+ FROM information_schema.columns
142+ """
143+ col_results = self .connection .execute (col_query ).fetchall ()
144+ for table_name , column_name in col_results :
145+ self .table_columns [table_name ].add (column_name )
146+
147+ except Exception as e :
148+ # If information_schema queries fail, just continue without metadata
149+ # (will fall back to pattern-based detection)
150+ print (f"Warning: Could not load schema metadata: { e } " )
151+
92152 def analyze_queries (self , queries : list [str ]) -> CoverageReport :
93153 """Analyze a list of SQL queries.
94154
@@ -183,13 +243,66 @@ def _extract_tables(self, parsed: exp.Expression, analysis: QueryAnalysis) -> No
183243 analysis .table_aliases [table .alias ] = table_name
184244
185245 def _extract_columns (self , parsed : exp .Expression , analysis : QueryAnalysis ) -> None :
186- """Extract column references grouped by table."""
246+ """Extract column references grouped by table.
247+
248+ Handles both qualified (table.column) and unqualified (column) references.
249+ For unqualified columns, attempts to infer table using information_schema.
250+ """
187251 for col in parsed .find_all (exp .Column ):
188252 col_name = col .name
189253 table_name = col .table if col .table else None
190254
191- if col_name and table_name :
192- analysis .columns [table_name ].add (col_name )
255+ if not col_name :
256+ continue
257+
258+ # If column has explicit table qualifier, use it
259+ if table_name :
260+ # Resolve alias to real table name
261+ real_table = analysis .table_aliases .get (table_name , table_name )
262+ analysis .columns [real_table ].add (col_name )
263+ else :
264+ # Unqualified column - try to infer table
265+ inferred_table = self ._infer_table_for_column (col_name , analysis )
266+ if inferred_table :
267+ analysis .columns [inferred_table ].add (col_name )
268+
269+ def _infer_table_for_column (self , col_name : str , analysis : QueryAnalysis ) -> str | None :
270+ """Infer which table an unqualified column belongs to.
271+
272+ Uses information_schema data if available, otherwise falls back to heuristics.
273+
274+ Args:
275+ col_name: Column name to infer table for
276+ analysis: Current query analysis
277+
278+ Returns:
279+ Table name or None if can't be inferred
280+ """
281+ # Get tables involved in this query
282+ query_tables = list (analysis .tables )
283+
284+ if not query_tables :
285+ return None
286+
287+ # If only one table in query, must be that table
288+ if len (query_tables ) == 1 :
289+ return query_tables [0 ]
290+
291+ # Use information_schema to find which tables have this column
292+ if self .table_columns :
293+ matching_tables = [t for t in query_tables if col_name in self .table_columns .get (t , set ())]
294+
295+ # If exactly one table in the query has this column, use it
296+ if len (matching_tables ) == 1 :
297+ return matching_tables [0 ]
298+
299+ # If multiple tables have it, prefer the FROM table (if available)
300+ if len (matching_tables ) > 1 and analysis .from_table :
301+ if analysis .from_table in matching_tables :
302+ return analysis .from_table
303+
304+ # Fall back to FROM table if we can't determine
305+ return analysis .from_table
193306
194307 def _extract_aggregations (self , parsed : exp .Expression , analysis : QueryAnalysis ) -> None :
195308 """Extract aggregation functions using sqlglot's AggFunc base class."""
@@ -337,7 +450,10 @@ def _extract_joins(self, parsed: exp.Expression, analysis: QueryAnalysis) -> Non
337450 analysis .joins .append ((from_table , from_alias , to_table , to_alias , join_type , on_clause ))
338451
339452 def _extract_relationships (self , parsed : exp .Expression , analysis : QueryAnalysis ) -> None :
340- """Extract relationships from JOIN ON conditions."""
453+ """Extract relationships from JOIN ON conditions.
454+
455+ Uses information_schema data if available, falls back to pattern matching.
456+ """
341457 for join in parsed .find_all (exp .Join ):
342458 if not isinstance (join .this , exp .Table ):
343459 continue
@@ -362,47 +478,67 @@ def _extract_relationships(self, parsed: exp.Expression, analysis: QueryAnalysis
362478 left_table = analysis .table_aliases .get (left .table , left .table ) if left .table else ""
363479 right_table = analysis .table_aliases .get (right .table , right .table ) if right .table else ""
364480
365- # Determine which side has the foreign key by checking column names
366- # Foreign keys typically end with _id (e.g., customer_id, product_id)
367-
368- left_is_fk = left . name . endswith ( "_id" )
369- right_is_fk = right . name . endswith ( "_id" )
481+ # Try information_schema first
482+ fk_table = None
483+ fk_column = None
484+ pk_table = None
485+ pk_column = None
370486
371- if left_is_fk and not right_is_fk :
372- # Left has the FK, right has the PK
487+ # Check if left side is a known FK
488+ if (left_table , left .name ) in self .foreign_keys :
489+ pk_table , pk_column = self .foreign_keys [(left_table , left .name )]
373490 fk_table = left_table
374491 fk_column = left .name
375- pk_table = right_table
376- pk_column = right .name
377- elif right_is_fk and not left_is_fk :
378- # Right has the FK, left has the PK
492+ # Check if right side is a known FK
493+ elif (right_table , right .name ) in self .foreign_keys :
494+ pk_table , pk_column = self .foreign_keys [(right_table , right .name )]
379495 fk_table = right_table
380496 fk_column = right .name
381- pk_table = left_table
382- pk_column = left .name
383- else :
384- # Can't determine from column names, fall back to join direction
385- # The table being joined TO (to_table) usually has the FK
386- if right_table == to_table :
387- fk_table = right_table
388- fk_column = right .name
389- pk_table = left_table
390- pk_column = left .name
391- elif left_table == to_table :
497+
498+ # Fall back to pattern matching if information_schema didn't help
499+ if not fk_table :
500+ # Determine which side has the foreign key by checking column names
501+ # Foreign keys typically end with _id (e.g., customer_id, product_id)
502+ left_is_fk = left .name .endswith ("_id" )
503+ right_is_fk = right .name .endswith ("_id" )
504+
505+ if left_is_fk and not right_is_fk :
506+ # Left has the FK, right has the PK
392507 fk_table = left_table
393508 fk_column = left .name
394509 pk_table = right_table
395510 pk_column = right .name
511+ elif right_is_fk and not left_is_fk :
512+ # Right has the FK, left has the PK
513+ fk_table = right_table
514+ fk_column = right .name
515+ pk_table = left_table
516+ pk_column = left .name
396517 else :
397- # Neither side matches the joined table, skip
398- continue
518+ # Can't determine from column names, fall back to join direction
519+ # The table being joined TO (to_table) usually has the FK
520+ if right_table == to_table :
521+ fk_table = right_table
522+ fk_column = right .name
523+ pk_table = left_table
524+ pk_column = left .name
525+ elif left_table == to_table :
526+ fk_table = left_table
527+ fk_column = left .name
528+ pk_table = right_table
529+ pk_column = right .name
530+ else :
531+ # Neither side matches the joined table, skip
532+ continue
399533
400- # Infer primary key column - if it ends with _id, the actual PK is probably "id"
401- inferred_pk_column = "id" if pk_column .endswith ("_id" ) else pk_column
534+ # Infer primary key column if not from information_schema
535+ if pk_column and not self .foreign_keys :
536+ # If it ends with _id, the actual PK is probably "id"
537+ pk_column = "id" if pk_column .endswith ("_id" ) else pk_column
402538
403539 # Generate relationships for both directions
404540 # FK table has many_to_one relationship to PK table
405- analysis .relationships .append ((fk_table , pk_table , "many_to_one" , fk_column , inferred_pk_column ))
541+ analysis .relationships .append ((fk_table , pk_table , "many_to_one" , fk_column , pk_column ))
406542
407543 # PK table has one_to_many relationship to FK table
408544 # For one_to_many, we don't store FK (it's on the other side)
0 commit comments