Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 106 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,105 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_nonlocal_value_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: factory-built entrypoint replays a nonlocal saved output."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

top_functions = {
stmt.name: stmt
for stmt in tree.body
if isinstance(stmt, ast.FunctionDef)
}
builders = {
stmt.value.func.id
for stmt in tree.body
if isinstance(stmt, ast.Assign)
and any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets)
and isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id in top_functions
}

def _param_names(node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
params = {arg.arg for arg in node.args.args}
params.update(arg.arg for arg in node.args.posonlyargs)
params.update(arg.arg for arg in node.args.kwonlyargs)
return params

for builder_name in builders:
builder = top_functions[builder_name]
initialized = {
target.id
for stmt in builder.body
if isinstance(stmt, ast.Assign)
and isinstance(stmt.value, ast.Constant)
and stmt.value.value is None
for target in stmt.targets
if isinstance(target, ast.Name)
}
returned_inner = {
stmt.value.id
for stmt in builder.body
if isinstance(stmt, ast.Return)
and isinstance(stmt.value, ast.Name)
}

for inner in builder.body:
if not isinstance(inner, ast.FunctionDef) or inner.name not in returned_inner:
continue
params = _param_names(inner)
nonlocals = {
name
for stmt in inner.body
if isinstance(stmt, ast.Nonlocal)
for name in stmt.names
} & initialized
if not nonlocals:
continue

input_derived_assigns = {
target.id
for stmt in ast.walk(inner)
if isinstance(stmt, ast.Assign)
and bool(_expr_names(stmt.value) & params)
for target in stmt.targets
if isinstance(target, ast.Name) and target.id in nonlocals
}
if not input_derived_assigns:
continue

for child in ast.walk(inner):
if not isinstance(child, ast.If):
continue
if _expr_names(child.test) & params:
continue
if not (_expr_names(child.test) & input_derived_assigns):
continue
if _body_has_calls(child.body):
continue
returned = {
stmt.value.id
for stmt in child.body
if isinstance(stmt, ast.Return)
and isinstance(stmt.value, ast.Name)
}
if returned & input_derived_assigns:
return [{
"pattern": "NONLOCAL_VALUE_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} returns a nonlocal output "
"captured by its factory"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3578,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"NONLOCAL_VALUE_REPLAY": RulePolicy(
"NONLOCAL_VALUE_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3865,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_nonlocal_value_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3904,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("nonlocal_value_replay", detect_nonlocal_value_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4801,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "NONLOCAL_VALUE_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down