diff --git a/kernelguard.py b/kernelguard.py index f086847..e9132f7 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2462,6 +2462,127 @@ def _has_ver(expr: ast.AST | None) -> bool: return [] +def detect_bound_method_replay(code: str | SubmissionFacts) -> list[dict]: + """Pattern: bound-method entrypoint replays an instance attribute.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + classes = { + stmt.name: stmt + for stmt in tree.body + if isinstance(stmt, ast.ClassDef) + } + instances: dict[str, str] = {} + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name): + continue + if not (isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name)): + continue + if stmt.value.func.id in classes: + instances[stmt.targets[0].id] = stmt.value.func.id + + def _bound_method(value: ast.AST) -> tuple[str, str] | None: + if not isinstance(value, ast.Attribute): + return None + receiver = value.value + if isinstance(receiver, ast.Call) and isinstance(receiver.func, ast.Name): + if receiver.func.id in classes: + return receiver.func.id, value.attr + if isinstance(receiver, ast.Name) and receiver.id in instances: + return instances[receiver.id], value.attr + return None + + def _self_attr(expr: ast.AST | None, self_name: str) -> str | None: + if not ( + isinstance(expr, ast.Attribute) + and isinstance(expr.value, ast.Name) + and expr.value.id == self_name + ): + return None + return expr.attr + + def _stores_input_result(method: ast.FunctionDef | ast.AsyncFunctionDef, self_name: str, params: set[str]) -> set[str]: + attrs = set() + for stmt in ast.walk(method): + if not isinstance(stmt, ast.Assign): + continue + if not (_expr_names(stmt.value) & params): + continue + if not any(isinstance(node, ast.Call) for node in ast.walk(stmt.value)): + continue + for target in stmt.targets: + attr = _self_attr(target, self_name) + if attr: + attrs.add(attr) + return attrs + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets): + continue + bound = _bound_method(stmt.value) + if bound is None: + continue + class_name, method_name = bound + class_node = classes[class_name] + method = next( + ( + item for item in class_node.body + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) + and item.name == method_name + and item.args.args + ), + None, + ) + if method is None: + continue + + self_name = method.args.args[0].arg + params = {arg.arg for arg in method.args.args[1:]} + params.update(arg.arg for arg in method.args.posonlyargs) + params.update(arg.arg for arg in method.args.kwonlyargs) + stored_attrs = _stores_input_result(method, self_name, params) + if not stored_attrs: + continue + + for child in ast.walk(method): + if not isinstance(child, ast.If): + continue + if _body_has_calls(child.body): + continue + test_attrs = { + node.attr + for node in ast.walk(child.test) + if isinstance(node, ast.Attribute) + and isinstance(node.value, ast.Name) + and node.value.id == self_name + } + returned_attrs = { + attr + for stmt in child.body + if isinstance(stmt, ast.Return) + for attr in [_self_attr(stmt.value, self_name)] + if attr is not None + } + if test_attrs & returned_attrs & stored_attrs: + return [{ + "pattern": "BOUND_METHOD_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} is a bound method that returns " + "an instance attribute cached from a prior input" + ), + }] + + return [] + + RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)") RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=") RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]") @@ -3479,6 +3600,10 @@ class RulePolicy: "LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep", ), + "BOUND_METHOD_REPLAY": RulePolicy( + "BOUND_METHOD_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), + (), "keep", + ), "SHAPE_OUTPUT_REPLAY": RulePolicy( "SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep", @@ -3762,6 +3887,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_decode_mm_ref, detect_result_caching, detect_last_call_replay, + detect_bound_method_replay, detect_shape_output_replay, detect_timed_input_replay, detect_cuda_graph_replay, @@ -3800,6 +3926,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("decode_mm_ref", detect_decode_mm_ref), ("result_caching", detect_result_caching), ("last_call_replay", detect_last_call_replay), + ("bound_method_replay", detect_bound_method_replay), ("shape_output_replay", detect_shape_output_replay), ("timed_input_replay", detect_timed_input_replay), ("cuda_graph_replay", detect_cuda_graph_replay), @@ -4696,7 +4823,7 @@ def _worker_parquet(args: tuple) -> dict: "EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT", "FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS", "TRUSTED_MODULE_IMPORT", - "OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE", + "OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "BOUND_METHOD_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE", "RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY", "TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING", "SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",