diff --git a/kernelguard.py b/kernelguard.py index f086847..cef089b 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -524,6 +524,109 @@ def _expr_has_tensor_version(expr: ast.AST | None) -> bool: ) +def _function_param_names(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> tuple[str, ...]: + names = [arg.arg for arg in fn.args.posonlyargs] + names.extend(arg.arg for arg in fn.args.args) + names.extend(arg.arg for arg in fn.args.kwonlyargs) + if fn.args.vararg is not None: + names.append(fn.args.vararg.arg) + if fn.args.kwarg is not None: + names.append(fn.args.kwarg.arg) + return tuple(names) + + +def _depends_on_names(expr: ast.AST | None, names: set[str]) -> bool: + if expr is None or not names: + return False + return bool(_expr_names(expr) & names) + + +def _state_read_root(expr: ast.AST | None, state_names: set[str], aliases: dict[str, str]) -> Optional[str]: + if expr is None: + return None + if isinstance(expr, ast.Name): + if expr.id in state_names: + return expr.id + return aliases.get(expr.id) + if isinstance(expr, (ast.Subscript, ast.Attribute)): + root = _ast_root_name(expr) + if root in state_names: + return root + if ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Attribute) + and expr.func.attr == "get" + ): + root = _ast_root_name(expr.func.value) + if root in state_names: + return root + return None + + +def _has_state_shortcut_replay( + fn: ast.FunctionDef | ast.AsyncFunctionDef, + state_names: set[str], + input_names: set[str], +) -> Optional[str]: + input_derived = set(input_names) + aliases: dict[str, str] = {} + + changed = True + while changed: + changed = False + for child in ast.walk(fn): + if not isinstance(child, ast.Assign): + continue + if len(child.targets) != 1 or not isinstance(child.targets[0], ast.Name): + continue + target_name = child.targets[0].id + if _depends_on_names(child.value, input_derived) and target_name not in input_derived: + input_derived.add(target_name) + changed = True + state_root = _state_read_root(child.value, state_names, aliases) + if state_root is not None and aliases.get(target_name) != state_root: + aliases[target_name] = state_root + changed = True + + state_from_input: set[str] = set() + for child in ast.walk(fn): + if isinstance(child, ast.Assign): + for target in child.targets: + root = _ast_root_name(target) + if root in state_names and _depends_on_names(child.value, input_derived): + state_from_input.add(root) + elif isinstance(child, ast.Call) and isinstance(child.func, ast.Attribute): + root = _ast_root_name(child.func.value) + if root not in state_names or child.func.attr not in {"append", "extend", "set", "update"}: + continue + if any(_depends_on_names(arg, input_derived) for arg in child.args): + state_from_input.add(root) + + if not state_from_input: + return None + + for child in ast.walk(fn): + if not isinstance(child, ast.If): + continue + if _body_has_calls(child.body): + continue + if _expr_names(child.test) & input_names: + continue + test_names = _expr_names(child.test) + for stmt in child.body: + if not isinstance(stmt, ast.Return) or stmt.value is None: + continue + state_root = _state_read_root(stmt.value, state_names, aliases) + if state_root not in state_from_input: + continue + if state_root in test_names or any( + alias in test_names and alias_root == state_root + for alias, alias_root in aliases.items() + ): + return state_root + return None + + _TRIVIAL_GPU_OPS = frozenset({ "fill_", "zero_", "copy_", "fill", "zero", "record", }) @@ -2462,6 +2565,77 @@ def _has_ver(expr: ast.AST | None) -> bool: return [] +def detect_partial_bound_state_replay(code: str | SubmissionFacts) -> list[dict]: + """Pattern: functools.partial binds mutable state used for replay.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None or not isinstance(tree, ast.Module): + return [] + + module_aliases = {"functools"} + direct_aliases: set[str] = set() + for node in facts._imports: + for alias in node.names: + if alias.name == "functools": + module_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module != "functools": + continue + for alias in node.names: + if alias.name == "partial": + direct_aliases.add(alias.asname or alias.name) + + functions = { + node.name: node + for node in tree.body + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(target, ast.Name) and target.id == "custom_kernel" for target in stmt.targets): + continue + if not isinstance(stmt.value, ast.Call) or not stmt.value.args: + continue + call = stmt.value + is_partial = False + if isinstance(call.func, ast.Name) and call.func.id in direct_aliases: + is_partial = True + elif ( + isinstance(call.func, ast.Attribute) + and call.func.attr == "partial" + and isinstance(call.func.value, ast.Name) + and call.func.value.id in module_aliases + ): + is_partial = True + if not is_partial: + continue + if not isinstance(call.args[0], ast.Name) or call.args[0].id not in functions: + continue + + helper = functions[call.args[0].id] + params = _function_param_names(helper) + bound_count = len(call.args) - 1 + if bound_count <= 0 or bound_count >= len(params): + continue + state_names = set(params[:bound_count]) + input_names = set(params[bound_count:]) + state_root = _has_state_shortcut_replay(helper, state_names, input_names) + if state_root is None: + continue + return [{ + "pattern": "PARTIAL_BOUND_STATE_REPLAY", + "severity": "critical", + "evidence": ( + f"custom_kernel binds replay state {state_root} through functools.partial " + f"and returns it from {helper.name} on an input-independent fast path" + ), + }] + + return [] + + RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)") RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=") RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]") @@ -3479,6 +3653,10 @@ class RulePolicy: "LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep", ), + "PARTIAL_BOUND_STATE_REPLAY": RulePolicy( + "PARTIAL_BOUND_STATE_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), + (), "keep", + ), "SHAPE_OUTPUT_REPLAY": RulePolicy( "SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep", @@ -3762,6 +3940,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_decode_mm_ref, detect_result_caching, detect_last_call_replay, + detect_partial_bound_state_replay, detect_shape_output_replay, detect_timed_input_replay, detect_cuda_graph_replay, @@ -3800,6 +3979,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("decode_mm_ref", detect_decode_mm_ref), ("result_caching", detect_result_caching), ("last_call_replay", detect_last_call_replay), + ("partial_bound_state_replay", detect_partial_bound_state_replay), ("shape_output_replay", detect_shape_output_replay), ("timed_input_replay", detect_timed_input_replay), ("cuda_graph_replay", detect_cuda_graph_replay),