@@ -19,7 +19,9 @@ use crate::semantic_index::{
1919use crate :: types:: bound_super:: BoundSuperError ;
2020use crate :: types:: constraints:: { ConstraintSet , IteratorConstraintsExtension } ;
2121use crate :: types:: context:: InferContext ;
22- use crate :: types:: diagnostic:: { INVALID_TYPE_ALIAS_TYPE , SUPER_CALL_IN_NAMED_TUPLE_METHOD } ;
22+ use crate :: types:: diagnostic:: {
23+ INVALID_TYPE_ALIAS_TYPE , SUPER_CALL_IN_NAMED_TUPLE_METHOD , UNSOUND_DATACLASS_METHOD_OVERRIDE ,
24+ } ;
2325use crate :: types:: enums:: enum_metadata;
2426use crate :: types:: function:: {
2527 DataclassTransformerFlags , DataclassTransformerParams , KnownFunction ,
@@ -1954,6 +1956,69 @@ impl<'db> ClassLiteral<'db> {
19541956 Some ( typed_dict_params_from_class_def ( class_stmt) )
19551957 }
19561958
1959+ /// Returns dataclass params for this class, sourced from both dataclass params and dataclass
1960+ /// transform params
1961+ fn merged_dataclass_params (
1962+ self ,
1963+ db : & ' db dyn Db ,
1964+ field_policy : CodeGeneratorKind < ' db > ,
1965+ ) -> ( Option < DataclassParams < ' db > > , Option < DataclassParams < ' db > > ) {
1966+ let dataclass_params = self . dataclass_params ( db) ;
1967+
1968+ let mut transformer_params =
1969+ if let CodeGeneratorKind :: DataclassLike ( Some ( transformer_params) ) = field_policy {
1970+ Some ( DataclassParams :: from_transformer_params (
1971+ db,
1972+ transformer_params,
1973+ ) )
1974+ } else {
1975+ None
1976+ } ;
1977+
1978+ // Dataclass transformer flags can be overwritten using class arguments.
1979+ if let Some ( transformer_params) = transformer_params. as_mut ( ) {
1980+ if let Some ( class_def) = self . definition ( db) . kind ( db) . as_class ( ) {
1981+ let module = parsed_module ( db, self . file ( db) ) . load ( db) ;
1982+
1983+ if let Some ( arguments) = & class_def. node ( & module) . arguments {
1984+ let mut flags = transformer_params. flags ( db) ;
1985+
1986+ for keyword in & arguments. keywords {
1987+ if let Some ( arg_name) = & keyword. arg {
1988+ if let Some ( is_set) =
1989+ keyword. value . as_boolean_literal_expr ( ) . map ( |b| b. value )
1990+ {
1991+ for ( flag_name, flag) in DATACLASS_FLAGS {
1992+ if arg_name. as_str ( ) == * flag_name {
1993+ flags. set ( * flag, is_set) ;
1994+ }
1995+ }
1996+ }
1997+ }
1998+ }
1999+
2000+ * transformer_params =
2001+ DataclassParams :: new ( db, flags, transformer_params. field_specifiers ( db) ) ;
2002+ }
2003+ }
2004+ }
2005+
2006+ ( dataclass_params, transformer_params)
2007+ }
2008+
2009+ /// Checks if the given dataclass parameter flag is set for this class.
2010+ /// This checks both the `dataclass_params` and `transformer_params`.
2011+ fn has_dataclass_param (
2012+ self ,
2013+ db : & ' db dyn Db ,
2014+ field_policy : CodeGeneratorKind < ' db > ,
2015+ param : DataclassFlags ,
2016+ ) -> bool {
2017+ let ( dataclass_params, transformer_params) = self . merged_dataclass_params ( db, field_policy) ;
2018+ dataclass_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
2019+ || transformer_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
2020+ }
2021+
19572022 /// Return the explicit `metaclass` of this class, if one is defined.
19582023 ///
19592024 /// ## Note
@@ -2367,57 +2432,8 @@ impl<'db> ClassLiteral<'db> {
23672432 inherited_generic_context : Option < GenericContext < ' db > > ,
23682433 name : & str ,
23692434 ) -> Option < Type < ' db > > {
2370- let dataclass_params = self . dataclass_params ( db) ;
2371-
23722435 let field_policy = CodeGeneratorKind :: from_class ( db, self , specialization) ?;
23732436
2374- let mut transformer_params =
2375- if let CodeGeneratorKind :: DataclassLike ( Some ( transformer_params) ) = field_policy {
2376- Some ( DataclassParams :: from_transformer_params (
2377- db,
2378- transformer_params,
2379- ) )
2380- } else {
2381- None
2382- } ;
2383-
2384- // Dataclass transformer flags can be overwritten using class arguments.
2385- // TODO this should be done more generally, not just in `own_synthesized_member`, so that
2386- // `dataclass_params` always reflects the transformer params.
2387- if let Some ( transformer_params) = transformer_params. as_mut ( ) {
2388- if let Some ( class_def) = self . definition ( db) . kind ( db) . as_class ( ) {
2389- let module = parsed_module ( db, self . file ( db) ) . load ( db) ;
2390-
2391- if let Some ( arguments) = & class_def. node ( & module) . arguments {
2392- let mut flags = transformer_params. flags ( db) ;
2393-
2394- for keyword in & arguments. keywords {
2395- if let Some ( arg_name) = & keyword. arg {
2396- if let Some ( is_set) =
2397- keyword. value . as_boolean_literal_expr ( ) . map ( |b| b. value )
2398- {
2399- for ( flag_name, flag) in DATACLASS_FLAGS {
2400- if arg_name. as_str ( ) == * flag_name {
2401- flags. set ( * flag, is_set) ;
2402- }
2403- }
2404- }
2405- }
2406- }
2407-
2408- * transformer_params =
2409- DataclassParams :: new ( db, flags, transformer_params. field_specifiers ( db) ) ;
2410- }
2411- }
2412- }
2413-
2414- let has_dataclass_param = |param| {
2415- dataclass_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
2416- // TODO if we were correctly initializing `dataclass_params` from the
2417- // transformer params, this fallback shouldn't be needed here.
2418- || transformer_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
2419- } ;
2420-
24212437 let instance_ty =
24222438 Type :: instance ( db, self . apply_optional_specialization ( db, specialization) ) ;
24232439
@@ -2536,7 +2552,7 @@ impl<'db> ClassLiteral<'db> {
25362552
25372553 match ( field_policy, name) {
25382554 ( CodeGeneratorKind :: DataclassLike ( _) , "__init__" ) => {
2539- if !has_dataclass_param ( DataclassFlags :: INIT ) {
2555+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: INIT ) {
25402556 return None ;
25412557 }
25422558
@@ -2551,7 +2567,7 @@ impl<'db> ClassLiteral<'db> {
25512567 signature_from_fields ( vec ! [ cls_parameter] , Some ( Type :: none ( db) ) )
25522568 }
25532569 ( CodeGeneratorKind :: DataclassLike ( _) , "__lt__" | "__le__" | "__gt__" | "__ge__" ) => {
2554- if !has_dataclass_param ( DataclassFlags :: ORDER ) {
2570+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: ORDER ) {
25552571 return None ;
25562572 }
25572573
@@ -2573,9 +2589,10 @@ impl<'db> ClassLiteral<'db> {
25732589 Some ( Type :: function_like_callable ( db, signature) )
25742590 }
25752591 ( CodeGeneratorKind :: DataclassLike ( _) , "__hash__" ) => {
2576- let unsafe_hash = has_dataclass_param ( DataclassFlags :: UNSAFE_HASH ) ;
2577- let frozen = has_dataclass_param ( DataclassFlags :: FROZEN ) ;
2578- let eq = has_dataclass_param ( DataclassFlags :: EQ ) ;
2592+ let unsafe_hash =
2593+ self . has_dataclass_param ( db, field_policy, DataclassFlags :: UNSAFE_HASH ) ;
2594+ let frozen = self . has_dataclass_param ( db, field_policy, DataclassFlags :: FROZEN ) ;
2595+ let eq = self . has_dataclass_param ( db, field_policy, DataclassFlags :: EQ ) ;
25792596
25802597 if unsafe_hash || ( frozen && eq) {
25812598 let signature = Signature :: new (
@@ -2598,11 +2615,12 @@ impl<'db> ClassLiteral<'db> {
25982615 ( CodeGeneratorKind :: DataclassLike ( _) , "__match_args__" )
25992616 if Program :: get ( db) . python_version ( db) >= PythonVersion :: PY310 =>
26002617 {
2601- if !has_dataclass_param ( DataclassFlags :: MATCH_ARGS ) {
2618+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: MATCH_ARGS ) {
26022619 return None ;
26032620 }
26042621
2605- let kw_only_default = has_dataclass_param ( DataclassFlags :: KW_ONLY ) ;
2622+ let kw_only_default =
2623+ self . has_dataclass_param ( db, field_policy, DataclassFlags :: KW_ONLY ) ;
26062624
26072625 let fields = self . fields ( db, specialization, field_policy) ;
26082626 let match_args = fields
@@ -2620,8 +2638,8 @@ impl<'db> ClassLiteral<'db> {
26202638 ( CodeGeneratorKind :: DataclassLike ( _) , "__weakref__" )
26212639 if Program :: get ( db) . python_version ( db) >= PythonVersion :: PY311 =>
26222640 {
2623- if !has_dataclass_param ( DataclassFlags :: WEAKREF_SLOT )
2624- || !has_dataclass_param ( DataclassFlags :: SLOTS )
2641+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: WEAKREF_SLOT )
2642+ || !self . has_dataclass_param ( db , field_policy , DataclassFlags :: SLOTS )
26252643 {
26262644 return None ;
26272645 }
@@ -2663,7 +2681,7 @@ impl<'db> ClassLiteral<'db> {
26632681 signature_from_fields ( vec ! [ self_parameter] , Some ( instance_ty) )
26642682 }
26652683 ( CodeGeneratorKind :: DataclassLike ( _) , "__setattr__" ) => {
2666- if has_dataclass_param ( DataclassFlags :: FROZEN ) {
2684+ if self . has_dataclass_param ( db , field_policy , DataclassFlags :: FROZEN ) {
26672685 let signature = Signature :: new (
26682686 Parameters :: new (
26692687 db,
@@ -2684,11 +2702,12 @@ impl<'db> ClassLiteral<'db> {
26842702 ( CodeGeneratorKind :: DataclassLike ( _) , "__slots__" )
26852703 if Program :: get ( db) . python_version ( db) >= PythonVersion :: PY310 =>
26862704 {
2687- has_dataclass_param ( DataclassFlags :: SLOTS ) . then ( || {
2688- let fields = self . fields ( db, specialization, field_policy) ;
2689- let slots = fields. keys ( ) . map ( |name| Type :: string_literal ( db, name) ) ;
2690- Type :: heterogeneous_tuple ( db, slots)
2691- } )
2705+ self . has_dataclass_param ( db, field_policy, DataclassFlags :: SLOTS )
2706+ . then ( || {
2707+ let fields = self . fields ( db, specialization, field_policy) ;
2708+ let slots = fields. keys ( ) . map ( |name| Type :: string_literal ( db, name) ) ;
2709+ Type :: heterogeneous_tuple ( db, slots)
2710+ } )
26922711 }
26932712 ( CodeGeneratorKind :: TypedDict , "__setitem__" ) => {
26942713 let fields = self . fields ( db, specialization, field_policy) ;
@@ -3074,6 +3093,42 @@ impl<'db> ClassLiteral<'db> {
30743093 . collect ( )
30753094 }
30763095
3096+ pub ( crate ) fn validate_members ( self , context : & InferContext < ' db , ' _ > ) {
3097+ let db = context. db ( ) ;
3098+ let Some ( field_policy) = CodeGeneratorKind :: from_class ( db, self , None ) else {
3099+ return ;
3100+ } ;
3101+ let class_body_scope = self . body_scope ( db) ;
3102+ let table = place_table ( db, class_body_scope) ;
3103+ let use_def = use_def_map ( db, class_body_scope) ;
3104+ for ( symbol_id, declarations) in use_def. all_end_of_scope_symbol_declarations ( ) {
3105+ let result = place_from_declarations ( db, declarations. clone ( ) ) ;
3106+ let attr = result. ignore_conflicting_declarations ( ) ;
3107+ let symbol = table. symbol ( symbol_id) ;
3108+ let name = symbol. name ( ) ;
3109+ if let Some ( Type :: FunctionLiteral ( literal) ) = attr. place . ignore_possibly_undefined ( )
3110+ && matches ! ( name. as_str( ) , "__setattr__" | "__delattr__" )
3111+ {
3112+ if let Some ( CodeGeneratorKind :: DataclassLike ( _) ) =
3113+ CodeGeneratorKind :: from_class ( db, self , None )
3114+ && self . has_dataclass_param ( db, field_policy, DataclassFlags :: FROZEN )
3115+ {
3116+ if let Some ( builder) = context. report_lint (
3117+ & UNSOUND_DATACLASS_METHOD_OVERRIDE ,
3118+ literal. node ( db, context. file ( ) , context. module ( ) ) ,
3119+ ) {
3120+ let mut diagnostic = builder. into_diagnostic ( format_args ! (
3121+ "Cannot overwrite attribute `{}` in class `{}`" ,
3122+ name,
3123+ self . name( db)
3124+ ) ) ;
3125+ diagnostic. info ( name) ;
3126+ }
3127+ }
3128+ }
3129+ }
3130+ }
3131+
30773132 /// Returns a list of all annotated attributes defined in the body of this class. This is similar
30783133 /// to the `__annotations__` attribute at runtime, but also contains default values.
30793134 ///
0 commit comments