@@ -558,6 +558,8 @@ def explain(
558558
559559 metrics = metrics or []
560560 dimensions = dimensions or []
561+ filters = list (filters ) if filters else []
562+ segments = segments or []
561563
562564 # Compile the actual SQL (respects use_preaggregations setting)
563565 sql = self .compile (
@@ -576,21 +578,15 @@ def explain(
576578
577579 use_preaggs = use_preaggregations if use_preaggregations is not None else self .use_preaggregations
578580
579- # Extract model names from metric/dimension/filter references
580- model_names = set ()
581- for ref in list (metrics ) + list (dimensions ):
582- if "." in ref :
583- model_name = ref .split ("." , 1 )[0 ]
584- if model_name :
585- model_names .add (model_name )
586-
587- # Also extract model names from filters (e.g., "customers.status = 'vip'")
588- import re
589-
590- for f in filters or []:
591- # Match "model.column" patterns before operators
592- for match in re .finditer (r"(\w+)\.(\w+)\s*[=<>!]" , f ):
593- model_names .add (match .group (1 ))
581+ generator = SQLGenerator (
582+ self .graph ,
583+ dialect = dialect or self .dialect ,
584+ preagg_database = self .preagg_database ,
585+ preagg_schema = self .preagg_schema ,
586+ )
587+ segment_filters = generator ._resolve_segments (segments )
588+ all_filters = filters + segment_filters
589+ model_names = generator ._find_required_models (metrics , dimensions , all_filters )
594590
595591 # Strip model prefixes from metrics and dimensions for matcher
596592 bare_metrics = []
@@ -610,18 +606,17 @@ def explain(
610606 bare_dims .append (dim_name )
611607
612608 bare_filters = []
613- if filters :
614- for f in filters :
615- # Strip any model prefix from filters
616- for mn in model_names :
617- f = f .replace (f"{ mn } ." , "" )
618- bare_filters .append (f )
609+ for f in all_filters :
610+ for mn in model_names :
611+ f = f .replace (f"{ mn } ." , "" )
612+ f = f .replace (f"{ mn } _cte." , "" )
613+ bare_filters .append (f )
619614
620615 # Check preconditions for preagg routing
621616 if not use_preaggs :
622617 return QueryPlan (
623618 sql = sql ,
624- model = next ( iter ( model_names ), None ) ,
619+ model = model_names [ 0 ] if model_names else None ,
625620 metrics = bare_metrics ,
626621 dimensions = bare_dims ,
627622 used_preaggregation = False ,
@@ -641,14 +636,14 @@ def explain(
641636 if ungrouped :
642637 return QueryPlan (
643638 sql = sql ,
644- model = next ( iter ( model_names )) ,
639+ model = model_names [ 0 ] ,
645640 metrics = bare_metrics ,
646641 dimensions = bare_dims ,
647642 used_preaggregation = False ,
648643 routing_reason = "ungrouped query, preaggs require aggregation" ,
649644 )
650645
651- model_name = next ( iter ( model_names ))
646+ model_name = model_names [ 0 ]
652647 try :
653648 model = self .get_model (model_name )
654649 except KeyError :
0 commit comments