From 976a7edbbe111a1acb1f16f6ff0bd7af00e573d4 Mon Sep 17 00:00:00 2001 From: Prasanna Date: Sat, 2 May 2026 19:41:43 +0530 Subject: [PATCH] Detect replay from mutated captured subscripts --- kernelguard.py | 102 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/kernelguard.py b/kernelguard.py index f086847..4ded4cc 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2459,6 +2459,108 @@ def _has_ver(expr: ast.AST | None) -> bool: ), }] + # Fourth pass: replay from a captured container that the entrypoint + # mutates after a state-only fast path. + # + # Pattern: if _state: return _state[0] + # _state.append(data.clone()) + # return _state[0] + params = { + arg.arg + for args in (node.args.posonlyargs, node.args.args, node.args.kwonlyargs) + for arg in args + } + if node.args.vararg is not None: + params.add(node.args.vararg.arg) + if node.args.kwarg is not None: + params.add(node.args.kwarg.arg) + + global_names = { + name + for stmt in node.body + if isinstance(stmt, ast.Global) + for name in stmt.names + } + local_names = { + target.id + for stmt in node.body + if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)) + for target in ( + stmt.targets if isinstance(stmt, ast.Assign) + else [stmt.target] + ) + if isinstance(target, ast.Name) and target.id not in global_names + } + captured = _expr_names(node) - params - local_names + + mutating_methods = { + "add", "append", "clear", "extend", "insert", "pop", "popitem", + "remove", "rotate", "setdefault", "update", + } + mutated_captures: set[str] = set() + subscript_sources: dict[str, str] = {} + + for child in ast.walk(node): + if isinstance(child, ast.Assign): + if isinstance(child.value, ast.Subscript): + source = _ast_root_name(child.value) + if source in captured: + for target in child.targets: + if isinstance(target, ast.Name): + subscript_sources[target.id] = source + for target in child.targets: + target_root = _ast_root_name(target) + if target_root not in captured: + continue + if isinstance(target, ast.Name) and target_root not in global_names: + continue + mutated_captures.add(target_root) + elif isinstance(child, ast.AugAssign): + target_root = _ast_root_name(child.target) + if target_root in captured: + mutated_captures.add(target_root) + elif (isinstance(child, ast.Call) + and isinstance(child.func, ast.Attribute) + and child.func.attr in mutating_methods): + target_root = _ast_root_name(child.func.value) + if target_root in captured: + mutated_captures.add(target_root) + + if mutated_captures: + for child in ast.walk(node): + if not isinstance(child, ast.If): + continue + if _body_has_calls(child.body): + continue + if _expr_names(child.test) & params: + continue + + returns_replayed_subscript = False + for stmt in child.body: + if not isinstance(stmt, ast.Return) or stmt.value is None: + continue + if (isinstance(stmt.value, ast.Name) + and subscript_sources.get(stmt.value.id) in mutated_captures): + returns_replayed_subscript = True + break + if not any(isinstance(part, ast.Subscript) for part in ast.walk(stmt.value)): + continue + if _ast_root_name(stmt.value) in mutated_captures: + returns_replayed_subscript = True + break + + if not returns_replayed_subscript: + continue + + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} returns a subscripted value from " + "captured state mutated by the entrypoint" + ), + }] + return []