diff --git a/kernelguard.py b/kernelguard.py index f086847..5447a41 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2459,6 +2459,95 @@ def _has_ver(expr: ast.AST | None) -> bool: ), }] + # Fourth pass: replay from an attribute on captured state that the + # entrypoint fills after a state-only fast path. + # + # Pattern: if hasattr(_state, "value"): return _state.value + # _state.value = data.clone() + # return _state.value + params = { + arg.arg + for args in (node.args.posonlyargs, node.args.args, node.args.kwonlyargs) + for arg in args + } + if node.args.vararg is not None: + params.add(node.args.vararg.arg) + if node.args.kwarg is not None: + params.add(node.args.kwarg.arg) + + global_names = { + name + for stmt in node.body + if isinstance(stmt, ast.Global) + for name in stmt.names + } + local_names = { + target.id + for stmt in node.body + if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)) + for target in ( + stmt.targets if isinstance(stmt, ast.Assign) + else [stmt.target] + ) + if isinstance(target, ast.Name) and target.id not in global_names + } + captured = _expr_names(node) - params - local_names + + mutated_attrs: set[tuple[str, str]] = set() + attr_sources: dict[str, tuple[str, str]] = {} + + for child in ast.walk(node): + if not isinstance(child, ast.Assign): + continue + if isinstance(child.value, ast.Attribute): + source_root = _ast_root_name(child.value) + if source_root in captured: + for target in child.targets: + if isinstance(target, ast.Name): + attr_sources[target.id] = (source_root, child.value.attr) + for target in child.targets: + if not isinstance(target, ast.Attribute): + continue + target_root = _ast_root_name(target) + if target_root in captured: + mutated_attrs.add((target_root, target.attr)) + + if mutated_attrs: + for child in ast.walk(node): + if not isinstance(child, ast.If): + continue + if _body_has_calls(child.body): + continue + if _expr_names(child.test) & params: + continue + + returns_replayed_attr = False + for stmt in child.body: + if not isinstance(stmt, ast.Return) or stmt.value is None: + continue + if (isinstance(stmt.value, ast.Name) + and attr_sources.get(stmt.value.id) in mutated_attrs): + returns_replayed_attr = True + break + if not isinstance(stmt.value, ast.Attribute): + continue + attr_key = (_ast_root_name(stmt.value), stmt.value.attr) + if attr_key in mutated_attrs: + returns_replayed_attr = True + break + + if not returns_replayed_attr: + continue + + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} returns an attribute from captured " + "state mutated by the entrypoint" + ), + }] + return []