Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,107 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_mappingproxy_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: mappingproxy reads replay state from a mutated backing dict."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

backing_dicts = {
stmt.targets[0].id
for stmt in tree.body
if isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and isinstance(stmt.value, ast.Dict)
}
proxies = {}
for stmt in tree.body:
if not (
isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and isinstance(stmt.value, ast.Call)
and len(stmt.value.args) == 1
and isinstance(stmt.value.args[0], ast.Name)
and stmt.value.args[0].id in backing_dicts
):
continue
func = stmt.value.func
if (
(isinstance(func, ast.Name) and func.id == "MappingProxyType")
or (isinstance(func, ast.Attribute) and func.attr == "MappingProxyType")
):
proxies[stmt.targets[0].id] = stmt.value.args[0].id

def _param_names(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
params = {arg.arg for arg in node.args.args}
params.update(arg.arg for arg in node.args.posonlyargs)
params.update(arg.arg for arg in node.args.kwonlyargs)
return params

def _subscript_root(expr: ast.AST | None, names: set[str]) -> str | None:
if isinstance(expr, ast.Subscript):
root = _ast_root_name(expr)
if root in names:
return root
return None

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = _param_names(node)
mutated_backing = {
root
for stmt in ast.walk(node)
if isinstance(stmt, ast.Assign)
and bool(_expr_names(stmt.value) & params)
and any(isinstance(expr, ast.Call) for expr in ast.walk(stmt.value))
for target in stmt.targets
for root in [_subscript_root(target, backing_dicts)]
if root is not None
}
if not mutated_backing:
continue

mutated_proxies = {
proxy
for proxy, backing in proxies.items()
if backing in mutated_backing
}
for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue
if _expr_names(child.test) & params:
continue
tested = _expr_names(child.test) & mutated_proxies
returned = {
root
for stmt in child.body
if isinstance(stmt, ast.Return)
for root in [_subscript_root(stmt.value, mutated_proxies)]
if root is not None
}
if tested & returned:
return [{
"pattern": "MAPPINGPROXY_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns mappingproxy state "
"populated through its backing dict"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3580,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"MAPPINGPROXY_REPLAY": RulePolicy(
"MAPPINGPROXY_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3867,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_mappingproxy_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3906,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("mappingproxy_replay", detect_mappingproxy_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4803,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "MAPPINGPROXY_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down