Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,125 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_object_dict_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: object __dict__ alias stores and replays prior input."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

objects = set()
for stmt in tree.body:
if not (
isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and isinstance(stmt.value, ast.Call)
):
continue
func_name = _ast_root_name(stmt.value.func)
if func_name == "ModuleType":
continue
objects.add(stmt.targets[0].id)

def _param_names(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
params = {arg.arg for arg in node.args.args}
params.update(arg.arg for arg in node.args.posonlyargs)
params.update(arg.arg for arg in node.args.kwonlyargs)
return params

def _dict_alias(stmt: ast.stmt) -> tuple[str, str] | None:
if not (
isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
):
return None
if (
isinstance(stmt.value, ast.Attribute)
and stmt.value.attr == "__dict__"
and isinstance(stmt.value.value, ast.Name)
and stmt.value.value.id in objects
):
return stmt.targets[0].id, stmt.value.value.id
if (
isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id == "vars"
and len(stmt.value.args) == 1
and isinstance(stmt.value.args[0], ast.Name)
and stmt.value.args[0].id in objects
):
return stmt.targets[0].id, stmt.value.args[0].id
return None

def _alias_subscript(expr: ast.AST | None, aliases: set[str]) -> str | None:
if not isinstance(expr, ast.Subscript):
return None
root = _ast_root_name(expr)
if root in aliases:
return root
return None

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = _param_names(node)
aliases = {
alias
for stmt in ast.walk(node)
for pair in [_dict_alias(stmt)]
if pair is not None
for alias in [pair[0]]
}
if not aliases:
continue

mutated = {
alias
for stmt in ast.walk(node)
if isinstance(stmt, ast.Assign)
and bool(_expr_names(stmt.value) & params)
and any(isinstance(expr, ast.Call) for expr in ast.walk(stmt.value))
for target in stmt.targets
for alias in [_alias_subscript(target, aliases)]
if alias is not None
}
if not mutated:
continue

for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue
if _expr_names(child.test) & params:
continue
tested = _expr_names(child.test) & mutated
returned = {
alias
for stmt in child.body
if isinstance(stmt, ast.Return)
for alias in [_alias_subscript(stmt.value, mutated)]
if alias is not None
}
if tested & returned:
return [{
"pattern": "OBJECT_DICT_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns object __dict__ state "
"populated from a prior input"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3598,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"OBJECT_DICT_REPLAY": RulePolicy(
"OBJECT_DICT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3885,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_object_dict_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3924,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("object_dict_replay", detect_object_dict_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4821,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "OBJECT_DICT_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down