Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,118 @@ def _has_ver(expr: ast.AST | None) -> bool:
return []


def detect_partial_namespace_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: functools.partial binds a namespace used as replay state."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)

functions = {
stmt.name: stmt
for stmt in tree.body
if isinstance(stmt, ast.FunctionDef)
}

def _is_partial(expr: ast.AST | None) -> bool:
return (
isinstance(expr, ast.Name) and expr.id == "partial"
) or (
isinstance(expr, ast.Attribute)
and expr.attr == "partial"
and isinstance(expr.value, ast.Name)
and expr.value.id == "functools"
)

def _is_namespace_none(expr: ast.AST | None) -> bool:
return (
isinstance(expr, ast.Call)
and isinstance(expr.func, ast.Name)
and expr.func.id == "SimpleNamespace"
and any(
kw.arg is not None
and isinstance(kw.value, ast.Constant)
and kw.value.value is None
for kw in expr.keywords
)
)

partials = []
for stmt in tree.body:
if not (
isinstance(stmt, ast.Assign)
and any(isinstance(target, ast.Name) and is_entrypoint_name(target.id) for target in stmt.targets)
and isinstance(stmt.value, ast.Call)
and _is_partial(stmt.value.func)
and len(stmt.value.args) >= 2
and isinstance(stmt.value.args[0], ast.Name)
and stmt.value.args[0].id in functions
and _is_namespace_none(stmt.value.args[1])
):
continue
partials.append(functions[stmt.value.args[0].id])

for fn in partials:
if len(fn.args.args) < 2:
continue
state_param = fn.args.args[0].arg
data_params = {arg.arg for arg in fn.args.args[1:]}
data_params.update(arg.arg for arg in fn.args.posonlyargs)
data_params.update(arg.arg for arg in fn.args.kwonlyargs)

mutated = {
target.attr
for stmt in ast.walk(fn)
if isinstance(stmt, ast.Assign)
and bool(_expr_names(stmt.value) & data_params)
for target in stmt.targets
if isinstance(target, ast.Attribute)
and isinstance(target.value, ast.Name)
and target.value.id == state_param
}
if not mutated:
continue

for child in ast.walk(fn):
if not isinstance(child, ast.If):
continue
if _expr_names(child.test) & data_params:
continue
tested = {
expr.attr
for expr in ast.walk(child.test)
if isinstance(expr, ast.Attribute)
and isinstance(expr.value, ast.Name)
and expr.value.id == state_param
and expr.attr in mutated
}
if not tested:
continue
if _body_has_calls(child.body):
continue
returned = {
stmt.value.attr
for stmt in child.body
if isinstance(stmt, ast.Return)
and isinstance(stmt.value, ast.Attribute)
and isinstance(stmt.value.value, ast.Name)
and stmt.value.value.id == state_param
and stmt.value.attr in mutated
}
if tested & returned:
return [{
"pattern": "PARTIAL_NAMESPACE_REPLAY",
"severity": "critical",
"evidence": (
f"{entrypoint_name} replays output through a "
"SimpleNamespace bound by partial"
),
}]

return []


RE_OBJECT_ID_DATA = re.compile(r"\bdata_id\s*=\s*id\s*\(\s*data\s*\)")
RE_RESULT_BANK_SET = re.compile(r"_superbatch_results\s*\[\s*(?:did|data_id)\s*\]\s*=")
RE_RESULT_BANK_RETURN = re.compile(r"return\s+_superbatch_results\s*\[\s*data_id\s*\]")
Expand Down Expand Up @@ -3479,6 +3591,10 @@ class RulePolicy:
"LAST_CALL_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_MANUAL_LAST_CALL_REPLAY_FIXTURES, "keep",
),
"PARTIAL_NAMESPACE_REPLAY": RulePolicy(
"PARTIAL_NAMESPACE_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"SHAPE_OUTPUT_REPLAY": RulePolicy(
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3762,6 +3878,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_decode_mm_ref,
detect_result_caching,
detect_last_call_replay,
detect_partial_namespace_replay,
detect_shape_output_replay,
detect_timed_input_replay,
detect_cuda_graph_replay,
Expand Down Expand Up @@ -3800,6 +3917,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("decode_mm_ref", detect_decode_mm_ref),
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("partial_namespace_replay", detect_partial_namespace_replay),
("shape_output_replay", detect_shape_output_replay),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
Expand Down Expand Up @@ -4696,7 +4814,7 @@ def _worker_parquet(args: tuple) -> dict:
"EVALUATOR_EXPLOIT", "HARNESS_RUNTIME_PATCHING", "MODULE_MUTATION", "GLOBALS_MUTATION", "CODE_REPLACEMENT",
"FRAME_WALK_ACCESS", "FRAME_WALK_MUTATION", "SYS_MODULES_ACCESS", "GLOBALS_ACCESS", "CODE_ACCESS",
"TRUSTED_MODULE_IMPORT",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"OUTPUT_REPLAY_CACHE", "LAST_CALL_REPLAY", "PARTIAL_NAMESPACE_REPLAY", "SHAPE_OUTPUT_REPLAY", "TIMED_INPUT_REPLAY", "CONFIG_CACHE_EXPLOIT", "POINTER_REPLAY", "RESULT_BANK_REPLAY", "PREPROCESS_CACHE", "WORKSPACE_CACHE",
"RUNNER_PLAN_CACHE", "CUDA_GRAPH_PYTHON", "CUDA_GRAPH_REPLAY",
"TIMER_MONKEYPATCH", "FAKE_BENCHMARK_EMIT", "STDIO_REDIRECT", "UNSYNC_MULTISTREAM", "CUDA_EVENT_DISABLE_TIMING",
"SCALED_MM_REF", "DECODE_MM_REF", "SILENT_FALLBACK", "REFERENCE_PRECOMPUTE_REPLAY", "TORCH_COMPILE_CACHE",
Expand Down