Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,11 +2197,75 @@ def _has_ver(expr: ast.AST | None) -> bool:
for n in ast.walk(expr)
)

def _input_derived(expr: ast.AST | None, params: set[str]) -> bool:
return bool(_expr_names(expr) & params)

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue
params = {arg.arg for arg in node.args.args}

# First-call state replay:
# if _state is None:
# _state = compute(input)
# return _state
#
# Or the inverse fast path:
# if _state is not None:
# return _state
# _state = compute(input)
# return _state
#
# The guard is input-independent, the slot appears in that guard, and
# the returned slot is assigned from a live input in the entrypoint.
for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _expr_names(child.test) & params:
continue
if child not in node.body:
continue

test_roots = _expr_names(child.test)
returned_in_body = {
_ast_root_name(stmt.value)
for stmt in child.body
if isinstance(stmt, ast.Return) and stmt.value is not None
} - {None}
assigned_from_input_in_body = {
_ast_root_name(target)
for stmt in child.body
if isinstance(stmt, ast.Assign) and _input_derived(stmt.value, params)
for target in stmt.targets
} - {None}
later_statements = node.body[node.body.index(child) + 1:]
returned_after_if = {
_ast_root_name(stmt.value)
for stmt in later_statements
if isinstance(stmt, ast.Return) and stmt.value is not None
} - {None}
assigned_after_if = {
_ast_root_name(target)
for stmt in later_statements
if isinstance(stmt, ast.Assign) and _input_derived(stmt.value, params)
for target in stmt.targets
} - {None}

init_then_return = assigned_from_input_in_body & returned_after_if
fast_return_then_init = returned_in_body & assigned_after_if & returned_after_if
replay_roots = (init_then_return | fast_return_then_init) & test_roots
if replay_roots:
root = next(iter(replay_roots))
return [{
"pattern": "LAST_CALL_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns first-call output from "
f"input-independent state slot {root}"
),
}]

signature_features: dict[str, set[str]] = defaultdict(set)
saved_state_features: dict[str, set[str]] = defaultdict(set)
Expand Down