diff --git a/kernelguard.py b/kernelguard.py index f086847..eb8830b 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -524,6 +524,23 @@ def _expr_has_tensor_version(expr: ast.AST | None) -> bool: ) +def _function_param_names(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> tuple[str, ...]: + names = [arg.arg for arg in fn.args.posonlyargs] + names.extend(arg.arg for arg in fn.args.args) + names.extend(arg.arg for arg in fn.args.kwonlyargs) + if fn.args.vararg is not None: + names.append(fn.args.vararg.arg) + if fn.args.kwarg is not None: + names.append(fn.args.kwarg.arg) + return tuple(names) + + +def _depends_on_names(expr: ast.AST | None, names: set[str]) -> bool: + if expr is None or not names: + return False + return bool(_expr_names(expr) & names) + + _TRIVIAL_GPU_OPS = frozenset({ "fill_", "zero_", "copy_", "fill", "zero", "record", }) @@ -2462,6 +2479,136 @@ def _has_ver(expr: ast.AST | None) -> bool: return [] +def detect_generator_send_replay(code: str | SubmissionFacts) -> list[dict]: + """Pattern: persistent primed generator replays yielded state via send().""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None or not isinstance(tree, ast.Module): + return [] + + functions = { + node.name: node + for node in tree.body + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + generator_builders: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = {} + primed_generators: set[str] = set() + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name): + continue + if not isinstance(stmt.value, ast.Call) or not isinstance(stmt.value.func, ast.Name): + continue + builder = functions.get(stmt.value.func.id) + if builder is None: + continue + if any(isinstance(node, (ast.Yield, ast.YieldFrom)) for node in ast.walk(builder)): + generator_builders[stmt.targets[0].id] = builder + + for stmt in tree.body: + if not isinstance(stmt, ast.Expr) or not isinstance(stmt.value, ast.Call): + continue + call = stmt.value + if isinstance(call.func, ast.Name) and call.func.id == "next" and len(call.args) == 1: + if isinstance(call.args[0], ast.Name) and call.args[0].id in generator_builders: + primed_generators.add(call.args[0].id) + continue + if not ( + isinstance(call.func, ast.Attribute) + and call.func.attr == "send" + and isinstance(call.func.value, ast.Name) + and call.func.value.id in generator_builders + and len(call.args) == 1 + and isinstance(call.args[0], ast.Constant) + and call.args[0].value is None + ): + continue + primed_generators.add(call.func.value.id) + + if not primed_generators: + return [] + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) or not is_entrypoint_name(node.name): + continue + params = set(_function_param_names(node)) + for stmt in node.body: + if not isinstance(stmt, ast.Return) or not isinstance(stmt.value, ast.Call): + continue + call = stmt.value + if not ( + isinstance(call.func, ast.Attribute) + and call.func.attr == "send" + and isinstance(call.func.value, ast.Name) + and call.func.value.id in primed_generators + and len(call.args) == 1 + and _expr_names(call.args[0]) & params + ): + continue + + builder = generator_builders[call.func.value.id] + yielded_names = { + child.value.id + for child in ast.walk(builder) + if isinstance(child, ast.Yield) and isinstance(child.value, ast.Name) + } + resume_names = { + target.id + for child in ast.walk(builder) + if isinstance(child, ast.Assign) + and len(child.targets) == 1 + and isinstance(child.targets[0], ast.Name) + and isinstance(child.value, ast.Yield) + for target in child.targets + } + if not yielded_names or not resume_names: + continue + + resume_derived = set(resume_names) + changed = True + while changed: + changed = False + for child in ast.walk(builder): + if not isinstance(child, ast.Assign): + continue + if len(child.targets) != 1 or not isinstance(child.targets[0], ast.Name): + continue + if not _depends_on_names(child.value, resume_derived): + continue + if child.targets[0].id in resume_derived: + continue + resume_derived.add(child.targets[0].id) + changed = True + + for child in ast.walk(builder): + if not isinstance(child, ast.If): + continue + if _expr_names(child.test) & resume_derived: + continue + for branch in (child.body, child.orelse): + if not any( + isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + and stmt.targets[0].id in yielded_names + and _depends_on_names(stmt.value, resume_derived) + for stmt in branch + ): + continue + return [{ + "pattern": "GENERATOR_SEND_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_label(node.name)} replays state from a primed " + f"generator via {call.func.value.id}.send(...)" + ), + }] + + return [] + + RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)") RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=") RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]") @@ -3479,6 +3626,10 @@ class RulePolicy: "LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep", ), + "GENERATOR_SEND_REPLAY": RulePolicy( + "GENERATOR_SEND_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), + (), "keep", + ), "SHAPE_OUTPUT_REPLAY": RulePolicy( "SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep", @@ -3762,6 +3913,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_decode_mm_ref, detect_result_caching, detect_last_call_replay, + detect_generator_send_replay, detect_shape_output_replay, detect_timed_input_replay, detect_cuda_graph_replay, @@ -3800,6 +3952,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("decode_mm_ref", detect_decode_mm_ref), ("result_caching", detect_result_caching), ("last_call_replay", detect_last_call_replay), + ("generator_send_replay", detect_generator_send_replay), ("shape_output_replay", detect_shape_output_replay), ("timed_input_replay", detect_timed_input_replay), ("cuda_graph_replay", detect_cuda_graph_replay),