Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,127 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_bound_method_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: bound-method entrypoint replays an instance attribute."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

classes = {
stmt.name: stmt
for stmt in tree.body
if isinstance(stmt, ast.ClassDef)
}
instances: dict[str, str] = {}
for stmt in tree.body:
if not isinstance(stmt, ast.Assign):
continue
if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name):
continue
if not (isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name)):
continue
if stmt.value.func.id in classes:
instances[stmt.targets[0].id] = stmt.value.func.id

def _bound_method(value: ast.AST) -> tuple[str, str] | None:
if not isinstance(value, ast.Attribute):
return None
receiver = value.value
if isinstance(receiver, ast.Call) and isinstance(receiver.func, ast.Name):
if receiver.func.id in classes:
return receiver.func.id, value.attr
if isinstance(receiver, ast.Name) and receiver.id in instances:
return instances[receiver.id], value.attr
return None

def _self_attr(expr: ast.AST | None, self_name: str) -> str | None:
if not (
isinstance(expr, ast.Attribute)
and isinstance(expr.value, ast.Name)
and expr.value.id == self_name
):
return None
return expr.attr

def _stores_input_result(method: ast.FunctionDef | ast.AsyncFunctionDef, self_name: str, params: set[str]) -> set[str]:
attrs = set()
for stmt in ast.walk(method):
if not isinstance(stmt, ast.Assign):
continue
if not (_expr_names(stmt.value) & params):
continue
if not any(isinstance(node, ast.Call) for node in ast.walk(stmt.value)):
continue
for target in stmt.targets:
attr = _self_attr(target, self_name)
if attr:
attrs.add(attr)
return attrs

for stmt in tree.body:
if not isinstance(stmt, ast.Assign):
continue
if not any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets):
continue
bound = _bound_method(stmt.value)
if bound is None:
continue
class_name, method_name = bound
class_node = classes[class_name]
method = next(
(
item for item in class_node.body
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef))
and item.name == method_name
and item.args.args
),
None,
)
if method is None:
continue

self_name = method.args.args[0].arg
params = {arg.arg for arg in method.args.args[1:]}
params.update(arg.arg for arg in method.args.posonlyargs)
params.update(arg.arg for arg in method.args.kwonlyargs)
stored_attrs = _stores_input_result(method, self_name, params)
if not stored_attrs:
continue

for child in ast.walk(method):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue
test_attrs = {
node.attr
for node in ast.walk(child.test)
if isinstance(node, ast.Attribute)
and isinstance(node.value, ast.Name)
and node.value.id == self_name
}
returned_attrs = {
attr
for stmt in child.body
if isinstance(stmt, ast.Return)
for attr in [_self_attr(stmt.value, self_name)]
if attr is not None
}
if test_attrs & returned_attrs & stored_attrs:
return [{
"pattern": "BOUND_METHOD_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} is a bound method that returns "
"an instance attribute cached from a prior input"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3600,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"BOUND_METHOD_REPLAY": RulePolicy(
"BOUND_METHOD_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3887,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_bound_method_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3926,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("bound_method_replay", detect_bound_method_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4823,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "BOUND_METHOD_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down