diff --git a/kernelguard.py b/kernelguard.py index f086847..ae36091 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2197,10 +2197,51 @@ def _has_ver(expr: ast.AST | None) -> bool: for n in ast.walk(expr) ) + # Collect inner-function names that are aliased to a top-level entrypoint + # via a factory return. Pattern: + # def _make_kernel(): + # ... + # def k(data): + # ... + # return k + # custom_kernel = _make_kernel() + # The replay logic lives in `k`, not `custom_kernel`. The straight + # ``is_entrypoint_name(node.name)`` check would skip it because + # ``k`` is not in ENTRYPOINT_NAMES, so the module-level alias is + # invisible to the analyzer. Trace those aliases here. + entrypoint_aliases: set[str] = set() + if isinstance(tree, ast.Module): + factory_returns: dict[str, set[str]] = {} + for stmt in tree.body: + if not isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + returned: set[str] = set() + for sub in ast.walk(stmt): + if isinstance(sub, ast.Return) and isinstance(sub.value, ast.Name): + returned.add(sub.value.id) + if returned: + factory_returns[stmt.name] = returned + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + for target in stmt.targets: + if not (isinstance(target, ast.Name) and is_entrypoint_name(target.id)): + continue + value = stmt.value + if isinstance(value, ast.Name): + entrypoint_aliases.add(value.id) + elif ( + isinstance(value, ast.Call) + and isinstance(value.func, ast.Name) + and value.func.id in factory_returns + ): + entrypoint_aliases.update(factory_returns[value.func.id]) + for node in ast.walk(tree): if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): continue - if not is_entrypoint_name(node.name): + if not (is_entrypoint_name(node.name) or node.name in entrypoint_aliases): continue signature_features: dict[str, set[str]] = defaultdict(set)