Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,59 @@ def detect_precision_downgrade(code: str | SubmissionFacts) -> list[dict]:
return matches


def detect_lambda_dict_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: lambda entrypoint with operator dict dispatch + data_ptr.

Catches RED #44:
custom_kernel = lambda data: (
operator.setitem(_DISPATCH, data.data_ptr(), torch.mm(data, data.t()))
if data.data_ptr() not in _DISPATCH else None,
_DISPATCH.get(data.data_ptr())
)[-1]

Uses lambda (no def), operator.setitem/operator.getitem, and
data_ptr() as dict keys.
"""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []

for node in ast.walk(tree):
if not isinstance(node, ast.Assign):
continue
if (
len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and is_entrypoint_name(node.targets[0].id)
and isinstance(node.value, ast.Lambda)
):
lambda_body = node.value.body
has_data_ptr_in_lambda = _expr_has_data_ptr_fast(
lambda_body, facts._nodes_with_data_ptr
)
has_operator = False
for sub in ast.walk(lambda_body):
if (
isinstance(sub, ast.Attribute)
and isinstance(sub.value, ast.Name)
and sub.value.id == "operator"
):
has_operator = True
break

if has_data_ptr_in_lambda and has_operator:
return [{
"pattern": "LAMBDA_DICT_REPLAY",
"severity": "critical",
"evidence": (
f"lambda entrypoint with operator module and "
f"data_ptr() dict dispatch"
),
}]
return []


# ---------------------------------------------------------------------------
# Score anomaly detection
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -3645,6 +3698,10 @@ class RulePolicy:
"SCORE_EXTREME_SPEEDUP", "score_anomaly", "support", SUSPICIOUS_ONLY, (),
(), "downgrade",
),
"LAMBDA_DICT_REPLAY": RulePolicy(
"LAMBDA_DICT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
(), "keep",
),
"DUPLICATE_SPAM": RulePolicy(
"DUPLICATE_SPAM", "administrative", "support", SUSPICIOUS_ONLY, (),
(), "downgrade",
Expand Down Expand Up @@ -3790,6 +3847,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_thread_injection,
detect_lazy_tensor,
detect_precision_downgrade,
detect_lambda_dict_replay,
]

BASE_DETECTOR_SPECS = [
Expand Down Expand Up @@ -3827,6 +3885,7 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("thread_injection", detect_thread_injection),
("lazy_tensor", detect_lazy_tensor),
("precision_downgrade", detect_precision_downgrade),
("lambda_dict_replay", detect_lambda_dict_replay),
]

VALID_RULE_OUTCOMES = {AUTO_FILTER, SUSPICIOUS_ONLY, TELEMETRY_ONLY}
Expand Down Expand Up @@ -4703,7 +4762,7 @@ def _worker_parquet(args: tuple) -> dict:
"HARDCODED_SHAPES", "TRIVIAL_PROBE",
"OBFUSCATED_EXEC", "DYNAMIC_EXECUTION", "MODULE_RELOAD", "THREAD_INJECTION", "LAZY_TENSOR",
"TOKEN_PASTE_CUDA_API", "SEQUENCE_BATCH_GRAPH", "PARTIAL_GRAPH_KEY", "RUNTIME_PACKAGE_INSTALL",
"PRECISION_DOWNGRADE", "SCORE_PHYSICS_FLOOR", "SCORE_IMPOSSIBLE", "SCORE_SUSPECT_FLOOR",
"PRECISION_DOWNGRADE", "LAMBDA_DICT_REPLAY", "SCORE_PHYSICS_FLOOR", "SCORE_IMPOSSIBLE", "SCORE_SUSPECT_FLOOR",
"SCORE_BROKEN", "SCORE_EXTREME_SPEEDUP", "DUPLICATE_SPAM", "NEAR_CLONE_SPAM",
]

Expand Down