From 0c09263a76850faafb36796ce39515a47a9b30c6 Mon Sep 17 00:00:00 2001 From: Prasanna Date: Sat, 2 May 2026 10:42:53 +0530 Subject: [PATCH] Detect partial namespace replay --- kernelguard.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 1 deletion(-) diff --git a/kernelguard.py b/kernelguard.py index f086847..7eeedd3 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2462,6 +2462,118 @@ def _has_ver(expr: ast.AST | None) -> bool: return [] +def detect_partial_namespace_replay(code: str | SubmissionFacts) -> list[dict]: + """Pattern: functools.partial binds a namespace used as replay state.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + + functions = { + stmt.name: stmt + for stmt in tree.body + if isinstance(stmt, ast.FunctionDef) + } + + def _is_partial(expr: ast.AST | None) -> bool: + return ( + isinstance(expr, ast.Name) and expr.id == "partial" + ) or ( + isinstance(expr, ast.Attribute) + and expr.attr == "partial" + and isinstance(expr.value, ast.Name) + and expr.value.id == "functools" + ) + + def _is_namespace_none(expr: ast.AST | None) -> bool: + return ( + isinstance(expr, ast.Call) + and isinstance(expr.func, ast.Name) + and expr.func.id == "SimpleNamespace" + and any( + kw.arg is not None + and isinstance(kw.value, ast.Constant) + and kw.value.value is None + for kw in expr.keywords + ) + ) + + partials = [] + for stmt in tree.body: + if not ( + isinstance(stmt, ast.Assign) + and any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets) + and isinstance(stmt.value, ast.Call) + and _is_partial(stmt.value.func) + and len(stmt.value.args) >= 2 + and isinstance(stmt.value.args[0], ast.Name) + and stmt.value.args[0].id in functions + and _is_namespace_none(stmt.value.args[1]) + ): + continue + partials.append(functions[stmt.value.args[0].id]) + + for fn in partials: + if len(fn.args.args) < 2: + continue + state_param = fn.args.args[0].arg + data_params = {arg.arg for arg in fn.args.args[1:]} + data_params.update(arg.arg for arg in fn.args.posonlyargs) + data_params.update(arg.arg for arg in fn.args.kwonlyargs) + + mutated = { + target.attr + for stmt in ast.walk(fn) + if isinstance(stmt, ast.Assign) + and bool(_expr_names(stmt.value) & data_params) + for target in stmt.targets + if isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id == state_param + } + if not mutated: + continue + + for child in ast.walk(fn): + if not isinstance(child, ast.If): + continue + if _expr_names(child.test) & data_params: + continue + tested = { + expr.attr + for expr in ast.walk(child.test) + if isinstance(expr, ast.Attribute) + and isinstance(expr.value, ast.Name) + and expr.value.id == state_param + and expr.attr in mutated + } + if not tested: + continue + if _body_has_calls(child.body): + continue + returned = { + stmt.value.attr + for stmt in child.body + if isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Attribute) + and isinstance(stmt.value.value, ast.Name) + and stmt.value.value.id == state_param + and stmt.value.attr in mutated + } + if tested & returned: + return [{ + "pattern": "PARTIAL_NAMESPACE_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} replays output through a " + "SimpleNamespace bound by partial" + ), + }] + + return [] + + RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)") RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=") RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]") @@ -3479,6 +3591,10 @@ class RulePolicy: "LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep", ), + "PARTIAL_NAMESPACE_REPLAY": RulePolicy( + "PARTIAL_NAMESPACE_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), + (), "keep", + ), "SHAPE_OUTPUT_REPLAY": RulePolicy( "SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep", @@ -3762,6 +3878,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_decode_mm_ref, detect_result_caching, detect_last_call_replay, + detect_partial_namespace_replay, detect_shape_output_replay, detect_timed_input_replay, detect_cuda_graph_replay, @@ -3800,6 +3917,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("decode_mm_ref", detect_decode_mm_ref), ("result_caching", detect_result_caching), ("last_call_replay", detect_last_call_replay), + ("partial_namespace_replay", detect_partial_namespace_replay), ("shape_output_replay", detect_shape_output_replay), ("timed_input_replay", detect_timed_input_replay), ("cuda_graph_replay", detect_cuda_graph_replay), @@ -4696,7 +4814,7 @@ def _worker_parquet(args: tuple) -> dict: "EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT", "FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS", "TRUSTED_MODULE_IMPORT", - "OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE", + "OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "PARTIAL_NAMESPACE_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE", "RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY", "TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING", "SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",