Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 160 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,159 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_mutated_capture_subscript_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: captured container mutation followed by subscript replay."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)
mutating_methods = {
"__setitem__", "append", "extend", "insert", "update", "setdefault",
"add", "push", "put", "extendleft",
}

def _param_names(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
params = {arg.arg for arg in node.args.args}
params.update(arg.arg for arg in node.args.posonlyargs)
params.update(arg.arg for arg in node.args.kwonlyargs)
if node.args.vararg:
params.add(node.args.vararg.arg)
if node.args.kwarg:
params.add(node.args.kwarg.arg)
return params

def _top_level_targets(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
targets = set()
global_names = {
name
for stmt in node.body
if isinstance(stmt, ast.Global)
for name in stmt.names
}
for stmt in node.body:
if isinstance(stmt, ast.Assign):
for target in stmt.targets:
if isinstance(target, ast.Name) and target.id not in global_names:
targets.add(target.id)
if isinstance(stmt, ast.AnnAssign):
if isinstance(stmt.target, ast.Name) and stmt.target.id not in global_names:
targets.add(stmt.target.id)
if isinstance(stmt, ast.AugAssign):
if isinstance(stmt.target, ast.Name) and stmt.target.id not in global_names:
targets.add(stmt.target.id)
if isinstance(stmt, (ast.For, ast.AsyncFor)):
if isinstance(stmt.target, ast.Name) and stmt.target.id not in global_names:
targets.add(stmt.target.id)
if isinstance(stmt, ast.With):
for item in stmt.items:
if isinstance(item.optional_vars, ast.Name) and item.optional_vars.id not in global_names:
targets.add(item.optional_vars.id)
return targets

def _captured_names(node: ast.FunctionDef | ast.AsyncFunctionDef, params: set[str]) -> set[str]:
local_targets = _top_level_targets(node)
referenced = {
expr.id
for expr in ast.walk(node)
if isinstance(expr, ast.Name) and isinstance(expr.ctx, ast.Load)
}
return referenced - params - local_targets

def _subscript_root(expr: ast.AST | None, roots: set[str]) -> str | None:
if isinstance(expr, ast.Subscript):
root = _ast_root_name(expr)
if root in roots:
return root
return None

def _input_derived(expr: ast.AST | None, params: set[str]) -> bool:
return bool(_expr_names(expr) & params)

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = _param_names(node)
captures = _captured_names(node, params)
mutated = set()
global_names = {
name
for stmt in node.body
if isinstance(stmt, ast.Global)
for name in stmt.names
}

for expr in ast.walk(node):
if isinstance(expr, ast.AugAssign):
root = _ast_root_name(expr.target)
if root in captures and _input_derived(expr.value, params):
mutated.add(root)
if isinstance(expr, ast.Assign):
if not _input_derived(expr.value, params):
continue
for target in expr.targets:
root = _ast_root_name(target)
if root in captures and root in global_names:
mutated.add(root)
if not isinstance(expr, ast.Call):
continue
if not (
isinstance(expr.func, ast.Attribute)
and expr.func.attr in mutating_methods
):
continue
root = _ast_root_name(expr.func.value)
if root in captures and any(_input_derived(arg, params) for arg in expr.args):
mutated.add(root)

if not mutated:
continue

local_subscripts = {
target.id
for stmt in ast.walk(node)
if isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and _subscript_root(stmt.value, mutated) is not None
for target in stmt.targets
}

for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue
if _expr_names(child.test) & params:
continue
for stmt in child.body:
if not isinstance(stmt, ast.Return):
continue
if _subscript_root(stmt.value, mutated) is not None:
return [{
"pattern": "MUTATED_CAPTURE_SUBSCRIPT_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns subscripted state "
"from a mutated captured container"
),
}]
if isinstance(stmt.value, ast.Name) and stmt.value.id in local_subscripts:
return [{
"pattern": "MUTATED_CAPTURE_SUBSCRIPT_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns a local alias of "
"subscripted mutated captured state"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3632,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"MUTATED_CAPTURE_SUBSCRIPT_REPLAY": RulePolicy(
"MUTATED_CAPTURE_SUBSCRIPT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3919,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_mutated_capture_subscript_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3958,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("mutated_capture_subscript_replay", detect_mutated_capture_subscript_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4855,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "MUTATED_CAPTURE_SUBSCRIPT_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down