From 2bdb5f9da719674d555829d0cc47bd23330e853d Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Sat, 30 Nov 2024 12:32:11 +0100 Subject: [PATCH 01/25] refactored --- rectools/models/base.py | 25 ++----------------------- rectools/models/implicit_als.py | 2 +- rectools/models/lightfm.py | 3 ++- rectools/utils/serialization.py | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 38 insertions(+), 25 deletions(-) create mode 100644 rectools/utils/serialization.py diff --git a/rectools/models/base.py b/rectools/models/base.py index 60e02c37..379429cf 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -22,7 +22,6 @@ import numpy as np import pandas as pd import typing_extensions as tpe -from pydantic import PlainSerializer from pydantic_core import PydanticSerializationError from rectools import Columns, ExternalIds, InternalIds @@ -32,6 +31,7 @@ from rectools.types import ExternalIdsArray, InternalIdsArray from rectools.utils.config import BaseConfig from rectools.utils.misc import make_dict_flat +from rectools.utils.serialization import PICKLE_PROTOCOL, FileLike, read_bytes T = tp.TypeVar("T", bound="ModelBase") ScoresArray = np.ndarray @@ -44,24 +44,6 @@ RecoTriplet_T = tp.TypeVar("RecoTriplet_T", InternalRecoTriplet, SemiInternalRecoTriplet, ExternalRecoTriplet) -FileLike = tp.Union[str, Path, tp.IO[bytes]] - -PICKLE_PROTOCOL = 5 - - -def _serialize_random_state(rs: tp.Optional[tp.Union[None, int, np.random.RandomState]]) -> tp.Union[None, int]: - if rs is None or isinstance(rs, int): - return rs - - # NOBUG: We can add serialization using get/set_state, but it's not human readable - raise TypeError("`random_state` must be ``None`` or have ``int`` type to convert it to simple type") - - -RandomState = tpe.Annotated[ - tp.Union[None, int, np.random.RandomState], - PlainSerializer(func=_serialize_random_state, when_used="json"), -] - class ModelConfig(BaseConfig): """Base model config.""" @@ -244,10 +226,7 @@ def load(cls, f: FileLike) -> tpe.Self: model Model instance. """ - if isinstance(f, (str, Path)): - data = Path(f).read_bytes() - else: - data = f.read() + data = read_bytes(f) return cls.loads(data) diff --git a/rectools/models/implicit_als.py b/rectools/models/implicit_als.py index 7e06b9a9..737ea202 100644 --- a/rectools/models/implicit_als.py +++ b/rectools/models/implicit_als.py @@ -31,8 +31,8 @@ from rectools.models.base import ModelConfig from rectools.utils.config import BaseConfig from rectools.utils.misc import get_class_or_function_full_path, import_object +from rectools.utils.serialization import RandomState -from .base import RandomState from .rank import Distance from .vector import Factors, VectorModel diff --git a/rectools/models/lightfm.py b/rectools/models/lightfm.py index 5ae2630d..40b93189 100644 --- a/rectools/models/lightfm.py +++ b/rectools/models/lightfm.py @@ -27,8 +27,9 @@ from rectools.types import InternalIds, InternalIdsArray from rectools.utils.config import BaseConfig from rectools.utils.misc import get_class_or_function_full_path, import_object +from rectools.utils.serialization import RandomState -from .base import FixedColdRecoModelMixin, InternalRecoTriplet, ModelConfig, RandomState, Scores +from .base import FixedColdRecoModelMixin, InternalRecoTriplet, ModelConfig, Scores from .rank import Distance from .vector import Factors, VectorModel diff --git a/rectools/utils/serialization.py b/rectools/utils/serialization.py new file mode 100644 index 00000000..577047fa --- /dev/null +++ b/rectools/utils/serialization.py @@ -0,0 +1,33 @@ +import typing as tp +from pathlib import Path + +import numpy as np +import typing_extensions as tpe +from pydantic import PlainSerializer + +FileLike = tp.Union[str, Path, tp.IO[bytes]] + +PICKLE_PROTOCOL = 5 + + +def _serialize_random_state(rs: tp.Optional[tp.Union[None, int, np.random.RandomState]]) -> tp.Union[None, int]: + if rs is None or isinstance(rs, int): + return rs + + # NOBUG: We can add serialization using get/set_state, but it's not human readable + raise TypeError("`random_state` must be ``None`` or have ``int`` type to convert it to simple type") + + +RandomState = tpe.Annotated[ + tp.Union[None, int, np.random.RandomState], + PlainSerializer(func=_serialize_random_state, when_used="json"), +] + + +def read_bytes(f: FileLike) -> bytes: + """Read bytes from a file.""" + if isinstance(f, (str, Path)): + data = Path(f).read_bytes() + else: + data = f.read() + return data From e29cb342de87b392751c68a602e177a78155cb46 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Sat, 30 Nov 2024 12:32:22 +0100 Subject: [PATCH 02/25] added load_model function --- rectools/models/__init__.py | 2 ++ rectools/models/serialization.py | 23 +++++++++++++++++ tests/models/test_serialization.py | 41 ++++++++++++++++++++++++++++++ tests/models/utils.py | 13 ++++++++++ 4 files changed, 79 insertions(+) create mode 100644 rectools/models/serialization.py create mode 100644 tests/models/test_serialization.py diff --git a/rectools/models/__init__.py b/rectools/models/__init__.py index 53e25817..c511e755 100644 --- a/rectools/models/__init__.py +++ b/rectools/models/__init__.py @@ -43,6 +43,7 @@ from .popular_in_category import PopularInCategoryModel from .pure_svd import PureSVDModel from .random import RandomModel +from .serialization import load_model try: from .lightfm import LightFMWrapperModel @@ -65,4 +66,5 @@ "PureSVDModel", "RandomModel", "DSSMModel", + "load_model", ) diff --git a/rectools/models/serialization.py b/rectools/models/serialization.py new file mode 100644 index 00000000..5341bb75 --- /dev/null +++ b/rectools/models/serialization.py @@ -0,0 +1,23 @@ +import pickle + +from rectools.models.base import ModelBase +from rectools.utils.serialization import FileLike, read_bytes + + +def load_model(f: FileLike) -> ModelBase: + """ + Load model from file. + + Parameters + ---------- + f : str or Path or file-like object + Path to file or file-like object. + + Returns + ------- + model + Model instance. + """ + data = read_bytes(f) + loaded = pickle.loads(data) + return loaded diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py new file mode 100644 index 00000000..e1f2ad11 --- /dev/null +++ b/tests/models/test_serialization.py @@ -0,0 +1,41 @@ +import typing as tp +from tempfile import NamedTemporaryFile + +import pytest +from implicit.als import AlternatingLeastSquares +from implicit.nearest_neighbours import ItemItemRecommender +from lightfm import LightFM + +from rectools.models import ( + ImplicitALSWrapperModel, + ImplicitItemKNNWrapperModel, + LightFMWrapperModel, + PopularInCategoryModel, + load_model, +) +from rectools.models.base import ModelBase + +from .utils import get_final_successors + +MODEL_CLASSES = [cls for cls in get_final_successors(ModelBase) if cls.__module__.startswith("rectools.models")] + + +def init_default_model(model_cls: tp.Type[ModelBase]) -> ModelBase: + mandatory_params = { + ImplicitItemKNNWrapperModel: {"model": ItemItemRecommender()}, + ImplicitALSWrapperModel: {"model": AlternatingLeastSquares()}, + LightFMWrapperModel: {"model": LightFM()}, + PopularInCategoryModel: {"category_feature": "some_feature"}, + } + params = mandatory_params.get(model_cls, {}) + model = model_cls(**params) + return model + + +@pytest.mark.parametrize("model_cls", MODEL_CLASSES) +def test_load_model(model_cls: tp.Type[ModelBase]) -> None: + model = init_default_model(model_cls) + with NamedTemporaryFile() as f: + model.save(f.name) + loaded_model = load_model(f.name) + assert isinstance(loaded_model, model_cls) diff --git a/tests/models/utils.py b/tests/models/utils.py index 7aca04fb..34f09662 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -95,3 +95,16 @@ def get_reco(model: ModelBase) -> pd.DataFrame: assert config_1 == config_2 pd.testing.assert_frame_equal(reco_1, reco_2) + + +def get_final_successors(cls: tp.Type) -> tp.List[tp.Type]: + final_classes = [] + subclasses = cls.__subclasses__() + + if not subclasses: # If there are no subclasses, it's a final class + final_classes.append(cls) + else: + for subclass in subclasses: + final_classes.extend(get_final_successors(subclass)) # Recursively check subclasses + + return final_classes From 8ee764ffec395cc720ecbe968761c330b80295ae Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Sat, 30 Nov 2024 12:33:30 +0100 Subject: [PATCH 03/25] updated changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d424f512..84f3317c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Model configs example ([#207](https://github.com/MobileTeleSystems/RecTools/pull/207)) - `use_gpu` argument to `ImplicitRanker.rank` method ([#201](https://github.com/MobileTeleSystems/RecTools/pull/201)) - `keep_extra_cols` argument to `Dataset.construct` and `Interactions.from_raw` methods. `include_extra_cols` argument to `Dataset.get_raw_interactions` and `Interactions.to_external` methods ([#208](https://github.com/MobileTeleSystems/RecTools/pull/208)) +- `load_model` function ([#213](https://github.com/MobileTeleSystems/RecTools/pull/213)) ## [0.8.0] - 28.08.2024 From 5af9c65e94d64e428e67cbcc0bb47dc6f283b303 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Sat, 30 Nov 2024 13:03:24 +0100 Subject: [PATCH 04/25] disabled lightfm tests for python >= 3.12 --- tests/models/test_serialization.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index e1f2ad11..febfc7fb 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -4,7 +4,13 @@ import pytest from implicit.als import AlternatingLeastSquares from implicit.nearest_neighbours import ItemItemRecommender -from lightfm import LightFM + +try: + from lightfm import LightFM +except ImportError: + LightFM = object # it's ok in case we're skipping the tests + +import sys from rectools.models import ( ImplicitALSWrapperModel, @@ -17,7 +23,11 @@ from .utils import get_final_successors -MODEL_CLASSES = [cls for cls in get_final_successors(ModelBase) if cls.__module__.startswith("rectools.models")] +MODEL_CLASSES = [ + cls + for cls in get_final_successors(ModelBase) + if cls.__module__.startswith("rectools.models") and not (sys.version_info >= (3, 12) and cls is LightFMWrapperModel) +] def init_default_model(model_cls: tp.Type[ModelBase]) -> ModelBase: From 5bb25a25c8dc99547a2f26ad2d77db8d63e95261 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Mon, 2 Dec 2024 23:21:13 +0100 Subject: [PATCH 05/25] added model_from_config --- rectools/models/__init__.py | 3 ++- rectools/models/serialization.py | 34 ++++++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/rectools/models/__init__.py b/rectools/models/__init__.py index c511e755..fa4308a3 100644 --- a/rectools/models/__init__.py +++ b/rectools/models/__init__.py @@ -43,7 +43,7 @@ from .popular_in_category import PopularInCategoryModel from .pure_svd import PureSVDModel from .random import RandomModel -from .serialization import load_model +from .serialization import load_model, model_from_config try: from .lightfm import LightFMWrapperModel @@ -67,4 +67,5 @@ "RandomModel", "DSSMModel", "load_model", + "model_from_config", ) diff --git a/rectools/models/serialization.py b/rectools/models/serialization.py index 5341bb75..305d0fca 100644 --- a/rectools/models/serialization.py +++ b/rectools/models/serialization.py @@ -1,6 +1,6 @@ import pickle - -from rectools.models.base import ModelBase +import typing as tp +from rectools.models.base import ModelBase, ModelConfig, deserialize_model_class from rectools.utils.serialization import FileLike, read_bytes @@ -21,3 +21,33 @@ def load_model(f: FileLike) -> ModelBase: data = read_bytes(f) loaded = pickle.loads(data) return loaded + + +def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase: + """ + Create model from config. + + Parameters + ---------- + config : ModelConfig + Model config. + + Returns + ------- + model + Model instance. + """ + def raise_on_none(model_cls: tp.Any) -> None: + if model_cls is None: + raise ValueError("`cls` must be provided in the config to load the model") + + if isinstance(config, dict): + model_cls = deserialize_model_class(config.get("cls")) + raise_on_none(model_cls) + if not issubclass(model_cls, ModelBase): + raise TypeError("`cls` must be (or refer to) a subclass of `ModelBase`") + else: + model_cls = config.cls + raise_on_none(model_cls) + + return model_cls.from_config(config) \ No newline at end of file From 96f84d3971257eed14c6c018233bf9aeb658c05b Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Mon, 2 Dec 2024 23:21:30 +0100 Subject: [PATCH 06/25] added some configs --- rectools/models/base.py | 33 +++++++++++++++++++++++++++++++-- rectools/models/ease.py | 2 +- rectools/models/pure_svd.py | 1 + 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/rectools/models/base.py b/rectools/models/base.py index 379429cf..0101f144 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -21,6 +21,7 @@ import numpy as np import pandas as pd +from pydantic import BeforeValidator, PlainSerializer import typing_extensions as tpe from pydantic_core import PydanticSerializationError @@ -30,7 +31,7 @@ from rectools.exceptions import NotFittedError from rectools.types import ExternalIdsArray, InternalIdsArray from rectools.utils.config import BaseConfig -from rectools.utils.misc import make_dict_flat +from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat from rectools.utils.serialization import PICKLE_PROTOCOL, FileLike, read_bytes T = tp.TypeVar("T", bound="ModelBase") @@ -45,9 +46,32 @@ RecoTriplet_T = tp.TypeVar("RecoTriplet_T", InternalRecoTriplet, SemiInternalRecoTriplet, ExternalRecoTriplet) +def deserialize_model_class(spec: tp.Any) -> tp.Any: + if not isinstance(spec, str): + return spec + # TODO: add short names for built-in models + return import_object(spec) + + +def _serialize_model_class(cls: tp.Type["ModelBase"]) -> str: + # TODO: add short names for built-in models + return get_class_or_function_full_path(cls) + + +ModelClass = tpe.Annotated[ + tp.Type["ModelBase"], + BeforeValidator(deserialize_model_class), + PlainSerializer( + func=_serialize_model_class, + return_type=str, + when_used="json", + ), +] + + class ModelConfig(BaseConfig): """Base model config.""" - + cls: tp.Optional[ModelClass] = None verbose: int = 0 @@ -173,6 +197,10 @@ def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self: config_obj = cls.config_class.model_validate(config) else: config_obj = config + + if config_obj.cls is not None and config_obj.cls is not cls: + raise TypeError(f"`{cls.__name__}` is used, but config is for `{config_obj.cls.__name__}`") + return cls._from_config(config_obj) @classmethod @@ -735,6 +763,7 @@ def _recommend_i2i( ) -> InternalRecoTriplet: raise NotImplementedError() +ModelConfig.model_rebuild() class FixedColdRecoModelMixin: """ diff --git a/rectools/models/ease.py b/rectools/models/ease.py index 3139a72e..d8dc3f9c 100644 --- a/rectools/models/ease.py +++ b/rectools/models/ease.py @@ -75,7 +75,7 @@ def __init__( self.num_threads = num_threads def _get_config(self) -> EASEModelConfig: - return EASEModelConfig(regularization=self.regularization, num_threads=self.num_threads, verbose=self.verbose) + return EASEModelConfig(cls=self.__class__, regularization=self.regularization, num_threads=self.num_threads, verbose=self.verbose) @classmethod def _from_config(cls, config: EASEModelConfig) -> tpe.Self: diff --git a/rectools/models/pure_svd.py b/rectools/models/pure_svd.py index 9984bcff..dc2ca816 100644 --- a/rectools/models/pure_svd.py +++ b/rectools/models/pure_svd.py @@ -84,6 +84,7 @@ def __init__( def _get_config(self) -> PureSVDModelConfig: return PureSVDModelConfig( + cls=self.__class__, factors=self.factors, tol=self.tol, maxiter=self.maxiter, From 4ea7f53cc2fa2165463a785a996485aaa359f34c Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Mon, 2 Dec 2024 23:21:37 +0100 Subject: [PATCH 07/25] added some tests --- tests/models/test_serialization.py | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index febfc7fb..fa3b5536 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -18,6 +18,8 @@ LightFMWrapperModel, PopularInCategoryModel, load_model, + model_from_config, + PureSVDModel ) from rectools.models.base import ModelBase @@ -49,3 +51,36 @@ def test_load_model(model_cls: tp.Type[ModelBase]) -> None: model.save(f.name) loaded_model = load_model(f.name) assert isinstance(loaded_model, model_cls) + + +class TestModelFromConfig: + @pytest.mark.parametrize("model_cls", MODEL_CLASSES) + @pytest.mark.parametrize( + "mode, simple_types", + ( + ("pydantic", False), + ("dict", False), + ("dict", True), + ), + ) + def test_standard_model_creation(self, model_cls: tp.Type[ModelBase], mode: tp.Literal["pydantic", "dict"], simple_types: bool) -> None: + model = init_default_model(model_cls) + config = model.get_config(mode=mode, simple_types=simple_types) + + new_model = model_from_config(config) + + assert isinstance(new_model, model_cls) + assert new_model.get_config(mode=mode, simple_types=simple_types) == config + + def test_custom_model_creation(self) -> None: + pass + + def test_fails_on_missing_cls(self) -> None: + # check none and missing cls + pass + + def test_fails_on_nonexistent_cls(self) -> None: + pass + + def test_fails_on_non_model_cls(self) -> None: + pass \ No newline at end of file From 7dfb39f87626cda7447bf6c812239fe6377d2dbe Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Mon, 2 Dec 2024 23:22:14 +0100 Subject: [PATCH 08/25] formatted --- rectools/models/base.py | 7 +++++-- rectools/models/ease.py | 4 +++- rectools/models/serialization.py | 6 ++++-- tests/models/test_serialization.py | 12 +++++++----- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/rectools/models/base.py b/rectools/models/base.py index 0101f144..005710a1 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -21,8 +21,8 @@ import numpy as np import pandas as pd -from pydantic import BeforeValidator, PlainSerializer import typing_extensions as tpe +from pydantic import BeforeValidator, PlainSerializer from pydantic_core import PydanticSerializationError from rectools import Columns, ExternalIds, InternalIds @@ -71,6 +71,7 @@ def _serialize_model_class(cls: tp.Type["ModelBase"]) -> str: class ModelConfig(BaseConfig): """Base model config.""" + cls: tp.Optional[ModelClass] = None verbose: int = 0 @@ -200,7 +201,7 @@ def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self: if config_obj.cls is not None and config_obj.cls is not cls: raise TypeError(f"`{cls.__name__}` is used, but config is for `{config_obj.cls.__name__}`") - + return cls._from_config(config_obj) @classmethod @@ -763,8 +764,10 @@ def _recommend_i2i( ) -> InternalRecoTriplet: raise NotImplementedError() + ModelConfig.model_rebuild() + class FixedColdRecoModelMixin: """ Mixin for models that have fixed cold recommendations. diff --git a/rectools/models/ease.py b/rectools/models/ease.py index d8dc3f9c..7c90bc25 100644 --- a/rectools/models/ease.py +++ b/rectools/models/ease.py @@ -75,7 +75,9 @@ def __init__( self.num_threads = num_threads def _get_config(self) -> EASEModelConfig: - return EASEModelConfig(cls=self.__class__, regularization=self.regularization, num_threads=self.num_threads, verbose=self.verbose) + return EASEModelConfig( + cls=self.__class__, regularization=self.regularization, num_threads=self.num_threads, verbose=self.verbose + ) @classmethod def _from_config(cls, config: EASEModelConfig) -> tpe.Self: diff --git a/rectools/models/serialization.py b/rectools/models/serialization.py index 305d0fca..de735dde 100644 --- a/rectools/models/serialization.py +++ b/rectools/models/serialization.py @@ -1,5 +1,6 @@ import pickle import typing as tp + from rectools.models.base import ModelBase, ModelConfig, deserialize_model_class from rectools.utils.serialization import FileLike, read_bytes @@ -37,6 +38,7 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase: model Model instance. """ + def raise_on_none(model_cls: tp.Any) -> None: if model_cls is None: raise ValueError("`cls` must be provided in the config to load the model") @@ -49,5 +51,5 @@ def raise_on_none(model_cls: tp.Any) -> None: else: model_cls = config.cls raise_on_none(model_cls) - - return model_cls.from_config(config) \ No newline at end of file + + return model_cls.from_config(config) diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index fa3b5536..e0038be6 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -17,9 +17,9 @@ ImplicitItemKNNWrapperModel, LightFMWrapperModel, PopularInCategoryModel, + PureSVDModel, load_model, model_from_config, - PureSVDModel ) from rectools.models.base import ModelBase @@ -56,14 +56,16 @@ def test_load_model(model_cls: tp.Type[ModelBase]) -> None: class TestModelFromConfig: @pytest.mark.parametrize("model_cls", MODEL_CLASSES) @pytest.mark.parametrize( - "mode, simple_types", + "mode, simple_types", ( ("pydantic", False), ("dict", False), - ("dict", True), + ("dict", True), ), ) - def test_standard_model_creation(self, model_cls: tp.Type[ModelBase], mode: tp.Literal["pydantic", "dict"], simple_types: bool) -> None: + def test_standard_model_creation( + self, model_cls: tp.Type[ModelBase], mode: tp.Literal["pydantic", "dict"], simple_types: bool + ) -> None: model = init_default_model(model_cls) config = model.get_config(mode=mode, simple_types=simple_types) @@ -83,4 +85,4 @@ def test_fails_on_nonexistent_cls(self) -> None: pass def test_fails_on_non_model_cls(self) -> None: - pass \ No newline at end of file + pass From fb321dd4933921fb0522cf796d5676552a39aa7d Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Tue, 3 Dec 2024 08:04:13 +0100 Subject: [PATCH 09/25] added cls to all configs --- rectools/models/implicit_als.py | 1 + rectools/models/implicit_knn.py | 1 + rectools/models/lightfm.py | 1 + rectools/models/popular.py | 1 + rectools/models/popular_in_category.py | 1 + rectools/models/random.py | 2 +- 6 files changed, 6 insertions(+), 1 deletion(-) diff --git a/rectools/models/implicit_als.py b/rectools/models/implicit_als.py index 737ea202..319fc2dc 100644 --- a/rectools/models/implicit_als.py +++ b/rectools/models/implicit_als.py @@ -173,6 +173,7 @@ def _make_config( model_cls = model.__class__ return ImplicitALSWrapperModelConfig( + cls=cls, model=AlternatingLeastSquaresConfig( cls=( model_cls diff --git a/rectools/models/implicit_knn.py b/rectools/models/implicit_knn.py index 3989146f..89e8354b 100644 --- a/rectools/models/implicit_knn.py +++ b/rectools/models/implicit_knn.py @@ -116,6 +116,7 @@ def _get_config(self) -> ImplicitItemKNNWrapperModelConfig: # NOBUG: If it's a custom class, we don't know its params params.update({"K1": inner_model.K1, "B": inner_model.B}) return ImplicitItemKNNWrapperModelConfig( + cls=self.__class__, model=ItemItemRecommenderConfig( cls=inner_model.__class__, params=params, diff --git a/rectools/models/lightfm.py b/rectools/models/lightfm.py index 40b93189..bfd970ca 100644 --- a/rectools/models/lightfm.py +++ b/rectools/models/lightfm.py @@ -157,6 +157,7 @@ def _get_config(self) -> LightFMWrapperModelConfig: } inner_model_cls = inner_model.__class__ return LightFMWrapperModelConfig( + cls=self.__class__, model=LightFMConfig( cls=inner_model_cls, params=tp.cast(LightFMParams, params), # https://github.com/python/mypy/issues/8890 diff --git a/rectools/models/popular.py b/rectools/models/popular.py index 29708b10..7792ee7e 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -187,6 +187,7 @@ def __init__( def _get_config(self) -> PopularModelConfig: return PopularModelConfig( + cls=self.__class__, popularity=self.popularity, period=self.period, begin_from=self.begin_from, diff --git a/rectools/models/popular_in_category.py b/rectools/models/popular_in_category.py index 4f6416c4..d93a763b 100644 --- a/rectools/models/popular_in_category.py +++ b/rectools/models/popular_in_category.py @@ -162,6 +162,7 @@ def __init__( def _get_config(self) -> PopularInCategoryModelConfig: return PopularInCategoryModelConfig( + cls=self.__class__, category_feature=self.category_feature, n_categories=self.n_categories, mixing_strategy=self.mixing_strategy, diff --git a/rectools/models/random.py b/rectools/models/random.py index 3b3ed4e9..ace645b9 100644 --- a/rectools/models/random.py +++ b/rectools/models/random.py @@ -88,7 +88,7 @@ def __init__(self, random_state: tp.Optional[int] = None, verbose: int = 0): self.all_item_ids: np.ndarray def _get_config(self) -> RandomModelConfig: - return RandomModelConfig(random_state=self.random_state, verbose=self.verbose) + return RandomModelConfig(cls=self.__class__, random_state=self.random_state, verbose=self.verbose) @classmethod def _from_config(cls, config: RandomModelConfig) -> tpe.Self: From b781f904354b6933a961cd4cb57b2464cbca8af4 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Tue, 3 Dec 2024 11:20:41 +0100 Subject: [PATCH 10/25] added some tests --- rectools/models/base.py | 9 ++++ tests/models/test_serialization.py | 67 +++++++++++++++++++++++------- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/rectools/models/base.py b/rectools/models/base.py index cd21f511..844fb1c0 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -47,6 +47,15 @@ def deserialize_model_class(spec: tp.Any) -> tp.Any: + """ + Get model class from specification. + + Parameters + ---------- + spec : str or type + Specification of model class. + + """ if not isinstance(spec, str): return spec # TODO: add short names for built-in models diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index c394b08e..971d30dd 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -2,6 +2,7 @@ import typing as tp from tempfile import NamedTemporaryFile +from pydantic import ValidationError import pytest from implicit.als import AlternatingLeastSquares from implicit.nearest_neighbours import ItemItemRecommender @@ -13,11 +14,12 @@ from rectools.models import ( + DSSMModel, ImplicitALSWrapperModel, ImplicitItemKNNWrapperModel, LightFMWrapperModel, PopularInCategoryModel, - PureSVDModel, + PopularModel, load_model, model_from_config, ) @@ -30,6 +32,7 @@ for cls in get_final_successors(ModelBase) if cls.__module__.startswith("rectools.models") and not (sys.version_info >= (3, 12) and cls is LightFMWrapperModel) ] +CONFIGURABLE_MODEL_CLASSES = [cls for cls in MODEL_CLASSES if cls not in (DSSMModel,)] def init_default_model(model_cls: tp.Type[ModelBase]) -> ModelBase: @@ -54,15 +57,9 @@ def test_load_model(model_cls: tp.Type[ModelBase]) -> None: class TestModelFromConfig: - @pytest.mark.parametrize("model_cls", MODEL_CLASSES) - @pytest.mark.parametrize( - "mode, simple_types", - ( - ("pydantic", False), - ("dict", False), - ("dict", True), - ), - ) + + @pytest.mark.parametrize("mode, simple_types", (("pydantic", False), ("dict", False), ("dict", True))) + @pytest.mark.parametrize("model_cls", CONFIGURABLE_MODEL_CLASSES) def test_standard_model_creation( self, model_cls: tp.Type[ModelBase], mode: tp.Literal["pydantic", "dict"], simple_types: bool ) -> None: @@ -77,12 +74,50 @@ def test_standard_model_creation( def test_custom_model_creation(self) -> None: pass - def test_fails_on_missing_cls(self) -> None: - # check none and missing cls - pass - - def test_fails_on_nonexistent_cls(self) -> None: - pass + @pytest.mark.parametrize("simple_types", (False, True)) + def test_fails_on_missing_cls(self, simple_types: bool) -> None: + model = PopularModel() + config = model.get_config(mode="dict", simple_types=simple_types) + config.pop("cls") + with pytest.raises(ValueError, match="`cls` must be provided in the config to load the model"): + model_from_config(config) + + @pytest.mark.parametrize("mode, simple_types", (("pydantic", False), ("dict", False), ("dict", True))) + def test_fails_on_none_cls(self, mode: tp.Literal["pydantic", "dict"], simple_types: bool) -> None: + model = PopularModel() + config = model.get_config(mode=mode, simple_types=simple_types) + if mode == "pydantic": + config.cls = None + else: + config["cls"] = None + with pytest.raises(ValueError, match="`cls` must be provided in the config to load the model"): + model_from_config(config) + + def test_fails_on_nonexistent_cls(self, ) -> None: + model = PopularModel() + config = model.get_config(mode=mode, simple_types=simple_types) + if mode == "pydantic": + config.cls = None + else: + config["cls"] = None def test_fails_on_non_model_cls(self) -> None: pass + + @pytest.mark.parametrize("mode, simple_types", (("pydantic", False), ("dict", False), ("dict", True))) + def test_fails_on_incorrect_model_cls(self, mode: tp.Literal["pydantic", "dict"], simple_types: bool) -> None: + model = PopularModel() + config = model.get_config(mode=mode, simple_types=simple_types) + if mode == "pydantic": + config.cls = LightFMWrapperModel + else: + if simple_types: + config["cls"] = "rectools.models.LightFMWrapperModel" + else: + config["cls"] = LightFMWrapperModel + with pytest.raises(ValidationError): + model_from_config(config) + + def test_fails_on_model_cls_without_from_config_support(self) -> None: + # DSSM + pass \ No newline at end of file From 3dab3847c64b5396dddb9811ebbc711712a11e8b Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Tue, 3 Dec 2024 13:24:37 +0100 Subject: [PATCH 11/25] finished model_from_config tests --- tests/models/test_serialization.py | 65 +++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index 971d30dd..5fed1a30 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -1,3 +1,4 @@ +import re import sys import typing as tp from tempfile import NamedTemporaryFile @@ -13,6 +14,7 @@ LightFM = object # it's ok in case we're skipping the tests +from rectools.metrics import NDCG from rectools.models import ( DSSMModel, ImplicitALSWrapperModel, @@ -23,7 +25,7 @@ load_model, model_from_config, ) -from rectools.models.base import ModelBase +from rectools.models.base import ModelBase, ModelConfig from .utils import get_final_successors @@ -56,6 +58,21 @@ def test_load_model(model_cls: tp.Type[ModelBase]) -> None: assert isinstance(loaded_model, model_cls) +class CustomModelConfig(ModelConfig): + some_param: int = 1 + + +class CustomModel(ModelBase[CustomModelConfig]): + config_class = CustomModelConfig + + def __init__(self, some_param: int = 1, verbose: int = 0): + self.some_param = some_param + + @classmethod + def _from_config(cls, config: CustomModelConfig) -> "CustomModel": + return cls(some_param=config.some_param, verbose=config.verbose) + + class TestModelFromConfig: @pytest.mark.parametrize("mode, simple_types", (("pydantic", False), ("dict", False), ("dict", True))) @@ -71,8 +88,17 @@ def test_standard_model_creation( assert isinstance(new_model, model_cls) assert new_model.get_config(mode=mode, simple_types=simple_types) == config - def test_custom_model_creation(self) -> None: - pass + @pytest.mark.parametrize( + "config", + ( + CustomModelConfig(cls=CustomModel, some_param=2), + {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2}, + ) + ) + def test_custom_model_creation(self, config: tp.Union[dict, CustomModelConfig]) -> None: + model = model_from_config(config) + assert isinstance(model, CustomModel) + assert model.some_param == 2 @pytest.mark.parametrize("simple_types", (False, True)) def test_fails_on_missing_cls(self, simple_types: bool) -> None: @@ -93,16 +119,23 @@ def test_fails_on_none_cls(self, mode: tp.Literal["pydantic", "dict"], simple_ty with pytest.raises(ValueError, match="`cls` must be provided in the config to load the model"): model_from_config(config) - def test_fails_on_nonexistent_cls(self, ) -> None: - model = PopularModel() - config = model.get_config(mode=mode, simple_types=simple_types) - if mode == "pydantic": - config.cls = None - else: - config["cls"] = None + @pytest.mark.parametrize( + "model_cls_path, error_cls", + ( + ("nonexistent_module.SomeModel", ModuleNotFoundError), + ("rectools.models.NonexistentModel", AttributeError), + ) + ) + def test_fails_on_nonexistent_cls(self, model_cls_path: str, error_cls: tp.Type[Exception]) -> None: + config = {"cls": model_cls_path} + with pytest.raises(error_cls): + model_from_config(config) - def test_fails_on_non_model_cls(self) -> None: - pass + @pytest.mark.parametrize("model_cls", ("rectools.metrics.NDCG", NDCG)) + def test_fails_on_non_model_cls(self, model_cls: tp.Any) -> None: + config = {"cls": model_cls} + with pytest.raises(TypeError, match=re.escape("`cls` must be (or refer to) a subclass of `ModelBase`")): + model_from_config(config) @pytest.mark.parametrize("mode, simple_types", (("pydantic", False), ("dict", False), ("dict", True))) def test_fails_on_incorrect_model_cls(self, mode: tp.Literal["pydantic", "dict"], simple_types: bool) -> None: @@ -118,6 +151,8 @@ def test_fails_on_incorrect_model_cls(self, mode: tp.Literal["pydantic", "dict"] with pytest.raises(ValidationError): model_from_config(config) - def test_fails_on_model_cls_without_from_config_support(self) -> None: - # DSSM - pass \ No newline at end of file + @pytest.mark.parametrize("model_cls", ("rectools.models.DSSMModel", DSSMModel)) + def test_fails_on_model_cls_without_from_config_support(self, model_cls: tp.Any) -> None: + config = {"cls": model_cls} + with pytest.raises(NotImplementedError, match="`from_config` method is not implemented for `DSSMModel` model"): + model_from_config(config) From 062eefb343acdd36f523c0dabade361ec0af864a Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Tue, 3 Dec 2024 13:32:40 +0100 Subject: [PATCH 12/25] used type adapter --- rectools/models/serialization.py | 18 ++++++++---------- tests/models/test_serialization.py | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/rectools/models/serialization.py b/rectools/models/serialization.py index de735dde..bdeb2224 100644 --- a/rectools/models/serialization.py +++ b/rectools/models/serialization.py @@ -1,7 +1,9 @@ import pickle import typing as tp -from rectools.models.base import ModelBase, ModelConfig, deserialize_model_class +from pydantic import TypeAdapter + +from rectools.models.base import ModelBase, ModelConfig, deserialize_model_class, ModelClass from rectools.utils.serialization import FileLike, read_bytes @@ -39,17 +41,13 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase: Model instance. """ - def raise_on_none(model_cls: tp.Any) -> None: - if model_cls is None: - raise ValueError("`cls` must be provided in the config to load the model") - if isinstance(config, dict): - model_cls = deserialize_model_class(config.get("cls")) - raise_on_none(model_cls) - if not issubclass(model_cls, ModelBase): - raise TypeError("`cls` must be (or refer to) a subclass of `ModelBase`") + model_cls = config.get("cls") + model_cls = TypeAdapter(tp.Optional[ModelClass]).validate_python(model_cls) else: model_cls = config.cls - raise_on_none(model_cls) + + if model_cls is None: + raise ValueError("`cls` must be provided in the config to load the model") return model_cls.from_config(config) diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index 5fed1a30..66a100b6 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -134,7 +134,7 @@ def test_fails_on_nonexistent_cls(self, model_cls_path: str, error_cls: tp.Type[ @pytest.mark.parametrize("model_cls", ("rectools.metrics.NDCG", NDCG)) def test_fails_on_non_model_cls(self, model_cls: tp.Any) -> None: config = {"cls": model_cls} - with pytest.raises(TypeError, match=re.escape("`cls` must be (or refer to) a subclass of `ModelBase`")): + with pytest.raises(ValidationError): model_from_config(config) @pytest.mark.parametrize("mode, simple_types", (("pydantic", False), ("dict", False), ("dict", True))) From 4d25e3aa2f0d3a35714cd2c06230c5b62be845e3 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Tue, 3 Dec 2024 13:33:23 +0100 Subject: [PATCH 13/25] made deserialization func private again --- rectools/models/base.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/rectools/models/base.py b/rectools/models/base.py index 844fb1c0..02e2ab53 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -46,16 +46,7 @@ RecoTriplet_T = tp.TypeVar("RecoTriplet_T", InternalRecoTriplet, SemiInternalRecoTriplet, ExternalRecoTriplet) -def deserialize_model_class(spec: tp.Any) -> tp.Any: - """ - Get model class from specification. - - Parameters - ---------- - spec : str or type - Specification of model class. - - """ +def _deserialize_model_class(spec: tp.Any) -> tp.Any: if not isinstance(spec, str): return spec # TODO: add short names for built-in models From 7d204e756be0fb4894677f0eb19c20fecc0ae74a Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Wed, 4 Dec 2024 23:54:53 +0100 Subject: [PATCH 14/25] improved tests --- tests/models/test_serialization.py | 21 ++++++++++++++------- tests/models/utils.py | 16 ++++++---------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index 66a100b6..5ebcc7f7 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -26,15 +26,22 @@ model_from_config, ) from rectools.models.base import ModelBase, ModelConfig +from rectools.models.vector import VectorModel -from .utils import get_final_successors +from .utils import get_successors -MODEL_CLASSES = [ +INTERMEDIATE_MODEL_CLASSES = (VectorModel,) + +EXPOSABLE_MODEL_CLASSES = tuple( cls - for cls in get_final_successors(ModelBase) - if cls.__module__.startswith("rectools.models") and not (sys.version_info >= (3, 12) and cls is LightFMWrapperModel) -] -CONFIGURABLE_MODEL_CLASSES = [cls for cls in MODEL_CLASSES if cls not in (DSSMModel,)] + for cls in get_successors(ModelBase) + if ( + cls.__module__.startswith("rectools.models") + and cls not in INTERMEDIATE_MODEL_CLASSES + and not (sys.version_info >= (3, 12) and cls is LightFMWrapperModel) + ) +) +CONFIGURABLE_MODEL_CLASSES = tuple(cls for cls in EXPOSABLE_MODEL_CLASSES if cls not in (DSSMModel,)) def init_default_model(model_cls: tp.Type[ModelBase]) -> ModelBase: @@ -49,7 +56,7 @@ def init_default_model(model_cls: tp.Type[ModelBase]) -> ModelBase: return model -@pytest.mark.parametrize("model_cls", MODEL_CLASSES) +@pytest.mark.parametrize("model_cls", EXPOSABLE_MODEL_CLASSES) def test_load_model(model_cls: tp.Type[ModelBase]) -> None: model = init_default_model(model_cls) with NamedTemporaryFile() as f: diff --git a/tests/models/utils.py b/tests/models/utils.py index 34f09662..e66d823a 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -97,14 +97,10 @@ def get_reco(model: ModelBase) -> pd.DataFrame: pd.testing.assert_frame_equal(reco_1, reco_2) -def get_final_successors(cls: tp.Type) -> tp.List[tp.Type]: - final_classes = [] +def get_successors(cls: tp.Type) -> tp.List[tp.Type]: + successors = [] subclasses = cls.__subclasses__() - - if not subclasses: # If there are no subclasses, it's a final class - final_classes.append(cls) - else: - for subclass in subclasses: - final_classes.extend(get_final_successors(subclass)) # Recursively check subclasses - - return final_classes + for subclass in subclasses: + successors.append(subclass) + successors.extend(get_successors(subclass)) + return successors From 33b7ba81a27e97249beefc08da244c220c3ee056 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Wed, 4 Dec 2024 23:55:56 +0100 Subject: [PATCH 15/25] added short paths for standard models --- rectools/models/base.py | 13 +++++++++---- rectools/models/serialization.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/rectools/models/base.py b/rectools/models/base.py index 02e2ab53..23997ea4 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -46,21 +46,26 @@ RecoTriplet_T = tp.TypeVar("RecoTriplet_T", InternalRecoTriplet, SemiInternalRecoTriplet, ExternalRecoTriplet) +STANDARD_MODEL_PATH_PREFIX = "rectools.models" + + def _deserialize_model_class(spec: tp.Any) -> tp.Any: if not isinstance(spec, str): return spec - # TODO: add short names for built-in models + if "." not in spec: + spec = f"{STANDARD_MODEL_PATH_PREFIX}.{spec}" # EaseModel -> rectools.models.EaseModel return import_object(spec) def _serialize_model_class(cls: tp.Type["ModelBase"]) -> str: - # TODO: add short names for built-in models - return get_class_or_function_full_path(cls) + path = get_class_or_function_full_path(cls) + if path.startswith(STANDARD_MODEL_PATH_PREFIX): + return path.split(".")[-1] # rectools.models.ease.EASEModel -> EASEModel ModelClass = tpe.Annotated[ tp.Type["ModelBase"], - BeforeValidator(deserialize_model_class), + BeforeValidator(_deserialize_model_class), PlainSerializer( func=_serialize_model_class, return_type=str, diff --git a/rectools/models/serialization.py b/rectools/models/serialization.py index bdeb2224..408f5239 100644 --- a/rectools/models/serialization.py +++ b/rectools/models/serialization.py @@ -3,7 +3,7 @@ from pydantic import TypeAdapter -from rectools.models.base import ModelBase, ModelConfig, deserialize_model_class, ModelClass +from rectools.models.base import ModelBase, ModelConfig, ModelClass from rectools.utils.serialization import FileLike, read_bytes From 7a665fc762fb240e4f89536eed59932aecfad446 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 08:19:06 +0100 Subject: [PATCH 16/25] adjusted tests --- rectools/models/base.py | 1 + tests/models/test_base.py | 12 ++++++++---- tests/models/test_ease.py | 6 ++++-- tests/models/test_implicit_als.py | 1 + tests/models/test_implicit_knn.py | 1 + tests/models/test_lightfm.py | 1 + tests/models/test_popular.py | 1 + tests/models/test_popular_in_category.py | 1 + tests/models/test_pure_svd.py | 6 ++++-- tests/models/test_random.py | 6 ++++-- 10 files changed, 26 insertions(+), 10 deletions(-) diff --git a/rectools/models/base.py b/rectools/models/base.py index 23997ea4..c7569918 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -61,6 +61,7 @@ def _serialize_model_class(cls: tp.Type["ModelBase"]) -> str: path = get_class_or_function_full_path(cls) if path.startswith(STANDARD_MODEL_PATH_PREFIX): return path.split(".")[-1] # rectools.models.ease.EASEModel -> EASEModel + return path ModelClass = tpe.Annotated[ diff --git a/tests/models/test_base.py b/tests/models/test_base.py index 21641784..8c359da8 100644 --- a/tests/models/test_base.py +++ b/tests/models/test_base.py @@ -452,7 +452,7 @@ def __init__(self, x: int, td: tp.Optional[timedelta] = None, verbose: int = 0): def _get_config(self) -> SomeModelConfig: sc = None if self.td is None else SomeModelSubConfig(td=self.td) - return SomeModelConfig(x=self.x, sc=sc, verbose=self.verbose) + return SomeModelConfig(cls=self.__class__, x=self.x, sc=sc, verbose=self.verbose) @classmethod def _from_config(cls, config: SomeModelConfig) -> tpe.Self: @@ -461,6 +461,7 @@ def _from_config(cls, config: SomeModelConfig) -> tpe.Self: self.config_class = SomeModelConfig self.model_class = SomeModel + self.model_class_path = "tests.models.test_base.TestConfiguration.setup_method..SomeModel" def test_from_pydantic_config(self) -> None: config = self.config_class(x=10, verbose=1) @@ -503,7 +504,8 @@ def test_raises_on_pydantic_with_simple_types(self) -> None: def test_get_config_dict(self, simple_types: bool, expected_td: tp.Union[timedelta, str]) -> None: model = self.model_class(x=10, verbose=1, td=timedelta(days=2, hours=3)) config = model.get_config(mode="dict", simple_types=simple_types) - assert config == {"x": 10, "verbose": 1, "sc": {"td": expected_td}} + expected_cls = self.model_class_path if simple_types else self.model_class + assert config == {"cls": expected_cls, "x": 10, "verbose": 1, "sc": {"td": expected_td}} def test_raises_on_incorrect_format(self) -> None: model = self.model_class(x=10, verbose=1) @@ -514,13 +516,15 @@ def test_raises_on_incorrect_format(self) -> None: def test_get_params(self, simple_types: bool, expected_td: tp.Union[timedelta, str]) -> None: model = self.model_class(x=10, verbose=1, td=timedelta(days=2, hours=3)) config = model.get_params(simple_types=simple_types) - assert config == {"x": 10, "verbose": 1, "sc.td": expected_td} + expected_cls = self.model_class_path if simple_types else self.model_class + assert config == {"cls": expected_cls, "x": 10, "verbose": 1, "sc.td": expected_td} @pytest.mark.parametrize("simple_types", (False, True)) def test_get_params_with_empty_subconfig(self, simple_types: bool) -> None: model = self.model_class(x=10, verbose=1, td=None) config = model.get_params(simple_types=simple_types) - assert config == {"x": 10, "verbose": 1, "sc": None} + expected_cls = self.model_class_path if simple_types else self.model_class + assert config == {"cls": expected_cls, "x": 10, "verbose": 1, "sc": None} def test_model_without_implemented_config_from_config(self) -> None: class MyModelWithoutConfig(ModelBase): diff --git a/tests/models/test_ease.py b/tests/models/test_ease.py index 20fc1701..15e26dbb 100644 --- a/tests/models/test_ease.py +++ b/tests/models/test_ease.py @@ -244,14 +244,16 @@ def test_from_config(self) -> None: assert model.verbose == 1 assert model.regularization == 500 - def test_get_config(self) -> None: + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config(self, simple_types: bool) -> None: model = EASEModel( regularization=500, num_threads=1, verbose=1, ) - config = model.get_config() + config = model.get_config(simple_types=simple_types) expected = { + "cls": "EASEModel" if simple_types else EASEModel, "regularization": 500, "num_threads": 1, "verbose": 1, diff --git a/tests/models/test_implicit_als.py b/tests/models/test_implicit_als.py index 91cc9514..df85ea24 100644 --- a/tests/models/test_implicit_als.py +++ b/tests/models/test_implicit_als.py @@ -490,6 +490,7 @@ def test_to_config(self, use_gpu: bool, random_state: tp.Optional[int], simple_t } ) expected = { + "cls": "ImplicitALSWrapperModel" if simple_types else ImplicitALSWrapperModel, "model": { "cls": "AlternatingLeastSquares", "params": expected_model_params, diff --git a/tests/models/test_implicit_knn.py b/tests/models/test_implicit_knn.py index 942743af..04213c16 100644 --- a/tests/models/test_implicit_knn.py +++ b/tests/models/test_implicit_knn.py @@ -329,6 +329,7 @@ def test_to_config( } ) expected = { + "cls": "ImplicitItemKNNWrapperModel" if simple_types else ImplicitItemKNNWrapperModel, "model": { "cls": model_class if not simple_types else model_class_str, "params": expected_model_params, diff --git a/tests/models/test_lightfm.py b/tests/models/test_lightfm.py index a527013b..de3aaf95 100644 --- a/tests/models/test_lightfm.py +++ b/tests/models/test_lightfm.py @@ -401,6 +401,7 @@ def test_to_config(self, random_state: tp.Optional[int], simple_types: bool) -> "random_state": random_state, } expected = { + "cls": "LightFMWrapperModel" if simple_types else LightFMWrapperModel, "model": { "cls": "LightFM" if simple_types else LightFM, "params": expected_model_params, diff --git a/tests/models/test_popular.py b/tests/models/test_popular.py index cb1ab8d7..7dbd6798 100644 --- a/tests/models/test_popular.py +++ b/tests/models/test_popular.py @@ -299,6 +299,7 @@ def test_get_config( ) config = model.get_config() expected = { + "cls": PopularModel, "popularity": Popularity("n_users"), "period": expected_period, "begin_from": begin_from, diff --git a/tests/models/test_popular_in_category.py b/tests/models/test_popular_in_category.py index f30533d5..f6109089 100644 --- a/tests/models/test_popular_in_category.py +++ b/tests/models/test_popular_in_category.py @@ -545,6 +545,7 @@ def test_get_config( ) config = model.get_config() expected = { + "cls": PopularInCategoryModel, "category_feature": "f2", "n_categories": 3, "mixing_strategy": MixingStrategy("rotate"), diff --git a/tests/models/test_pure_svd.py b/tests/models/test_pure_svd.py index a197c150..5d1471a6 100644 --- a/tests/models/test_pure_svd.py +++ b/tests/models/test_pure_svd.py @@ -283,7 +283,8 @@ def test_from_config(self) -> None: assert model.verbose == 0 @pytest.mark.parametrize("random_state", (None, 42)) - def test_get_config(self, random_state: tp.Optional[int]) -> None: + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config(self, random_state: tp.Optional[int], simple_types: bool) -> None: model = PureSVDModel( factors=100, tol=1, @@ -291,8 +292,9 @@ def test_get_config(self, random_state: tp.Optional[int]) -> None: random_state=random_state, verbose=1, ) - config = model.get_config() + config = model.get_config(simple_types=simple_types) expected = { + "cls": "PureSVDModel" if simple_types else PureSVDModel, "factors": 100, "tol": 1, "maxiter": 100, diff --git a/tests/models/test_random.py b/tests/models/test_random.py index ab79526c..cde90faf 100644 --- a/tests/models/test_random.py +++ b/tests/models/test_random.py @@ -202,13 +202,15 @@ def test_from_config(self) -> None: assert model.verbose == 0 @pytest.mark.parametrize("random_state", (None, 42)) - def test_get_config(self, random_state: tp.Optional[int]) -> None: + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config(self, random_state: tp.Optional[int], simple_types: bool) -> None: model = RandomModel( random_state=random_state, verbose=1, ) - config = model.get_config() + config = model.get_config(simple_types=simple_types) expected = { + "cls": "RandomModel" if simple_types else RandomModel, "random_state": random_state, "verbose": 1, } From 6b3a785f01e3be75847d8f2731ae97d1fac1eff1 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 08:38:01 +0100 Subject: [PATCH 17/25] fixed popular model config serialization --- rectools/models/popular.py | 8 ++++---- tests/models/test_popular.py | 20 ++++++++++++-------- tests/models/test_popular_in_category.py | 24 ++++++++++++++---------- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/rectools/models/popular.py b/rectools/models/popular.py index 7792ee7e..98de5f26 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -21,7 +21,7 @@ import numpy as np import pandas as pd import typing_extensions as tpe -from pydantic import PlainSerializer, PlainValidator +from pydantic import BeforeValidator, PlainSerializer from tqdm.auto import tqdm from rectools import Columns, InternalIds @@ -43,7 +43,7 @@ class Popularity(Enum): SUM_WEIGHT = "sum_weight" -def _deserialize_timedelta(td: tp.Union[dict, timedelta]) -> timedelta: +def _deserialize_timedelta(td: tp.Any) -> tp.Any: if isinstance(td, dict): return timedelta(**td) return td @@ -60,8 +60,8 @@ def _serialize_timedelta(td: timedelta) -> dict: TimeDelta = tpe.Annotated[ timedelta, - PlainValidator(func=_deserialize_timedelta), - PlainSerializer(func=_serialize_timedelta), + BeforeValidator(func=_deserialize_timedelta), + PlainSerializer(func=_serialize_timedelta, return_type=dict, when_used="json") ] diff --git a/tests/models/test_popular.py b/tests/models/test_popular.py index 7dbd6798..ca935dfb 100644 --- a/tests/models/test_popular.py +++ b/tests/models/test_popular.py @@ -273,21 +273,25 @@ def test_from_config( assert model.verbose == 0 @pytest.mark.parametrize( - "begin_from,period,expected_period", + "begin_from,period,simple_begin_from,simple_period", ( ( None, timedelta(weeks=2, days=7, hours=23, milliseconds=12345), + None, {"days": 21, "microseconds": 345000, "seconds": 82812}, ), - (datetime(2021, 11, 23, 10, 20, 30, 400000), None, None), + (datetime(2024, 11, 23, 10, 20, 30, 400000), None, "2024-11-23T10:20:30.400000", None), ), ) + @pytest.mark.parametrize("simple_types", (True, False)) def test_get_config( self, period: tp.Optional[timedelta], begin_from: tp.Optional[datetime], - expected_period: tp.Optional[timedelta], + simple_begin_from: tp.Optional[str], + simple_period: tp.Optional[dict], + simple_types: bool, ) -> None: model = PopularModel( popularity="n_users", @@ -297,12 +301,12 @@ def test_get_config( inverse=False, verbose=1, ) - config = model.get_config() + config = model.get_config(simple_types=simple_types) expected = { - "cls": PopularModel, - "popularity": Popularity("n_users"), - "period": expected_period, - "begin_from": begin_from, + "cls": "PopularModel" if simple_types else PopularModel, + "popularity": "n_users" if simple_types else Popularity("n_users"), + "period": simple_period if simple_types else period, + "begin_from": simple_begin_from if simple_types else begin_from, "add_cold": False, "inverse": False, "verbose": 1, diff --git a/tests/models/test_popular_in_category.py b/tests/models/test_popular_in_category.py index f6109089..233f5bd7 100644 --- a/tests/models/test_popular_in_category.py +++ b/tests/models/test_popular_in_category.py @@ -515,21 +515,25 @@ def test_from_config( assert model.verbose == 0 @pytest.mark.parametrize( - "begin_from,period,expected_period", + "begin_from,period,simple_begin_from,simple_period", ( ( None, timedelta(weeks=2, days=7, hours=23, milliseconds=12345), + None, {"days": 21, "microseconds": 345000, "seconds": 82812}, ), - (datetime(2021, 11, 23, 10, 20, 30, 400000), None, None), + (datetime(2024, 11, 23, 10, 20, 30, 400000), None, "2024-11-23T10:20:30.400000", None), ), ) + @pytest.mark.parametrize("simple_types", (True, False)) def test_get_config( self, period: tp.Optional[timedelta], begin_from: tp.Optional[datetime], - expected_period: tp.Optional[timedelta], + simple_begin_from: tp.Optional[str], + simple_period: tp.Optional[dict], + simple_types: bool, ) -> None: model = PopularInCategoryModel( category_feature="f2", @@ -543,16 +547,16 @@ def test_get_config( inverse=False, verbose=1, ) - config = model.get_config() + config = model.get_config(simple_types=simple_types) expected = { - "cls": PopularInCategoryModel, + "cls": "PopularInCategoryModel" if simple_types else PopularInCategoryModel, "category_feature": "f2", "n_categories": 3, - "mixing_strategy": MixingStrategy("rotate"), - "ratio_strategy": RatioStrategy("proportional"), - "popularity": Popularity("n_users"), - "period": expected_period, - "begin_from": begin_from, + "mixing_strategy": "rotate" if simple_types else MixingStrategy("rotate"), + "ratio_strategy": "proportional" if simple_types else RatioStrategy("proportional"), + "popularity": "n_users" if simple_types else Popularity("n_users"), + "period": simple_period if simple_types else period, + "begin_from": simple_begin_from if simple_types else begin_from, "add_cold": False, "inverse": False, "verbose": 1, From 34fe896bdae5d04f23e22166ba82870913674e01 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 08:42:06 +0100 Subject: [PATCH 18/25] formatted --- rectools/models/popular.py | 2 +- rectools/models/serialization.py | 4 ++-- tests/models/test_serialization.py | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/rectools/models/popular.py b/rectools/models/popular.py index 98de5f26..c64be4fc 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -61,7 +61,7 @@ def _serialize_timedelta(td: timedelta) -> dict: TimeDelta = tpe.Annotated[ timedelta, BeforeValidator(func=_deserialize_timedelta), - PlainSerializer(func=_serialize_timedelta, return_type=dict, when_used="json") + PlainSerializer(func=_serialize_timedelta, return_type=dict, when_used="json"), ] diff --git a/rectools/models/serialization.py b/rectools/models/serialization.py index 408f5239..19fe924d 100644 --- a/rectools/models/serialization.py +++ b/rectools/models/serialization.py @@ -3,7 +3,7 @@ from pydantic import TypeAdapter -from rectools.models.base import ModelBase, ModelConfig, ModelClass +from rectools.models.base import ModelBase, ModelClass, ModelConfig from rectools.utils.serialization import FileLike, read_bytes @@ -46,7 +46,7 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase: model_cls = TypeAdapter(tp.Optional[ModelClass]).validate_python(model_cls) else: model_cls = config.cls - + if model_cls is None: raise ValueError("`cls` must be provided in the config to load the model") diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index 5ebcc7f7..7b34bf85 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -3,10 +3,10 @@ import typing as tp from tempfile import NamedTemporaryFile -from pydantic import ValidationError import pytest from implicit.als import AlternatingLeastSquares from implicit.nearest_neighbours import ItemItemRecommender +from pydantic import ValidationError try: from lightfm import LightFM @@ -36,8 +36,8 @@ cls for cls in get_successors(ModelBase) if ( - cls.__module__.startswith("rectools.models") - and cls not in INTERMEDIATE_MODEL_CLASSES + cls.__module__.startswith("rectools.models") + and cls not in INTERMEDIATE_MODEL_CLASSES and not (sys.version_info >= (3, 12) and cls is LightFMWrapperModel) ) ) @@ -67,7 +67,7 @@ def test_load_model(model_cls: tp.Type[ModelBase]) -> None: class CustomModelConfig(ModelConfig): some_param: int = 1 - + class CustomModel(ModelBase[CustomModelConfig]): config_class = CustomModelConfig @@ -78,7 +78,7 @@ def __init__(self, some_param: int = 1, verbose: int = 0): @classmethod def _from_config(cls, config: CustomModelConfig) -> "CustomModel": return cls(some_param=config.some_param, verbose=config.verbose) - + class TestModelFromConfig: @@ -100,7 +100,7 @@ def test_standard_model_creation( ( CustomModelConfig(cls=CustomModel, some_param=2), {"cls": "tests.models.test_serialization.CustomModel", "some_param": 2}, - ) + ), ) def test_custom_model_creation(self, config: tp.Union[dict, CustomModelConfig]) -> None: model = model_from_config(config) @@ -131,7 +131,7 @@ def test_fails_on_none_cls(self, mode: tp.Literal["pydantic", "dict"], simple_ty ( ("nonexistent_module.SomeModel", ModuleNotFoundError), ("rectools.models.NonexistentModel", AttributeError), - ) + ), ) def test_fails_on_nonexistent_cls(self, model_cls_path: str, error_cls: tp.Type[Exception]) -> None: config = {"cls": model_cls_path} From a1928bd56d076059245fc8b5f1d59130ff774421 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 08:51:07 +0100 Subject: [PATCH 19/25] refactored --- rectools/models/serialization.py | 1 - tests/models/test_serialization.py | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rectools/models/serialization.py b/rectools/models/serialization.py index 19fe924d..ae0cc428 100644 --- a/rectools/models/serialization.py +++ b/rectools/models/serialization.py @@ -40,7 +40,6 @@ def model_from_config(config: tp.Union[dict, ModelConfig]) -> ModelBase: model Model instance. """ - if isinstance(config, dict): model_cls = config.get("cls") model_cls = TypeAdapter(tp.Optional[ModelClass]).validate_python(model_cls) diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index 7b34bf85..a02885b5 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -1,4 +1,3 @@ -import re import sys import typing as tp from tempfile import NamedTemporaryFile @@ -73,6 +72,7 @@ class CustomModel(ModelBase[CustomModelConfig]): config_class = CustomModelConfig def __init__(self, some_param: int = 1, verbose: int = 0): + super().__init__(verbose=verbose) self.some_param = some_param @classmethod @@ -120,9 +120,9 @@ def test_fails_on_none_cls(self, mode: tp.Literal["pydantic", "dict"], simple_ty model = PopularModel() config = model.get_config(mode=mode, simple_types=simple_types) if mode == "pydantic": - config.cls = None + config.cls = None # type: ignore else: - config["cls"] = None + config["cls"] = None # type: ignore # pylint: disable=unsupported-assignment-operation with pytest.raises(ValueError, match="`cls` must be provided in the config to load the model"): model_from_config(config) @@ -149,12 +149,13 @@ def test_fails_on_incorrect_model_cls(self, mode: tp.Literal["pydantic", "dict"] model = PopularModel() config = model.get_config(mode=mode, simple_types=simple_types) if mode == "pydantic": - config.cls = LightFMWrapperModel + config.cls = LightFMWrapperModel # type: ignore else: if simple_types: - config["cls"] = "rectools.models.LightFMWrapperModel" + # pylint: disable=unsupported-assignment-operation + config["cls"] = "rectools.models.LightFMWrapperModel" # type: ignore else: - config["cls"] = LightFMWrapperModel + config["cls"] = LightFMWrapperModel # type: ignore # pylint: disable=unsupported-assignment-operation with pytest.raises(ValidationError): model_from_config(config) From 1a18391ec23f38ca0dee40c4ea3f5550ee82df36 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 12:00:10 +0100 Subject: [PATCH 20/25] improved model configs for als and lightfm --- rectools/models/implicit_als.py | 45 +++++++++++-------------------- rectools/models/lightfm.py | 31 ++++++++------------- tests/models/test_implicit_als.py | 24 +++++++---------- tests/models/test_lightfm.py | 18 +++++-------- 4 files changed, 42 insertions(+), 76 deletions(-) diff --git a/rectools/models/implicit_als.py b/rectools/models/implicit_als.py index 319fc2dc..3d7666be 100644 --- a/rectools/models/implicit_als.py +++ b/rectools/models/implicit_als.py @@ -29,7 +29,6 @@ from rectools.dataset import Dataset, Features from rectools.exceptions import NotFittedError from rectools.models.base import ModelConfig -from rectools.utils.config import BaseConfig from rectools.utils.misc import get_class_or_function_full_path, import_object from rectools.utils.serialization import RandomState @@ -74,9 +73,10 @@ def _serialize_alternating_least_squares_class( ] -class AlternatingLeastSquaresParams(tpe.TypedDict): - """Params for implicit `AlternatingLeastSquares` model.""" +class AlternatingLeastSquaresConfig(tpe.TypedDict): + """Config for implicit `AlternatingLeastSquares` model.""" + cls: tpe.NotRequired[AlternatingLeastSquaresClass] factors: tpe.NotRequired[int] regularization: tpe.NotRequired[float] alpha: tpe.NotRequired[float] @@ -90,18 +90,11 @@ class AlternatingLeastSquaresParams(tpe.TypedDict): random_state: tpe.NotRequired[RandomState] -class AlternatingLeastSquaresConfig(BaseConfig): - """Config for implicit `AlternatingLeastSquares` model.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - cls: AlternatingLeastSquaresClass = "AlternatingLeastSquares" - params: AlternatingLeastSquaresParams = {} - - class ImplicitALSWrapperModelConfig(ModelConfig): """Config for `ImplicitALSWrapperModel`.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + model: AlternatingLeastSquaresConfig fit_features_together: bool = False @@ -150,7 +143,8 @@ def __init__(self, model: AnyAlternatingLeastSquares, verbose: int = 0, fit_feat def _make_config( cls, model: AnyAlternatingLeastSquares, verbose: int, fit_features_together: bool ) -> ImplicitALSWrapperModelConfig: - params = { + inner_model_config = { + "cls": model.__class__, "factors": model.factors, "regularization": model.regularization, "alpha": model.alpha, @@ -160,9 +154,9 @@ def _make_config( "random_state": model.random_state, } if isinstance(model, GPUAlternatingLeastSquares): - params.update({"use_gpu": True}) + inner_model_config.update({"use_gpu": True}) else: - params.update( + inner_model_config.update( { "use_gpu": False, "use_native": model.use_native, @@ -171,17 +165,10 @@ def _make_config( } ) - model_cls = model.__class__ return ImplicitALSWrapperModelConfig( cls=cls, - model=AlternatingLeastSquaresConfig( - cls=( - model_cls - if model_cls not in (CPUAlternatingLeastSquares, GPUAlternatingLeastSquares) - else "AlternatingLeastSquares" - ), - params=tp.cast(AlternatingLeastSquaresParams, params), # https://github.com/python/mypy/issues/8890 - ), + # https://github.com/python/mypy/issues/8890 + model=tp.cast(AlternatingLeastSquaresConfig, inner_model_config), verbose=verbose, fit_features_together=fit_features_together, ) @@ -191,11 +178,11 @@ def _get_config(self) -> ImplicitALSWrapperModelConfig: @classmethod def _from_config(cls, config: ImplicitALSWrapperModelConfig) -> tpe.Self: - if config.model.cls == ALS_STRING: - model_cls = AlternatingLeastSquares # Not actually a class, but it's ok - else: - model_cls = config.model.cls - model = model_cls(**config.model.params) + inner_model_params = config.model.copy() + inner_model_cls = inner_model_params.pop("cls", AlternatingLeastSquares) + if inner_model_cls == ALS_STRING: + inner_model_cls = AlternatingLeastSquares # Not actually a class, but it's ok + model = inner_model_cls(**inner_model_params) # type: ignore # mypy misses we replaced str with a func return cls(model=model, verbose=config.verbose, fit_features_together=config.fit_features_together) def _fit(self, dataset: Dataset) -> None: diff --git a/rectools/models/lightfm.py b/rectools/models/lightfm.py index bfd970ca..c231c749 100644 --- a/rectools/models/lightfm.py +++ b/rectools/models/lightfm.py @@ -25,7 +25,6 @@ from rectools.exceptions import NotFittedError from rectools.models.utils import recommend_from_scores from rectools.types import InternalIds, InternalIdsArray -from rectools.utils.config import BaseConfig from rectools.utils.misc import get_class_or_function_full_path, import_object from rectools.utils.serialization import RandomState @@ -61,9 +60,10 @@ def _serialize_light_fm_class(cls: tp.Type[LightFM]) -> str: ] -class LightFMParams(tpe.TypedDict): - """Params for `LightFM` model.""" +class LightFMConfig(tpe.TypedDict): + """Config for `LightFM` model.""" + cls: tpe.NotRequired[LightFMClass] no_components: tpe.NotRequired[int] k: tpe.NotRequired[int] n: tpe.NotRequired[int] @@ -78,18 +78,11 @@ class LightFMParams(tpe.TypedDict): random_state: tpe.NotRequired[RandomState] -class LightFMConfig(BaseConfig): - """Config for `LightFM` model.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - cls: LightFMClass = LightFM - params: LightFMParams = {} - - class LightFMWrapperModelConfig(ModelConfig): """Config for `LightFMWrapperModel`.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + model: LightFMConfig epochs: int = 1 num_threads: int = 1 @@ -141,7 +134,8 @@ def __init__( def _get_config(self) -> LightFMWrapperModelConfig: inner_model = self._model - params = { + inner_config = { + "cls": inner_model.__class__, "no_components": inner_model.no_components, "k": inner_model.k, "n": inner_model.n, @@ -155,13 +149,9 @@ def _get_config(self) -> LightFMWrapperModelConfig: "max_sampled": inner_model.max_sampled, "random_state": inner_model.initial_random_state, # random_state is an object and can't be serialized } - inner_model_cls = inner_model.__class__ return LightFMWrapperModelConfig( cls=self.__class__, - model=LightFMConfig( - cls=inner_model_cls, - params=tp.cast(LightFMParams, params), # https://github.com/python/mypy/issues/8890 - ), + model=tp.cast(LightFMConfig, inner_config), # https://github.com/python/mypy/issues/8890 epochs=self.n_epochs, num_threads=self.n_threads, verbose=self.verbose, @@ -169,8 +159,9 @@ def _get_config(self) -> LightFMWrapperModelConfig: @classmethod def _from_config(cls, config: LightFMWrapperModelConfig) -> tpe.Self: - model_cls = config.model.cls - model = model_cls(**config.model.params) + params = config.model.copy() + model_cls = params.pop("cls", LightFM) + model = model_cls(**params) return cls(model=model, epochs=config.epochs, num_threads=config.num_threads, verbose=config.verbose) def _fit(self, dataset: Dataset) -> None: # type: ignore diff --git a/tests/models/test_implicit_als.py b/tests/models/test_implicit_als.py index df85ea24..581504c6 100644 --- a/tests/models/test_implicit_als.py +++ b/tests/models/test_implicit_als.py @@ -438,12 +438,10 @@ def setup_method(self) -> None: def test_from_config(self, use_gpu: bool, cls: tp.Any) -> None: config: tp.Dict = { "model": { - "params": { - "factors": 16, - "num_threads": 2, - "iterations": 100, - "use_gpu": use_gpu, - }, + "factors": 16, + "num_threads": 2, + "iterations": 100, + "use_gpu": use_gpu, }, "fit_features_together": True, "verbose": 1, @@ -471,7 +469,8 @@ def test_to_config(self, use_gpu: bool, random_state: tp.Optional[int], simple_t verbose=1, ) config = model.get_config(simple_types=simple_types) - expected_model_params = { + expected_inner_model_config = { + "cls": "AlternatingLeastSquares", "factors": 16, "regularization": 0.01, "alpha": 1.0, @@ -482,7 +481,7 @@ def test_to_config(self, use_gpu: bool, random_state: tp.Optional[int], simple_t "use_gpu": use_gpu, } if not use_gpu: - expected_model_params.update( + expected_inner_model_config.update( { "use_native": True, "use_cg": True, @@ -491,10 +490,7 @@ def test_to_config(self, use_gpu: bool, random_state: tp.Optional[int], simple_t ) expected = { "cls": "ImplicitALSWrapperModel" if simple_types else ImplicitALSWrapperModel, - "model": { - "cls": "AlternatingLeastSquares", - "params": expected_model_params, - }, + "model": expected_inner_model_config, "fit_features_together": True, "verbose": 1, } @@ -528,9 +524,7 @@ def test_custom_model_class(self) -> None: @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: initial_config = { - "model": { - "params": {"factors": 16, "num_threads": 2, "iterations": 3, "random_state": 42}, - }, + "model": {"factors": 16, "num_threads": 2, "iterations": 3, "random_state": 42}, "verbose": 1, } assert_get_config_and_from_config_compatibility(ImplicitALSWrapperModel, DATASET, initial_config, simple_types) diff --git a/tests/models/test_lightfm.py b/tests/models/test_lightfm.py index de3aaf95..d4b34f79 100644 --- a/tests/models/test_lightfm.py +++ b/tests/models/test_lightfm.py @@ -357,10 +357,8 @@ class TestLightFMWrapperModelConfiguration: def test_from_config(self, add_cls: bool) -> None: config: tp.Dict = { "model": { - "params": { - "no_components": 16, - "learning_rate": 0.03, - }, + "no_components": 16, + "learning_rate": 0.03, }, "epochs": 2, "num_threads": 3, @@ -386,7 +384,8 @@ def test_to_config(self, random_state: tp.Optional[int], simple_types: bool) -> verbose=1, ) config = model.get_config(simple_types=simple_types) - expected_model_params = { + expected_inner_model_config = { + "cls": "LightFM" if simple_types else LightFM, "no_components": 16, "k": 5, "n": 10, @@ -402,10 +401,7 @@ def test_to_config(self, random_state: tp.Optional[int], simple_types: bool) -> } expected = { "cls": "LightFMWrapperModel" if simple_types else LightFMWrapperModel, - "model": { - "cls": "LightFM" if simple_types else LightFM, - "params": expected_model_params, - }, + "model": expected_inner_model_config, "epochs": 2, "num_threads": 3, "verbose": 1, @@ -440,9 +436,7 @@ def test_custom_model_class(self) -> None: @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: initial_config = { - "model": { - "params": {"no_components": 16, "learning_rate": 0.03, "random_state": 42}, - }, + "model": {"no_components": 16, "learning_rate": 0.03, "random_state": 42}, "verbose": 1, } assert_get_config_and_from_config_compatibility(LightFMWrapperModel, DATASET, initial_config, simple_types) From 38750fa4ad0a6c0cf834db1fe5e9bb1617c93ce7 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 16:10:16 +0100 Subject: [PATCH 21/25] improved config for knn --- rectools/models/implicit_knn.py | 28 +++++++++++++++++----------- tests/models/test_implicit_knn.py | 23 +++++++++-------------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/rectools/models/implicit_knn.py b/rectools/models/implicit_knn.py index 89e8354b..e17ec889 100644 --- a/rectools/models/implicit_knn.py +++ b/rectools/models/implicit_knn.py @@ -71,18 +71,21 @@ def _serialize_item_item_recommender_class(cls: tp.Type[ItemItemRecommender]) -> ] -class ItemItemRecommenderConfig(BaseConfig): +class ItemItemRecommenderConfig(tpe.TypedDict): """Config for `implicit` `ItemItemRecommender` model and its successors.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - cls: ItemItemRecommenderClass - params: tp.Dict[str, tp.Any] = {} + K: tpe.NotRequired[int] + K1: tpe.NotRequired[float] + B: tpe.NotRequired[float] + num_threads: tpe.NotRequired[int] class ImplicitItemKNNWrapperModelConfig(ModelConfig): """Config for `ImplicitItemKNNWrapperModel`.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + model: ItemItemRecommenderConfig @@ -111,22 +114,25 @@ def __init__(self, model: ItemItemRecommender, verbose: int = 0): def _get_config(self) -> ImplicitItemKNNWrapperModelConfig: inner_model = self._model - params = {"K": inner_model.K, "num_threads": inner_model.num_threads} + inner_model_config = { + "cls": inner_model.__class__, + "K": inner_model.K, + "num_threads": inner_model.num_threads, + } if isinstance(inner_model, BM25Recommender): # NOBUG: If it's a custom class, we don't know its params - params.update({"K1": inner_model.K1, "B": inner_model.B}) + inner_model_config.update({"K1": inner_model.K1, "B": inner_model.B}) return ImplicitItemKNNWrapperModelConfig( cls=self.__class__, - model=ItemItemRecommenderConfig( - cls=inner_model.__class__, - params=params, - ), + model=tp.cast(ItemItemRecommenderConfig, inner_model_config), verbose=self.verbose, ) @classmethod def _from_config(cls, config: ImplicitItemKNNWrapperModelConfig) -> tpe.Self: - model = config.model.cls(**config.model.params) + params = config.model.copy() + model_cls = params.pop("cls") + model = model_cls(**params) return cls(model=model, verbose=config.verbose) def _fit(self, dataset: Dataset) -> None: # type: ignore diff --git a/tests/models/test_implicit_knn.py b/tests/models/test_implicit_knn.py index 04213c16..354c01e6 100644 --- a/tests/models/test_implicit_knn.py +++ b/tests/models/test_implicit_knn.py @@ -276,14 +276,11 @@ class TestImplicitItemKNNWrapperModelConfiguration: ), ) def test_from_config(self, model_class: tp.Union[tp.Type[ItemItemRecommender], str]) -> None: - params: tp.Dict[str, tp.Any] = {"K": 5} + inner_model_config: tp.Dict[str, tp.Any] = {"cls": model_class, "K": 5} if model_class == "BM25Recommender": - params.update({"K1": 0.33}) + inner_model_config.update({"K1": 0.33}) config = { - "model": { - "cls": model_class, - "params": params, - }, + "model": inner_model_config, "verbose": 1, } model = ImplicitItemKNNWrapperModel.from_config(config) @@ -317,12 +314,13 @@ def test_to_config( verbose=1, ) config = model.get_config(simple_types=simple_types) - expected_model_params: tp.Dict[str, tp.Any] = { + expected_inner_model_config: tp.Dict[str, tp.Any] = { + "cls": model_class if not simple_types else model_class_str, "K": 5, "num_threads": 0, } if model_class is BM25Recommender: - expected_model_params.update( + expected_inner_model_config.update( { "K1": 1.2, "B": 0.75, @@ -330,10 +328,7 @@ def test_to_config( ) expected = { "cls": "ImplicitItemKNNWrapperModel" if simple_types else ImplicitItemKNNWrapperModel, - "model": { - "cls": model_class if not simple_types else model_class_str, - "params": expected_model_params, - }, + "model": expected_inner_model_config, "verbose": 1, } assert config == expected @@ -343,7 +338,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N initial_config = { "model": { "cls": TFIDFRecommender, - "params": {"K": 3}, + "K": 3, }, "verbose": 1, } @@ -352,6 +347,6 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N ) def test_default_config_and_default_model_params_are_the_same(self) -> None: - default_config: tp.Dict[str, tp.Any] = {"model": {"cls": ItemItemRecommender, "params": {}}} + default_config: tp.Dict[str, tp.Any] = {"model": {"cls": ItemItemRecommender}} model = ImplicitItemKNNWrapperModel(model=ItemItemRecommender()) assert_default_config_and_default_model_params_are_the_same(model, default_config) From 3ff56c8ecc45bac2bd2b140185adf82d5f6cb2e8 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 16:33:18 +0100 Subject: [PATCH 22/25] fixed errors --- rectools/models/implicit_als.py | 7 ++++++- rectools/models/implicit_knn.py | 6 +++--- tests/models/test_base.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/rectools/models/implicit_als.py b/rectools/models/implicit_als.py index 3d7666be..72753e11 100644 --- a/rectools/models/implicit_als.py +++ b/rectools/models/implicit_als.py @@ -143,8 +143,13 @@ def __init__(self, model: AnyAlternatingLeastSquares, verbose: int = 0, fit_feat def _make_config( cls, model: AnyAlternatingLeastSquares, verbose: int, fit_features_together: bool ) -> ImplicitALSWrapperModelConfig: + model_cls = ( + model.__class__ + if model.__class__ not in (CPUAlternatingLeastSquares, GPUAlternatingLeastSquares) + else "AlternatingLeastSquares" + ) inner_model_config = { - "cls": model.__class__, + "cls": model_cls, "factors": model.factors, "regularization": model.regularization, "alpha": model.alpha, diff --git a/rectools/models/implicit_knn.py b/rectools/models/implicit_knn.py index e17ec889..ae645775 100644 --- a/rectools/models/implicit_knn.py +++ b/rectools/models/implicit_knn.py @@ -29,7 +29,6 @@ from rectools.dataset import Dataset from rectools.types import InternalId, InternalIdsArray from rectools.utils import fast_isin_for_sorted_test_elements -from rectools.utils.config import BaseConfig from rectools.utils.misc import get_class_or_function_full_path, import_object from .base import ModelBase, ModelConfig, Scores @@ -130,8 +129,9 @@ def _get_config(self) -> ImplicitItemKNNWrapperModelConfig: @classmethod def _from_config(cls, config: ImplicitItemKNNWrapperModelConfig) -> tpe.Self: - params = config.model.copy() - model_cls = params.pop("cls") + model_cls = config.model["cls"] + params = dict(config.model.copy()) # `cls` param is required and cannot be popped + del params["cls"] model = model_cls(**params) return cls(model=model, verbose=config.verbose) diff --git a/tests/models/test_base.py b/tests/models/test_base.py index 8c359da8..a5e3a8e0 100644 --- a/tests/models/test_base.py +++ b/tests/models/test_base.py @@ -493,7 +493,7 @@ def test_from_config_dict_with_extra_keys(self) -> None: def test_get_config_pydantic(self) -> None: model = self.model_class(x=10, verbose=1) config = model.get_config(mode="pydantic") - assert config == self.config_class(x=10, verbose=1) + assert config == self.config_class(cls=self.model_class, x=10, verbose=1) def test_raises_on_pydantic_with_simple_types(self) -> None: model = self.model_class(x=10, verbose=1) From 74976a338e3ef75d23abe2206a5e43b58f3bce36 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 17:34:57 +0100 Subject: [PATCH 23/25] fixed error for python 3.12 --- tests/models/test_serialization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index a02885b5..5bd701ee 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -149,13 +149,13 @@ def test_fails_on_incorrect_model_cls(self, mode: tp.Literal["pydantic", "dict"] model = PopularModel() config = model.get_config(mode=mode, simple_types=simple_types) if mode == "pydantic": - config.cls = LightFMWrapperModel # type: ignore + config.cls = ImplicitALSWrapperModel # type: ignore else: if simple_types: # pylint: disable=unsupported-assignment-operation config["cls"] = "rectools.models.LightFMWrapperModel" # type: ignore else: - config["cls"] = LightFMWrapperModel # type: ignore # pylint: disable=unsupported-assignment-operation + config["cls"] = ImplicitALSWrapperModel # type: ignore # pylint: disable=unsupported-assignment-operation with pytest.raises(ValidationError): model_from_config(config) From 2f06dca4ac737f728fac19b0710b772262d06045 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 17:45:22 +0100 Subject: [PATCH 24/25] fixed coverage --- tests/models/test_base.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/models/test_base.py b/tests/models/test_base.py index a5e3a8e0..5ab7d68e 100644 --- a/tests/models/test_base.py +++ b/tests/models/test_base.py @@ -459,9 +459,17 @@ def _from_config(cls, config: SomeModelConfig) -> tpe.Self: td = None if config.sc is None else config.sc.td return cls(x=config.x, td=td, verbose=config.verbose) + class OtherModelConfig(ModelConfig): + y: int + + class OtherModel(ModelBase[OtherModelConfig]): + pass + self.config_class = SomeModelConfig self.model_class = SomeModel self.model_class_path = "tests.models.test_base.TestConfiguration.setup_method..SomeModel" + self.other_config_class = OtherModelConfig + self.other_model_class = OtherModel def test_from_pydantic_config(self) -> None: config = self.config_class(x=10, verbose=1) @@ -544,6 +552,11 @@ class MyModelWithoutConfig(ModelBase): ): MyModelWithoutConfig().get_config() + def test_incorrct_model_class_in_config(self) -> None: + config = self.config_class(cls=self.other_model_class, x=1) + with pytest.raises(TypeError, match="`SomeModel` is used, but config is for `OtherModel`"): + self.model_class.from_config(config) + class MyModel(ModelBase): def __init__(self, x: int = 10, verbose: int = 0): From 1a47946848e4ba69cbd25c236ec5d42ad7b07204 Mon Sep 17 00:00:00 2001 From: Emiliy Feldman Date: Thu, 5 Dec 2024 17:47:30 +0100 Subject: [PATCH 25/25] formatted --- tests/models/test_serialization.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/test_serialization.py b/tests/models/test_serialization.py index 5bd701ee..66786540 100644 --- a/tests/models/test_serialization.py +++ b/tests/models/test_serialization.py @@ -16,6 +16,7 @@ from rectools.metrics import NDCG from rectools.models import ( DSSMModel, + EASEModel, ImplicitALSWrapperModel, ImplicitItemKNNWrapperModel, LightFMWrapperModel, @@ -149,13 +150,13 @@ def test_fails_on_incorrect_model_cls(self, mode: tp.Literal["pydantic", "dict"] model = PopularModel() config = model.get_config(mode=mode, simple_types=simple_types) if mode == "pydantic": - config.cls = ImplicitALSWrapperModel # type: ignore + config.cls = EASEModel # type: ignore else: if simple_types: # pylint: disable=unsupported-assignment-operation config["cls"] = "rectools.models.LightFMWrapperModel" # type: ignore else: - config["cls"] = ImplicitALSWrapperModel # type: ignore # pylint: disable=unsupported-assignment-operation + config["cls"] = EASEModel # type: ignore # pylint: disable=unsupported-assignment-operation with pytest.raises(ValidationError): model_from_config(config)