diff --git a/insolver/wrappers_v2/base.py b/insolver/wrappers_v2/base.py index a05de68..377950a 100644 --- a/insolver/wrappers_v2/base.py +++ b/insolver/wrappers_v2/base.py @@ -14,7 +14,7 @@ else: from typing_extensions import Literal -from .utils import get_requirements +from .utils.req_utils import get_requirements class InsolverWrapperWarning(Warning): @@ -58,7 +58,7 @@ def _save_insolver(self, path_or_buf: Union[str, 'PathLike[str]'], method: Calla zip_file.writestr("requirements.txt", get_requirements()) zip_file.writestr( f"model_{os.path.basename(path_or_buf)}", - BytesIO(method(self.model, path_or_buf=None, **kwargs)).getvalue(), + BytesIO(method(self, path_or_buf=None, **kwargs)).getvalue(), ) with open(path_or_buf if str(path_or_buf).endswith('.zip') else f'{path_or_buf}.zip', "wb") as f: @@ -109,24 +109,14 @@ def save_model( if path_or_buf is None: if self._model_cached is None: - return self._backend_saving_methods[self.backend][method](self.model, path_or_buf, **kwargs) + return self._backend_saving_methods[self.backend][method](self, path_or_buf, **kwargs) else: return self._model_cached else: if mode == "insolver": self.metadata.update({"saving_method": method}) - if self._model_cached is None: - self._save_insolver( - path_or_buf, method=self._backend_saving_methods[self.backend][method], **kwargs - ) - else: - self._save_insolver( - path_or_buf, - method=self._backend_saving_methods[self.backend][method], - _model_cached=self._model_cached, - **kwargs, - ) - path_or_buf = f'{path_or_buf}.zip' + self._save_insolver(path_or_buf, method=self._backend_saving_methods[self.backend][method], **kwargs) + path_or_buf = f'{path_or_buf}.zip' if not path_or_buf.endswith('.zip') else path_or_buf else: - self._backend_saving_methods[self.backend][method](self.model, path_or_buf, **kwargs) + self._backend_saving_methods[self.backend][method](self, path_or_buf, **kwargs) return f"Saved model: {os.path.normpath(path_or_buf)}" diff --git a/insolver/wrappers_v2/gbm.py b/insolver/wrappers_v2/gbm.py index ef2a4a4..d7e897e 100644 --- a/insolver/wrappers_v2/gbm.py +++ b/insolver/wrappers_v2/gbm.py @@ -44,7 +44,7 @@ class InsolverGBMWrapper(InsolverBaseWrapper): def __init__( self, backend: Optional[Literal['xgboost', 'lightgbm', 'catboost']], - task: Optional[Literal['class', 'reg']] = 'reg', + task: Literal['class', 'reg'] = 'reg', objective: Union[None, str, Callable] = None, n_estimators: int = 100, **kwargs: Any, diff --git a/insolver/wrappers_v2/glm.py b/insolver/wrappers_v2/glm.py index 59e67f1..24c6e5b 100644 --- a/insolver/wrappers_v2/glm.py +++ b/insolver/wrappers_v2/glm.py @@ -49,7 +49,7 @@ class InsolverGLMWrapper(InsolverBaseWrapper): def __init__( self, backend: Optional[Literal['sklearn', 'h2o']], - task: Optional[Literal['class', 'reg']] = 'reg', + task: Literal['class', 'reg'] = 'reg', family: Optional[str] = None, link: Optional[str] = None, h2o_server_params: Optional[Dict] = None, diff --git a/insolver/wrappers_v2/utils/__init__.py b/insolver/wrappers_v2/utils/__init__.py index 3dabdd8..41c4585 100644 --- a/insolver/wrappers_v2/utils/__init__.py +++ b/insolver/wrappers_v2/utils/__init__.py @@ -1,4 +1,3 @@ from .save_load_utils import load_model from .save_load_utils import save_pickle, save_dill from .h2o_utils import save_h2o -from .req_utils import get_requirements diff --git a/insolver/wrappers_v2/utils/h2o_utils.py b/insolver/wrappers_v2/utils/h2o_utils.py index 96426a8..61540c5 100644 --- a/insolver/wrappers_v2/utils/h2o_utils.py +++ b/insolver/wrappers_v2/utils/h2o_utils.py @@ -10,6 +10,8 @@ from h2o.estimators import H2OEstimator from h2o import no_progress, cluster, remove_all, connect, load_model, save_model +from ..base import InsolverBaseWrapper + def h2o_start(h2o_server_params: Dict[str, Any] = None) -> None: # nthreads=-1, enable_assertions=True, max_mem_size=None, min_mem_size=None, @@ -77,17 +79,21 @@ def x_y_to_h2o_frame( def save_h2o( - model: H2OEstimator, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any + wrapper: InsolverBaseWrapper, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any ) -> Optional[bytes]: if not ((path_or_buf is None) or (isinstance(path_or_buf, str))): raise ValueError(f"Invalid file path or buffer object {type(path_or_buf)}") - _model_cached = None if '_model_cached' not in kwargs else kwargs.pop('_model_cached') + if hasattr(wrapper, '_model_cached'): + _model_cached = wrapper._model_cached + else: + _model_cached = None + h2o_start() if path_or_buf is None: # Since there no possibility to save h2o model to a variable, workaround is needed if _model_cached is None: - save_model(model=model, filename='.temp_h2o_model_save', **kwargs) + save_model(model=wrapper.model, filename='.temp_h2o_model_save', **kwargs) with open('.temp_h2o_model_save', 'rb') as file: saved = file.read() os.remove('.temp_h2o_model_save') @@ -97,7 +103,11 @@ def save_h2o( else: path, filename = os.path.split(path_or_buf) # force = False, export_cross_validation_predictions = False - save_model(model=model, path=path, filename=filename, **kwargs) + if _model_cached is None: + save_model(model=wrapper.model, path=path, filename=filename, **kwargs) + else: + with open(path_or_buf, 'wb') as file: + file.write(_model_cached) return None diff --git a/insolver/wrappers_v2/utils/save_load_utils.py b/insolver/wrappers_v2/utils/save_load_utils.py index 3511880..3b3df01 100644 --- a/insolver/wrappers_v2/utils/save_load_utils.py +++ b/insolver/wrappers_v2/utils/save_load_utils.py @@ -9,6 +9,8 @@ from .h2o_utils import load_h2o from .req_utils import check_requirements +from ..base import InsolverBaseWrapper + def load(path_or_buf: Union[str, 'PathLike[str]', bytes], saving_method: str, **kwargs: Any) -> Callable: load_config: Dict[str, Callable] = dict(pickle=load_pickle, dill=load_dill, h2o=load_h2o) @@ -53,15 +55,17 @@ def load_model(path_or_buf: Union[str, 'PathLike[str]', IO[bytes]], **kwargs: An ) -def save_pickle(model: Any, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any) -> Optional[bytes]: +def save_pickle( + wrapper: InsolverBaseWrapper, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any +) -> Optional[bytes]: if not ((path_or_buf is None) or (isinstance(path_or_buf, str))): raise ValueError(f"Invalid file path or buffer object {type(path_or_buf)}") if path_or_buf is None: - return pickle.dumps(model, **kwargs) + return pickle.dumps(wrapper.model, **kwargs) else: with open(path_or_buf, "wb") as _file: - pickle.dump(model, _file, **kwargs) + pickle.dump(wrapper.model, _file, **kwargs) return None @@ -73,15 +77,17 @@ def load_pickle(path_or_buf: Union[str, 'PathLike[str]', bytes], **kwargs: Any) return pickle.loads(path_or_buf, **kwargs) -def save_dill(model: Any, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any) -> Optional[bytes]: +def save_dill( + wrapper: InsolverBaseWrapper, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any +) -> Optional[bytes]: if not ((path_or_buf is None) or (isinstance(path_or_buf, str))): raise ValueError(f"Invalid file path or buffer object {type(path_or_buf)}") if path_or_buf is None: - return dill.dumps(model, **kwargs) + return dill.dumps(wrapper.model, **kwargs) else: with open(path_or_buf, "wb") as _file: - dill.dump(model, _file, **kwargs) + dill.dump(wrapper.model, _file, **kwargs) return None