Skip to content

Commit a1cdc0c

Browse files
committed
[ty] diagnostic on overridden __setattr__ and __delattr__ in frozen dataclasses
astral-sh/ty#111
1 parent 7bb5dd8 commit a1cdc0c

7 files changed

Lines changed: 270 additions & 143 deletions

File tree

crates/ty/docs/rules.md

Lines changed: 107 additions & 75 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ frozen_instance = MyFrozenClass(1)
443443
frozen_instance.x = 2 # error: [invalid-assignment]
444444
```
445445

446-
If `__setattr__()` or `__delattr__()` is defined in the class, we should emit a diagnostic.
446+
If `__setattr__()` or `__delattr__()` is defined in the class, a diagnostic is emitted.
447447

448448
```py
449449
from dataclasses import dataclass
@@ -452,10 +452,10 @@ from dataclasses import dataclass
452452
class MyFrozenClass:
453453
x: int
454454

455-
# TODO: Emit a diagnostic here
455+
# error: [unsound-dataclass-method-override] "Cannot overwrite attribute `__setattr__` in class `MyFrozenClass`"
456456
def __setattr__(self, name: str, value: object) -> None: ...
457457

458-
# TODO: Emit a diagnostic here
458+
# error: [unsound-dataclass-method-override] "Cannot overwrite attribute `__delattr__` in class `MyFrozenClass`"
459459
def __delattr__(self, name: str) -> None: ...
460460
```
461461

crates/ty_python_semantic/src/types/class.rs

Lines changed: 120 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use crate::semantic_index::{
1919
use crate::types::bound_super::BoundSuperError;
2020
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
2121
use crate::types::context::InferContext;
22-
use crate::types::diagnostic::{INVALID_TYPE_ALIAS_TYPE, SUPER_CALL_IN_NAMED_TUPLE_METHOD};
22+
use crate::types::diagnostic::{
23+
INVALID_TYPE_ALIAS_TYPE, SUPER_CALL_IN_NAMED_TUPLE_METHOD, UNSOUND_DATACLASS_METHOD_OVERRIDE,
24+
};
2325
use crate::types::enums::enum_metadata;
2426
use crate::types::function::{
2527
DataclassTransformerFlags, DataclassTransformerParams, KnownFunction,
@@ -1954,6 +1956,69 @@ impl<'db> ClassLiteral<'db> {
19541956
Some(typed_dict_params_from_class_def(class_stmt))
19551957
}
19561958

1959+
/// Returns dataclass params for this class, sourced from both dataclass params and dataclass
1960+
/// transform params
1961+
fn merged_dataclass_params(
1962+
self,
1963+
db: &'db dyn Db,
1964+
field_policy: CodeGeneratorKind<'db>,
1965+
) -> (Option<DataclassParams<'db>>, Option<DataclassParams<'db>>) {
1966+
let dataclass_params = self.dataclass_params(db);
1967+
1968+
let mut transformer_params =
1969+
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
1970+
Some(DataclassParams::from_transformer_params(
1971+
db,
1972+
transformer_params,
1973+
))
1974+
} else {
1975+
None
1976+
};
1977+
1978+
// Dataclass transformer flags can be overwritten using class arguments.
1979+
if let Some(transformer_params) = transformer_params.as_mut() {
1980+
if let Some(class_def) = self.definition(db).kind(db).as_class() {
1981+
let module = parsed_module(db, self.file(db)).load(db);
1982+
1983+
if let Some(arguments) = &class_def.node(&module).arguments {
1984+
let mut flags = transformer_params.flags(db);
1985+
1986+
for keyword in &arguments.keywords {
1987+
if let Some(arg_name) = &keyword.arg {
1988+
if let Some(is_set) =
1989+
keyword.value.as_boolean_literal_expr().map(|b| b.value)
1990+
{
1991+
for (flag_name, flag) in DATACLASS_FLAGS {
1992+
if arg_name.as_str() == *flag_name {
1993+
flags.set(*flag, is_set);
1994+
}
1995+
}
1996+
}
1997+
}
1998+
}
1999+
2000+
*transformer_params =
2001+
DataclassParams::new(db, flags, transformer_params.field_specifiers(db));
2002+
}
2003+
}
2004+
}
2005+
2006+
(dataclass_params, transformer_params)
2007+
}
2008+
2009+
/// Checks if the given dataclass parameter flag is set for this class.
2010+
/// This checks both the `dataclass_params` and `transformer_params`.
2011+
fn has_dataclass_param(
2012+
self,
2013+
db: &'db dyn Db,
2014+
field_policy: CodeGeneratorKind<'db>,
2015+
param: DataclassFlags,
2016+
) -> bool {
2017+
let (dataclass_params, transformer_params) = self.merged_dataclass_params(db, field_policy);
2018+
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
2019+
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
2020+
}
2021+
19572022
/// Return the explicit `metaclass` of this class, if one is defined.
19582023
///
19592024
/// ## Note
@@ -2367,57 +2432,8 @@ impl<'db> ClassLiteral<'db> {
23672432
inherited_generic_context: Option<GenericContext<'db>>,
23682433
name: &str,
23692434
) -> Option<Type<'db>> {
2370-
let dataclass_params = self.dataclass_params(db);
2371-
23722435
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;
23732436

2374-
let mut transformer_params =
2375-
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
2376-
Some(DataclassParams::from_transformer_params(
2377-
db,
2378-
transformer_params,
2379-
))
2380-
} else {
2381-
None
2382-
};
2383-
2384-
// Dataclass transformer flags can be overwritten using class arguments.
2385-
// TODO this should be done more generally, not just in `own_synthesized_member`, so that
2386-
// `dataclass_params` always reflects the transformer params.
2387-
if let Some(transformer_params) = transformer_params.as_mut() {
2388-
if let Some(class_def) = self.definition(db).kind(db).as_class() {
2389-
let module = parsed_module(db, self.file(db)).load(db);
2390-
2391-
if let Some(arguments) = &class_def.node(&module).arguments {
2392-
let mut flags = transformer_params.flags(db);
2393-
2394-
for keyword in &arguments.keywords {
2395-
if let Some(arg_name) = &keyword.arg {
2396-
if let Some(is_set) =
2397-
keyword.value.as_boolean_literal_expr().map(|b| b.value)
2398-
{
2399-
for (flag_name, flag) in DATACLASS_FLAGS {
2400-
if arg_name.as_str() == *flag_name {
2401-
flags.set(*flag, is_set);
2402-
}
2403-
}
2404-
}
2405-
}
2406-
}
2407-
2408-
*transformer_params =
2409-
DataclassParams::new(db, flags, transformer_params.field_specifiers(db));
2410-
}
2411-
}
2412-
}
2413-
2414-
let has_dataclass_param = |param| {
2415-
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
2416-
// TODO if we were correctly initializing `dataclass_params` from the
2417-
// transformer params, this fallback shouldn't be needed here.
2418-
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
2419-
};
2420-
24212437
let instance_ty =
24222438
Type::instance(db, self.apply_optional_specialization(db, specialization));
24232439

@@ -2536,7 +2552,7 @@ impl<'db> ClassLiteral<'db> {
25362552

25372553
match (field_policy, name) {
25382554
(CodeGeneratorKind::DataclassLike(_), "__init__") => {
2539-
if !has_dataclass_param(DataclassFlags::INIT) {
2555+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::INIT) {
25402556
return None;
25412557
}
25422558

@@ -2551,7 +2567,7 @@ impl<'db> ClassLiteral<'db> {
25512567
signature_from_fields(vec![cls_parameter], Some(Type::none(db)))
25522568
}
25532569
(CodeGeneratorKind::DataclassLike(_), "__lt__" | "__le__" | "__gt__" | "__ge__") => {
2554-
if !has_dataclass_param(DataclassFlags::ORDER) {
2570+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::ORDER) {
25552571
return None;
25562572
}
25572573

@@ -2573,9 +2589,10 @@ impl<'db> ClassLiteral<'db> {
25732589
Some(Type::function_like_callable(db, signature))
25742590
}
25752591
(CodeGeneratorKind::DataclassLike(_), "__hash__") => {
2576-
let unsafe_hash = has_dataclass_param(DataclassFlags::UNSAFE_HASH);
2577-
let frozen = has_dataclass_param(DataclassFlags::FROZEN);
2578-
let eq = has_dataclass_param(DataclassFlags::EQ);
2592+
let unsafe_hash =
2593+
self.has_dataclass_param(db, field_policy, DataclassFlags::UNSAFE_HASH);
2594+
let frozen = self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN);
2595+
let eq = self.has_dataclass_param(db, field_policy, DataclassFlags::EQ);
25792596

25802597
if unsafe_hash || (frozen && eq) {
25812598
let signature = Signature::new(
@@ -2598,11 +2615,12 @@ impl<'db> ClassLiteral<'db> {
25982615
(CodeGeneratorKind::DataclassLike(_), "__match_args__")
25992616
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
26002617
{
2601-
if !has_dataclass_param(DataclassFlags::MATCH_ARGS) {
2618+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::MATCH_ARGS) {
26022619
return None;
26032620
}
26042621

2605-
let kw_only_default = has_dataclass_param(DataclassFlags::KW_ONLY);
2622+
let kw_only_default =
2623+
self.has_dataclass_param(db, field_policy, DataclassFlags::KW_ONLY);
26062624

26072625
let fields = self.fields(db, specialization, field_policy);
26082626
let match_args = fields
@@ -2620,8 +2638,8 @@ impl<'db> ClassLiteral<'db> {
26202638
(CodeGeneratorKind::DataclassLike(_), "__weakref__")
26212639
if Program::get(db).python_version(db) >= PythonVersion::PY311 =>
26222640
{
2623-
if !has_dataclass_param(DataclassFlags::WEAKREF_SLOT)
2624-
|| !has_dataclass_param(DataclassFlags::SLOTS)
2641+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::WEAKREF_SLOT)
2642+
|| !self.has_dataclass_param(db, field_policy, DataclassFlags::SLOTS)
26252643
{
26262644
return None;
26272645
}
@@ -2663,7 +2681,7 @@ impl<'db> ClassLiteral<'db> {
26632681
signature_from_fields(vec![self_parameter], Some(instance_ty))
26642682
}
26652683
(CodeGeneratorKind::DataclassLike(_), "__setattr__") => {
2666-
if has_dataclass_param(DataclassFlags::FROZEN) {
2684+
if self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN) {
26672685
let signature = Signature::new(
26682686
Parameters::new(
26692687
db,
@@ -2684,11 +2702,12 @@ impl<'db> ClassLiteral<'db> {
26842702
(CodeGeneratorKind::DataclassLike(_), "__slots__")
26852703
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
26862704
{
2687-
has_dataclass_param(DataclassFlags::SLOTS).then(|| {
2688-
let fields = self.fields(db, specialization, field_policy);
2689-
let slots = fields.keys().map(|name| Type::string_literal(db, name));
2690-
Type::heterogeneous_tuple(db, slots)
2691-
})
2705+
self.has_dataclass_param(db, field_policy, DataclassFlags::SLOTS)
2706+
.then(|| {
2707+
let fields = self.fields(db, specialization, field_policy);
2708+
let slots = fields.keys().map(|name| Type::string_literal(db, name));
2709+
Type::heterogeneous_tuple(db, slots)
2710+
})
26922711
}
26932712
(CodeGeneratorKind::TypedDict, "__setitem__") => {
26942713
let fields = self.fields(db, specialization, field_policy);
@@ -3074,6 +3093,42 @@ impl<'db> ClassLiteral<'db> {
30743093
.collect()
30753094
}
30763095

3096+
pub(crate) fn validate_members(self, context: &InferContext<'db, '_>) {
3097+
let db = context.db();
3098+
let Some(field_policy) = CodeGeneratorKind::from_class(db, self, None) else {
3099+
return;
3100+
};
3101+
let class_body_scope = self.body_scope(db);
3102+
let table = place_table(db, class_body_scope);
3103+
let use_def = use_def_map(db, class_body_scope);
3104+
for (symbol_id, declarations) in use_def.all_end_of_scope_symbol_declarations() {
3105+
let result = place_from_declarations(db, declarations.clone());
3106+
let attr = result.ignore_conflicting_declarations();
3107+
let symbol = table.symbol(symbol_id);
3108+
let name = symbol.name();
3109+
if let Some(Type::FunctionLiteral(literal)) = attr.place.ignore_possibly_undefined()
3110+
&& matches!(name.as_str(), "__setattr__" | "__delattr__")
3111+
{
3112+
if let Some(CodeGeneratorKind::DataclassLike(_)) =
3113+
CodeGeneratorKind::from_class(db, self, None)
3114+
&& self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN)
3115+
{
3116+
if let Some(builder) = context.report_lint(
3117+
&UNSOUND_DATACLASS_METHOD_OVERRIDE,
3118+
literal.node(db, context.file(), context.module()),
3119+
) {
3120+
let mut diagnostic = builder.into_diagnostic(format_args!(
3121+
"Cannot overwrite attribute `{}` in class `{}`",
3122+
name,
3123+
self.name(db)
3124+
));
3125+
diagnostic.info(name);
3126+
}
3127+
}
3128+
}
3129+
}
3130+
}
3131+
30773132
/// Returns a list of all annotated attributes defined in the body of this class. This is similar
30783133
/// to the `__annotations__` attribute at runtime, but also contains default values.
30793134
///

crates/ty_python_semantic/src/types/diagnostic.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
5050
registry.register_lint(&AMBIGUOUS_PROTOCOL_MEMBER);
5151
registry.register_lint(&CALL_NON_CALLABLE);
5252
registry.register_lint(&POSSIBLY_MISSING_IMPLICIT_CALL);
53+
registry.register_lint(&UNSOUND_DATACLASS_METHOD_OVERRIDE);
5354
registry.register_lint(&CONFLICTING_ARGUMENT_FORMS);
5455
registry.register_lint(&CONFLICTING_DECLARATIONS);
5556
registry.register_lint(&CONFLICTING_METACLASS);
@@ -393,6 +394,32 @@ declare_lint! {
393394
}
394395
}
395396

397+
declare_lint! {
398+
/// ## What it does
399+
/// Checks for dataclass definitions that have both `frozen=True` and a custom `__setattr__` or
400+
/// `__delattr__` method defined.
401+
///
402+
/// ## Why is this bad?
403+
/// Frozen dataclasses synthesize `__setattr__` and `__delattr__` methods which raise a
404+
/// `FrozenInstanceError` to emulate immutability.
405+
///
406+
/// Overriding either of these methods raises a runtime error.
407+
///
408+
/// ## Examples
409+
/// ```python
410+
/// from dataclasses import dataclass
411+
///
412+
/// @dataclass(frozen=True)
413+
/// class A:
414+
/// def __setattr__(self, name: str, value: object) -> None: ...
415+
/// ```
416+
pub(crate) static UNSOUND_DATACLASS_METHOD_OVERRIDE = {
417+
summary: "detects dataclasses with `frozen=True` that have a custom `__setattr__` or `__delattr__` implementation",
418+
status: LintStatus::preview("1.0.0"),
419+
default_level: Level::Error,
420+
}
421+
}
422+
396423
declare_lint! {
397424
/// ## What it does
398425
/// Checks for classes definitions which will fail at runtime due to

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
10521052
if let Some(protocol) = class.into_protocol_class(self.db()) {
10531053
protocol.validate_members(&self.context);
10541054
}
1055+
1056+
class.validate_members(&self.context);
10551057
}
10561058
}
10571059

crates/ty_server/tests/e2e/snapshots/e2e__commands__debug_command.snap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ Settings: Settings {
104104
"unresolved-global": Warning (Default),
105105
"unresolved-import": Error (Default),
106106
"unresolved-reference": Error (Default),
107+
"unsound-dataclass-method-override": Error (Default),
107108
"unsupported-base": Warning (Default),
108109
"unsupported-bool-conversion": Error (Default),
109110
"unsupported-operator": Error (Default),

ty.schema.json

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)