diff --git a/kernelguard.py b/kernelguard.py index f086847..661df1d 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2459,6 +2459,88 @@ def _has_ver(expr: ast.AST | None) -> bool: ), }] + # Fourth pass: try-fast-path replay where the exception path fills + # the captured state. + # + # Pattern: try: return _state[0] + # except IndexError: _state.append(data.clone()) + params = { + arg.arg + for args in (node.args.posonlyargs, node.args.args, node.args.kwonlyargs) + for arg in args + } + if node.args.vararg is not None: + params.add(node.args.vararg.arg) + if node.args.kwarg is not None: + params.add(node.args.kwarg.arg) + + global_names = { + name + for stmt in node.body + if isinstance(stmt, ast.Global) + for name in stmt.names + } + local_names = { + target.id + for stmt in node.body + if isinstance(stmt, (ast.Assign, ast.AnnAssign, ast.AugAssign)) + for target in ( + stmt.targets if isinstance(stmt, ast.Assign) + else [stmt.target] + ) + if isinstance(target, ast.Name) and target.id not in global_names + } + captured = _expr_names(node) - params - local_names + + mutating_methods = { + "add", "append", "clear", "extend", "insert", "pop", "popitem", + "remove", "rotate", "setdefault", "update", + } + for child in ast.walk(node): + if not isinstance(child, ast.Try): + continue + if _body_has_calls(child.body): + continue + returned = { + _ast_root_name(stmt.value) + for stmt in child.body + if isinstance(stmt, ast.Return) + and isinstance(stmt.value, ast.Subscript) + } - {None} + replay_roots = returned & captured + if not replay_roots: + continue + + mutated = set() + for handler in child.handlers: + for stmt in ast.walk(handler): + if isinstance(stmt, ast.Assign): + for target in stmt.targets: + target_root = _ast_root_name(target) + if target_root in replay_roots: + mutated.add(target_root) + elif isinstance(stmt, ast.AugAssign): + target_root = _ast_root_name(stmt.target) + if target_root in replay_roots: + mutated.add(target_root) + elif (isinstance(stmt, ast.Call) + and isinstance(stmt.func, ast.Attribute) + and stmt.func.attr in mutating_methods): + target_root = _ast_root_name(stmt.func.value) + if target_root in replay_roots: + mutated.add(target_root) + if not mutated: + continue + + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} tries to return captured state and " + "fills it in the exception path" + ), + }] + return []