Skip to content

Commit

Permalink
Merge pull request #185 from MindSetLib/dev
Browse files Browse the repository at this point in the history
Fixing loading and saving models
  • Loading branch information
alexmindset authored Aug 8, 2022
2 parents b9dcbe3 + 8cb6eb6 commit 6d5a8c8
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 29 deletions.
22 changes: 6 additions & 16 deletions insolver/wrappers_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
else:
from typing_extensions import Literal

from .utils import get_requirements
from .utils.req_utils import get_requirements


class InsolverWrapperWarning(Warning):
Expand Down Expand Up @@ -58,7 +58,7 @@ def _save_insolver(self, path_or_buf: Union[str, 'PathLike[str]'], method: Calla
zip_file.writestr("requirements.txt", get_requirements())
zip_file.writestr(
f"model_{os.path.basename(path_or_buf)}",
BytesIO(method(self.model, path_or_buf=None, **kwargs)).getvalue(),
BytesIO(method(self, path_or_buf=None, **kwargs)).getvalue(),
)

with open(path_or_buf if str(path_or_buf).endswith('.zip') else f'{path_or_buf}.zip', "wb") as f:
Expand Down Expand Up @@ -109,24 +109,14 @@ def save_model(

if path_or_buf is None:
if self._model_cached is None:
return self._backend_saving_methods[self.backend][method](self.model, path_or_buf, **kwargs)
return self._backend_saving_methods[self.backend][method](self, path_or_buf, **kwargs)
else:
return self._model_cached
else:
if mode == "insolver":
self.metadata.update({"saving_method": method})
if self._model_cached is None:
self._save_insolver(
path_or_buf, method=self._backend_saving_methods[self.backend][method], **kwargs
)
else:
self._save_insolver(
path_or_buf,
method=self._backend_saving_methods[self.backend][method],
_model_cached=self._model_cached,
**kwargs,
)
path_or_buf = f'{path_or_buf}.zip'
self._save_insolver(path_or_buf, method=self._backend_saving_methods[self.backend][method], **kwargs)
path_or_buf = f'{path_or_buf}.zip' if not path_or_buf.endswith('.zip') else path_or_buf
else:
self._backend_saving_methods[self.backend][method](self.model, path_or_buf, **kwargs)
self._backend_saving_methods[self.backend][method](self, path_or_buf, **kwargs)
return f"Saved model: {os.path.normpath(path_or_buf)}"
2 changes: 1 addition & 1 deletion insolver/wrappers_v2/gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class InsolverGBMWrapper(InsolverBaseWrapper):
def __init__(
self,
backend: Optional[Literal['xgboost', 'lightgbm', 'catboost']],
task: Optional[Literal['class', 'reg']] = 'reg',
task: Literal['class', 'reg'] = 'reg',
objective: Union[None, str, Callable] = None,
n_estimators: int = 100,
**kwargs: Any,
Expand Down
2 changes: 1 addition & 1 deletion insolver/wrappers_v2/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class InsolverGLMWrapper(InsolverBaseWrapper):
def __init__(
self,
backend: Optional[Literal['sklearn', 'h2o']],
task: Optional[Literal['class', 'reg']] = 'reg',
task: Literal['class', 'reg'] = 'reg',
family: Optional[str] = None,
link: Optional[str] = None,
h2o_server_params: Optional[Dict] = None,
Expand Down
1 change: 0 additions & 1 deletion insolver/wrappers_v2/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .save_load_utils import load_model
from .save_load_utils import save_pickle, save_dill
from .h2o_utils import save_h2o
from .req_utils import get_requirements
18 changes: 14 additions & 4 deletions insolver/wrappers_v2/utils/h2o_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from h2o.estimators import H2OEstimator
from h2o import no_progress, cluster, remove_all, connect, load_model, save_model

from ..base import InsolverBaseWrapper


def h2o_start(h2o_server_params: Dict[str, Any] = None) -> None:
# nthreads=-1, enable_assertions=True, max_mem_size=None, min_mem_size=None,
Expand Down Expand Up @@ -77,17 +79,21 @@ def x_y_to_h2o_frame(


def save_h2o(
model: H2OEstimator, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any
wrapper: InsolverBaseWrapper, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any
) -> Optional[bytes]:
if not ((path_or_buf is None) or (isinstance(path_or_buf, str))):
raise ValueError(f"Invalid file path or buffer object {type(path_or_buf)}")

_model_cached = None if '_model_cached' not in kwargs else kwargs.pop('_model_cached')
if hasattr(wrapper, '_model_cached'):
_model_cached = wrapper._model_cached
else:
_model_cached = None
h2o_start()

if path_or_buf is None:
# Since there no possibility to save h2o model to a variable, workaround is needed
if _model_cached is None:
save_model(model=model, filename='.temp_h2o_model_save', **kwargs)
save_model(model=wrapper.model, filename='.temp_h2o_model_save', **kwargs)
with open('.temp_h2o_model_save', 'rb') as file:
saved = file.read()
os.remove('.temp_h2o_model_save')
Expand All @@ -97,7 +103,11 @@ def save_h2o(
else:
path, filename = os.path.split(path_or_buf)
# force = False, export_cross_validation_predictions = False
save_model(model=model, path=path, filename=filename, **kwargs)
if _model_cached is None:
save_model(model=wrapper.model, path=path, filename=filename, **kwargs)
else:
with open(path_or_buf, 'wb') as file:
file.write(_model_cached)
return None


Expand Down
18 changes: 12 additions & 6 deletions insolver/wrappers_v2/utils/save_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .h2o_utils import load_h2o
from .req_utils import check_requirements

from ..base import InsolverBaseWrapper


def load(path_or_buf: Union[str, 'PathLike[str]', bytes], saving_method: str, **kwargs: Any) -> Callable:
load_config: Dict[str, Callable] = dict(pickle=load_pickle, dill=load_dill, h2o=load_h2o)
Expand Down Expand Up @@ -53,15 +55,17 @@ def load_model(path_or_buf: Union[str, 'PathLike[str]', IO[bytes]], **kwargs: An
)


def save_pickle(model: Any, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any) -> Optional[bytes]:
def save_pickle(
wrapper: InsolverBaseWrapper, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any
) -> Optional[bytes]:
if not ((path_or_buf is None) or (isinstance(path_or_buf, str))):
raise ValueError(f"Invalid file path or buffer object {type(path_or_buf)}")

if path_or_buf is None:
return pickle.dumps(model, **kwargs)
return pickle.dumps(wrapper.model, **kwargs)
else:
with open(path_or_buf, "wb") as _file:
pickle.dump(model, _file, **kwargs)
pickle.dump(wrapper.model, _file, **kwargs)
return None


Expand All @@ -73,15 +77,17 @@ def load_pickle(path_or_buf: Union[str, 'PathLike[str]', bytes], **kwargs: Any)
return pickle.loads(path_or_buf, **kwargs)


def save_dill(model: Any, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any) -> Optional[bytes]:
def save_dill(
wrapper: InsolverBaseWrapper, path_or_buf: Union[None, str, 'PathLike[str]'] = None, **kwargs: Any
) -> Optional[bytes]:
if not ((path_or_buf is None) or (isinstance(path_or_buf, str))):
raise ValueError(f"Invalid file path or buffer object {type(path_or_buf)}")

if path_or_buf is None:
return dill.dumps(model, **kwargs)
return dill.dumps(wrapper.model, **kwargs)
else:
with open(path_or_buf, "wb") as _file:
dill.dump(model, _file, **kwargs)
dill.dump(wrapper.model, _file, **kwargs)
return None


Expand Down

0 comments on commit 6d5a8c8

Please sign in to comment.