Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,23 @@ def _expr_has_tensor_version(expr: ast.AST | None) -> bool:
)


def _function_param_names(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> tuple[str, ...]:
names = [arg.arg for arg in fn.args.posonlyargs]
names.extend(arg.arg for arg in fn.args.args)
names.extend(arg.arg for arg in fn.args.kwonlyargs)
if fn.args.vararg is not None:
names.append(fn.args.vararg.arg)
if fn.args.kwarg is not None:
names.append(fn.args.kwarg.arg)
return tuple(names)


def _depends_on_names(expr: ast.AST | None, names: set[str]) -> bool:
if expr is None or not names:
return False
return bool(_expr_names(expr) & names)


_TRIVIAL_GPU_OPS = frozenset({
"fill_", "zero_", "copy_", "fill", "zero", "record",
})
Expand Down Expand Up @@ -2462,6 +2479,136 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_generator_send_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: persistent primed generator replays yielded state via send()."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None or not isinstance(tree, ast.Module):
return []

functions = {
node.name: node
for node in tree.body
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
}
generator_builders: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = {}
primed_generators: set[str] = set()

for stmt in tree.body:
if not isinstance(stmt, ast.Assign):
continue
if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name):
continue
if not isinstance(stmt.value, ast.Call) or not isinstance(stmt.value.func, ast.Name):
continue
builder = functions.get(stmt.value.func.id)
if builder is None:
continue
if any(isinstance(node, (ast.Yield, ast.YieldFrom)) for node in ast.walk(builder)):
generator_builders[stmt.targets[0].id] = builder

for stmt in tree.body:
if not isinstance(stmt, ast.Expr) or not isinstance(stmt.value, ast.Call):
continue
call = stmt.value
if isinstance(call.func, ast.Name) and call.func.id == "next" and len(call.args) == 1:
if isinstance(call.args[0], ast.Name) and call.args[0].id in generator_builders:
primed_generators.add(call.args[0].id)
continue
if not (
isinstance(call.func, ast.Attribute)
and call.func.attr == "send"
and isinstance(call.func.value, ast.Name)
and call.func.value.id in generator_builders
and len(call.args) == 1
and isinstance(call.args[0], ast.Constant)
and call.args[0].value is None
):
continue
primed_generators.add(call.func.value.id)

if not primed_generators:
return []

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) or not is_entrypoint_name(node.name):
continue
params = set(_function_param_names(node))
for stmt in node.body:
if not isinstance(stmt, ast.Return) or not isinstance(stmt.value, ast.Call):
continue
call = stmt.value
if not (
isinstance(call.func, ast.Attribute)
and call.func.attr == "send"
and isinstance(call.func.value, ast.Name)
and call.func.value.id in primed_generators
and len(call.args) == 1
and _expr_names(call.args[0]) & params
):
continue

builder = generator_builders[call.func.value.id]
yielded_names = {
child.value.id
for child in ast.walk(builder)
if isinstance(child, ast.Yield) and isinstance(child.value, ast.Name)
}
resume_names = {
target.id
for child in ast.walk(builder)
if isinstance(child, ast.Assign)
and len(child.targets) == 1
and isinstance(child.targets[0], ast.Name)
and isinstance(child.value, ast.Yield)
for target in child.targets
}
if not yielded_names or not resume_names:
continue

resume_derived = set(resume_names)
changed = True
while changed:
changed = False
for child in ast.walk(builder):
if not isinstance(child, ast.Assign):
continue
if len(child.targets) != 1 or not isinstance(child.targets[0], ast.Name):
continue
if not _depends_on_names(child.value, resume_derived):
continue
if child.targets[0].id in resume_derived:
continue
resume_derived.add(child.targets[0].id)
changed = True

for child in ast.walk(builder):
if not isinstance(child, ast.If):
continue
if _expr_names(child.test) & resume_derived:
continue
for branch in (child.body, child.orelse):
if not any(
isinstance(stmt, ast.Assign)
and len(stmt.targets) == 1
and isinstance(stmt.targets[0], ast.Name)
and stmt.targets[0].id in yielded_names
and _depends_on_names(stmt.value, resume_derived)
for stmt in branch
):
continue
return [{
"pattern": "GENERATOR_SEND_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_label(node.name)} replays state from a primed "
f"generator via {call.func.value.id}.send(...)"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3626,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"GENERATOR_SEND_REPLAY": RulePolicy(
"GENERATOR_SEND_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3913,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_generator_send_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3952,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("generator_send_replay", detect_generator_send_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down