Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 164 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,163 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_setattr_box_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: replay state stored on an object through setattr."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

strings = {
target.id: stmt.value.value
for stmt in tree.body
if isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and isinstance(stmt.value, ast.Constant)
and isinstance(stmt.value.value, str)
for target in stmt.targets
}

objects = {
stmt.targets[0].id
for stmt in tree.body
if isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and isinstance(stmt.value, ast.Call)
}

def _string_value(expr: ast.AST | None) -> str | None:
if isinstance(expr, ast.Constant) and isinstance(expr.value, str):
return expr.value
if isinstance(expr, ast.Name):
return strings.get(expr.id)
return None

def _object_attr(expr: ast.AST | None) -> tuple[str, str] | None:
if (
isinstance(expr, ast.Attribute)
and isinstance(expr.value, ast.Name)
and expr.value.id in objects
):
return expr.value.id, expr.attr
if (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Name)
and expr.func.id == "getattr"
and len(expr.args) >= 2
and isinstance(expr.args[0], ast.Name)
and expr.args[0].id in objects
):
attr = _string_value(expr.args[1])
if attr:
return expr.args[0].id, attr
return None

def _hasattr_key(expr: ast.AST) -> tuple[str, str] | None:
if not (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Name)
and expr.func.id == "hasattr"
and len(expr.args) >= 2
and isinstance(expr.args[0], ast.Name)
and expr.args[0].id in objects
):
return None
attr = _string_value(expr.args[1])
if attr:
return expr.args[0].id, attr
return None

def _setattr_key(expr: ast.AST, params: set[str]) -> tuple[str, str] | None:
if not isinstance(expr, ast.Call):
return None

args = expr.args
direct_setattr = isinstance(expr.func, ast.Name) and expr.func.id == "setattr"
object_setattr = (
isinstance(expr.func, ast.Attribute)
and expr.func.attr == "__setattr__"
and isinstance(expr.func.value, ast.Name)
and expr.func.value.id == "object"
)
if not (direct_setattr or object_setattr):
return None
if len(args) < 3 or not isinstance(args[0], ast.Name) or args[0].id not in objects:
return None
if not (_expr_names(args[2]) & params):
return None
if not any(isinstance(node, ast.Call) for node in ast.walk(args[2])):
return None

attr = _string_value(args[1])
if attr:
return args[0].id, attr
return None

def _body_has_real_calls(body: list[ast.stmt]) -> bool:
for stmt in body:
for expr in ast.walk(stmt):
if not isinstance(expr, ast.Call):
continue
if _object_attr(expr) is not None:
continue
return True
return False

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue

params = {arg.arg for arg in node.args.args}
params.update(arg.arg for arg in node.args.posonlyargs)
params.update(arg.arg for arg in node.args.kwonlyargs)

mutated = {
key
for child in ast.walk(node)
for key in [_setattr_key(child, params)]
if key is not None
}
if not mutated:
continue

for child in ast.walk(node):
if not isinstance(child, ast.If):
continue
if _body_has_real_calls(child.body):
continue

tested = {
key
for expr in ast.walk(child.test)
for key in (_hasattr_key(expr), _object_attr(expr))
if key is not None
}
returned = {
key
for stmt in child.body
if isinstance(stmt, ast.Return)
for key in [_object_attr(stmt.value)]
if key is not None
}
if tested & returned & mutated:
return [{
"pattern": "SETATTR_BOX_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns object attribute state "
"populated from a prior input through setattr"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3636,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"SETATTR_BOX_REPLAY": RulePolicy(
"SETATTR_BOX_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3923,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_setattr_box_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3962,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("setattr_box_replay", detect_setattr_box_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4859,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SETATTR_BOX_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down