Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,80 @@ def _has_ver(expr: ast.AST | None) -> bool:
),
}]

# Quaternary pass: flag/counter-gated replay (no identity comparison).
# Pattern:
# _OUT = None
# _N = 0
# def custom_kernel(data):
# global _OUT, _N
# _N += 1
# if _N > 1 and _OUT is not None:
# return _OUT
# _OUT = data.clone()
# return _OUT
# The earlier passes all require an identity-based test that pairs a
# captured slot against the input. This pattern has no such pairing —
# the gate is purely on a module-level counter or flag, and the early
# return drops the input on the floor.
param_names = {a.arg for a in node.args.args}
param_names.update(a.arg for a in getattr(node.args, "posonlyargs", []) or [])
param_names.update(a.arg for a in getattr(node.args, "kwonlyargs", []) or [])

global_names: set[str] = set()
for n in ast.walk(node):
if isinstance(n, ast.Global):
global_names.update(n.names)

if global_names:
global_assigned: set[str] = set()
for child in ast.walk(node):
if isinstance(child, ast.Assign):
for target in child.targets:
root = _ast_root_name(target)
if root and root in global_names:
global_assigned.add(root)
elif isinstance(child, ast.AugAssign):
root = _ast_root_name(child.target)
if root and root in global_names:
global_assigned.add(root)

if global_assigned:
for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue

test_names = _expr_names(child.test)
# Test must reference at least one mutated global, and must
# NOT reference any input parameter — that's the smoking gun:
# the early-return decision ignores the input entirely.
if not (test_names & global_assigned):
continue
if test_names & param_names:
continue

saved_return = False
for stmt in child.body:
if not isinstance(stmt, ast.Return) or stmt.value is None:
continue
rroot = _ast_root_name(stmt.value)
if rroot and rroot in global_assigned:
saved_return = True
break
if not saved_return:
continue

return [{
"pattern": "LAST_CALL_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns saved output gated only "
f"by a module-level counter/flag — the early-return "
f"path does not reference the input"
),
}]

return []


Expand Down