Skip to content

Commit 1b03f0b

Browse files
Pass kwargs to post init in dataclasses (#3771)
* kwargs to post init * default post init in strict dataclass --------- Co-authored-by: Lucain Pouget <lucainp@gmail.com>
1 parent 67542bd commit 1b03f0b

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

src/huggingface_hub/dataclasses.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,27 @@ def __init__(self, **kwargs: Any) -> None:
166166
# Call the original __init__ with standard fields
167167
original_init(self, **standard_kwargs)
168168

169-
# Add any additional kwargs as attributes
169+
# Pass any additional kwargs to `__post_init__` and let the object
170+
# decide whether to set the attr or use for different purposes (e.g. BC checks)
171+
additional_kwargs = {}
170172
for name, value in kwargs.items():
171173
if name not in dataclass_fields:
172-
self.__setattr__(name, value)
174+
additional_kwargs[name] = value
175+
176+
self.__post_init__(**additional_kwargs)
173177

174178
cls.__init__ = __init__ # type: ignore[method-assign]
175179

180+
# Define a default __post_init__ if not defined
181+
if not hasattr(cls, "__post_init__"):
182+
183+
def __post_init__(self, **kwargs: Any) -> None:
184+
"""Default __post_init__ to accept additional kwargs."""
185+
for name, value in kwargs.items():
186+
setattr(self, name, value)
187+
188+
cls.__post_init__ = __post_init__ # type: ignore[method-assign]
189+
176190
# (optional) Override __repr__ to include additional kwargs
177191
original_repr = cls.__repr__
178192

tests/test_utils_strict_dataclass.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ class ConfigWithKwargs:
8484
vocab_size: int = validated_field(validator=positive_int, default=16)
8585

8686

87+
@strict(accept_kwargs=True)
88+
@dataclass
89+
class ConfigWithKwargsAndPostInit:
90+
model_type: str
91+
vocab_size: int = validated_field(validator=positive_int, default=16)
92+
93+
def __post_init__(self, **kwargs: Any) -> None:
94+
"""Custom __post_init__ that also accepts additional kwargs."""
95+
for name, value in kwargs.items():
96+
setattr(self, name.upper(), value) # store additional kwargs in uppercase (just for testing)
97+
98+
8799
class DummyClass:
88100
pass
89101

@@ -376,6 +388,13 @@ class Config:
376388
Config(model_type="bert", vocab_size=30000)
377389

378390

391+
def test_post_init_with_kwargs():
392+
config = ConfigWithKwargsAndPostInit(model_type="bert", vocab_size=30000, extra_param="extra_value")
393+
assert config.model_type == "bert"
394+
assert config.vocab_size == 30000
395+
assert config.EXTRA_PARAM == "extra_value" # stored in uppercase by custom __post_init__
396+
397+
379398
def test_is_recognized_as_dataclass():
380399
# Check that dataclasses module recognizes it as a dataclass
381400
assert is_dataclass(Config)

0 commit comments

Comments
 (0)