Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 135 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,134 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_builtins_attr_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: builtins attribute stores and replays prior input."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

strings = {
stmt.targets[0].id: stmt.value.value
for stmt in tree.body
if isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)
}

def _param_names(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
params = {arg.arg for arg in node.args.args}
params.update(arg.arg for arg in node.args.posonlyargs)
params.update(arg.arg for arg in node.args.kwonlyargs)
return params

def _string_value(expr: ast.AST | None) -> str | None:
if isinstance(expr, ast.Constant) and isinstance(expr.value, str):
return expr.value
if isinstance(expr, ast.Name):
return strings.get(expr.id)
return None

def _builtins_getattr(expr: ast.AST | None) -> str | None:
if not (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Name)
and expr.func.id == "getattr"
and len(expr.args) >= 2
and isinstance(expr.args[0], ast.Name)
and expr.args[0].id == "builtins"
):
return None
return _string_value(expr.args[1])

def _builtins_hasattr(expr: ast.AST | None) -> str | None:
if not (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Name)
and expr.func.id == "hasattr"
and len(expr.args) >= 2
and isinstance(expr.args[0], ast.Name)
and expr.args[0].id == "builtins"
):
return None
return _string_value(expr.args[1])

def _builtins_setattr(expr: ast.AST | None, params: set[str]) -> str | None:
if not (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Name)
and expr.func.id == "setattr"
and len(expr.args) >= 3
and isinstance(expr.args[0], ast.Name)
and expr.args[0].id == "builtins"
and bool(_expr_names(expr.args[2]) & params)
and any(isinstance(node, ast.Call) for node in ast.walk(expr.args[2]))
):
return None
return _string_value(expr.args[1])

def _body_has_only_builtins_getattr(body: list[ast.stmt]) -> bool:
for stmt in body:
for expr in ast.walk(stmt):
if not isinstance(expr, ast.Call):
continue
if _builtins_getattr(expr) is not None:
continue
return False
return True

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = _param_names(node)
mutated = {
attr
for expr in ast.walk(node)
for attr in [_builtins_setattr(expr, params)]
if attr is not None
}
if not mutated:
continue

for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if not _body_has_only_builtins_getattr(child.body):
continue
if _expr_names(child.test) & params:
continue
tested = {
attr
for expr in ast.walk(child.test)
for attr in (_builtins_hasattr(expr), _builtins_getattr(expr))
if attr is not None
}
returned = {
attr
for stmt in child.body
if isinstance(stmt, ast.Return)
for attr in [_builtins_getattr(stmt.value)]
if attr is not None
}
if tested & returned & mutated:
return [{
"pattern": "BUILTINS_ATTR_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns builtins attribute state "
"populated from a prior input"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3607,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"BUILTINS_ATTR_REPLAY": RulePolicy(
"BUILTINS_ATTR_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3894,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_builtins_attr_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3933,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("builtins_attr_replay", detect_builtins_attr_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4830,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "BUILTINS_ATTR_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down