diff --git a/kernelguard.py b/kernelguard.py index f086847..30f0416 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2197,11 +2197,75 @@ def _has_ver(expr: ast.AST | None) -> bool: for n in ast.walk(expr) ) + def _input_derived(expr: ast.AST | None, params: set[str]) -> bool: + return bool(_expr_names(expr) & params) + for node in ast.walk(tree): if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): continue if not is_entrypoint_name(node.name): continue + params = {arg.arg for arg in node.args.args} + + # First-call state replay: + # if _state is None: + # _state = compute(input) + # return _state + # + # Or the inverse fast path: + # if _state is not None: + # return _state + # _state = compute(input) + # return _state + # + # The guard is input-independent, the slot appears in that guard, and + # the returned slot is assigned from a live input in the entrypoint. + for child in ast.walk(node): + if not isinstance(child, ast.If): + continue + if _expr_names(child.test) & params: + continue + if child not in node.body: + continue + + test_roots = _expr_names(child.test) + returned_in_body = { + _ast_root_name(stmt.value) + for stmt in child.body + if isinstance(stmt, ast.Return) and stmt.value is not None + } - {None} + assigned_from_input_in_body = { + _ast_root_name(target) + for stmt in child.body + if isinstance(stmt, ast.Assign) and _input_derived(stmt.value, params) + for target in stmt.targets + } - {None} + later_statements = node.body[node.body.index(child) + 1:] + returned_after_if = { + _ast_root_name(stmt.value) + for stmt in later_statements + if isinstance(stmt, ast.Return) and stmt.value is not None + } - {None} + assigned_after_if = { + _ast_root_name(target) + for stmt in later_statements + if isinstance(stmt, ast.Assign) and _input_derived(stmt.value, params) + for target in stmt.targets + } - {None} + + init_then_return = assigned_from_input_in_body & returned_after_if + fast_return_then_init = returned_in_body & assigned_after_if & returned_after_if + replay_roots = (init_then_return | fast_return_then_init) & test_roots + if replay_roots: + root = next(iter(replay_roots)) + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} returns first-call output from " + f"input-independent state slot {root}" + ), + }] signature_features: dict[str, set[str]] = defaultdict(set) saved_state_features: dict[str, set[str]] = defaultdict(set)