Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,109 @@ def _expr_has_tensor_version(expr: ast.AST | None) -> bool:
)


def _function_param_names(fn: ast.FunctionDef | ast.AsyncFunctionDef) -> tuple[str, ...]:
names = [arg.arg for arg in fn.args.posonlyargs]
names.extend(arg.arg for arg in fn.args.args)
names.extend(arg.arg for arg in fn.args.kwonlyargs)
if fn.args.vararg is not None:
names.append(fn.args.vararg.arg)
if fn.args.kwarg is not None:
names.append(fn.args.kwarg.arg)
return tuple(names)


def _depends_on_names(expr: ast.AST | None, names: set[str]) -> bool:
if expr is None or not names:
return False
return bool(_expr_names(expr) & names)


def _state_read_root(expr: ast.AST | None, state_names: set[str], aliases: dict[str, str]) -> Optional[str]:
if expr is None:
return None
if isinstance(expr, ast.Name):
if expr.id in state_names:
return expr.id
return aliases.get(expr.id)
if isinstance(expr, (ast.Subscript, ast.Attribute)):
root = _ast_root_name(expr)
if root in state_names:
return root
if (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Attribute)
and expr.func.attr == "get"
):
root = _ast_root_name(expr.func.value)
if root in state_names:
return root
return None


def _has_state_shortcut_replay(
fn: ast.FunctionDef | ast.AsyncFunctionDef,
state_names: set[str],
input_names: set[str],
) -> Optional[str]:
input_derived = set(input_names)
aliases: dict[str, str] = {}

changed = True
while changed:
changed = False
for child in ast.walk(fn):
if not isinstance(child, ast.Assign):
continue
if len(child.targets) != 1 or not isinstance(child.targets[0], ast.Name):
continue
target_name = child.targets[0].id
if _depends_on_names(child.value, input_derived) and target_name not in input_derived:
input_derived.add(target_name)
changed = True
state_root = _state_read_root(child.value, state_names, aliases)
if state_root is not None and aliases.get(target_name) != state_root:
aliases[target_name] = state_root
changed = True

state_from_input: set[str] = set()
for child in ast.walk(fn):
if isinstance(child, ast.Assign):
for target in child.targets:
root = _ast_root_name(target)
if root in state_names and _depends_on_names(child.value, input_derived):
state_from_input.add(root)
elif isinstance(child, ast.Call) and isinstance(child.func, ast.Attribute):
root = _ast_root_name(child.func.value)
if root not in state_names or child.func.attr not in {"append", "extend", "set", "update"}:
continue
if any(_depends_on_names(arg, input_derived) for arg in child.args):
state_from_input.add(root)

if not state_from_input:
return None

for child in ast.walk(fn):
if not isinstance(child, ast.If):
continue
if _body_has_calls(child.body):
continue
if _expr_names(child.test) & input_names:
continue
test_names = _expr_names(child.test)
for stmt in child.body:
if not isinstance(stmt, ast.Return) or stmt.value is None:
continue
state_root = _state_read_root(stmt.value, state_names, aliases)
if state_root not in state_from_input:
continue
if state_root in test_names or any(
alias in test_names and alias_root == state_root
for alias, alias_root in aliases.items()
):
return state_root
return None


_TRIVIAL_GPU_OPS = frozenset({
"fill_", "zero_", "copy_", "fill", "zero", "record",
})
Expand Down Expand Up @@ -2462,6 +2565,77 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_partial_bound_state_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: functools.partial binds mutable state used for replay."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None or not isinstance(tree, ast.Module):
return []

module_aliases = {"functools"}
direct_aliases: set[str] = set()
for node in facts._imports:
for alias in node.names:
if alias.name == "functools":
module_aliases.add(alias.asname or alias.name)
for node in facts._import_froms:
if node.module != "functools":
continue
for alias in node.names:
if alias.name == "partial":
direct_aliases.add(alias.asname or alias.name)

functions = {
node.name: node
for node in tree.body
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
}

for stmt in tree.body:
if not isinstance(stmt, ast.Assign):
continue
if not any(isinstance(target, ast.Name) and target.id == "custom_kernel" for target in stmt.targets):
continue
if not isinstance(stmt.value, ast.Call) or not stmt.value.args:
continue
call = stmt.value
is_partial = False
if isinstance(call.func, ast.Name) and call.func.id in direct_aliases:
is_partial = True
elif (
isinstance(call.func, ast.Attribute)
and call.func.attr == "partial"
and isinstance(call.func.value, ast.Name)
and call.func.value.id in module_aliases
):
is_partial = True
if not is_partial:
continue
if not isinstance(call.args[0], ast.Name) or call.args[0].id not in functions:
continue

helper = functions[call.args[0].id]
params = _function_param_names(helper)
bound_count = len(call.args) - 1
if bound_count <= 0 or bound_count >= len(params):
continue
state_names = set(params[:bound_count])
input_names = set(params[bound_count:])
state_root = _has_state_shortcut_replay(helper, state_names, input_names)
if state_root is None:
continue
return [{
"pattern": "PARTIAL_BOUND_STATE_REPLAY",
"severity": "critical",
"evidence": (
f"custom_kernel binds replay state {state_root} through functools.partial "
f"and returns it from {helper.name} on an input-independent fast path"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3653,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"PARTIAL_BOUND_STATE_REPLAY": RulePolicy(
"PARTIAL_BOUND_STATE_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3940,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_partial_bound_state_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3979,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("partial_bound_state_replay", detect_partial_bound_state_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down