Skip to content

Commit

Permalink
🚨 Ran black to fix linter issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adjorn committed Dec 19, 2024
1 parent 514a1c3 commit 5e534c2
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 43 deletions.
39 changes: 23 additions & 16 deletions sam/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,9 @@ def score(self, X: pd.DataFrame, y: pd.Series) -> float:
return score

@abstractmethod
def dump_parameters(self, foldername: str, prefix: str = "model", file_extension='.pkl') -> None:
def dump_parameters(
self, foldername: str, prefix: str = "model", file_extension=".pkl"
) -> None:
"""
Save a model to disk
Expand All @@ -656,11 +658,11 @@ def dump_parameters(self, foldername: str, prefix: str = "model", file_extension
...

def dump(
self,
foldername: str,
prefix: str = "model",
model_file_extension: str = ".pkl",
weights_file_extension: str = None
self,
foldername: str,
prefix: str = "model",
model_file_extension: str = ".pkl",
weights_file_extension: str = None,
):
"""
Writes the following files:
Expand All @@ -687,7 +689,11 @@ def dump(
if hasattr(self, "model_"):
check_is_fitted(self, "model_")
# Dirty but we need to get the default file extension of the inheritor when file_extension is None
dump_kwargs = {} if weights_file_extension is None else {"file_extension": weights_file_extension}
dump_kwargs = (
{}
if weights_file_extension is None
else {"file_extension": weights_file_extension}
)
self.dump_parameters(foldername=foldername, prefix=prefix, **dump_kwargs)
# Set the models to None temporarily, because they can't be pickled
backup, self.model_ = self.model_, None
Expand All @@ -696,10 +702,12 @@ def dump(
match model_file_extension:
case ".json":
import json

with open(foldername / (prefix + ".json"), "w") as file:
json.dump(self.to_dict(), file)
case ".pkl":
import cloudpickle

with open(foldername / (prefix + ".pkl"), "wb") as file:
cloudpickle.dump(self, file)

Expand Down Expand Up @@ -729,6 +737,7 @@ def load(cls, foldername: str, prefix: str = "model"):
The SAM model that has been loaded from disk
"""
import os

foldername = Path(foldername)
file_path = foldername / prefix
obj = None
Expand All @@ -739,6 +748,7 @@ def load(cls, foldername: str, prefix: str = "model"):
if os.path.exists(file_path := file_path.with_suffix(".pkl")):
with open(file_path, "rb") as f:
import cloudpickle

obj = cloudpickle.load(f)

if obj is None:
Expand All @@ -753,20 +763,17 @@ def to_dict(self):
"""
Creates a dictionary used to recreate the BaseTimeseriesRegressor for prediction.
"""
required_objects = {n: getattr(self, n) for n in self.to_save_objects if
hasattr(self, n)}
required_objects = {n: getattr(self, n) for n in self.to_save_objects if hasattr(self, n)}

object_data = {}
for name, obj in required_objects.items():
data = object_to_dict(obj)
object_data[name] = data

class_data = {name: getattr(self, name) for name in self.to_save_parameters if
hasattr(self, name)}
return {
"objects": object_data,
"class_parameters": class_data
class_data = {
name: getattr(self, name) for name in self.to_save_parameters if hasattr(self, name)
}
return {"objects": object_data, "class_parameters": class_data}

@classmethod
def from_dict(cls, params: dict[str, Any]):
Expand All @@ -775,7 +782,7 @@ def from_dict(cls, params: dict[str, Any]):
"""
# Initialize the saved objects
initialized_objects = {}
for name, data in params['objects'].items():
for name, data in params["objects"].items():
if data is None:
initialized_objects[name] = None
continue
Expand All @@ -784,7 +791,7 @@ def from_dict(cls, params: dict[str, Any]):
initialized_objects[name] = obj
class_object = cls()

to_set = params['class_parameters'] | initialized_objects
to_set = params["class_parameters"] | initialized_objects

for name, value in to_set.items():
if hasattr(class_object, name):
Expand Down
20 changes: 13 additions & 7 deletions sam/models/constant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,28 +291,33 @@ def predict(
else:
return prediction

def dump_parameters(self, foldername: str, prefix: str = "model",
file_extension='.json') -> None:
def dump_parameters(
self, foldername: str, prefix: str = "model", file_extension=".json"
) -> None:
match file_extension:
case '.json':
case ".json":
parameters = vars(self.model_)
parameters['model_quantiles_'] = parameters['model_quantiles_'].tolist()
parameters["model_quantiles_"] = parameters["model_quantiles_"].tolist()
with open(Path(foldername) / f"{prefix}_params.json", "w") as f:
json.dump(obj=parameters, fp=f)
case '.pkl':
case ".pkl":
import cloudpickle

with open(Path(foldername) / f"{prefix}_params.pkl", "wb") as f:
cloudpickle.dump(self.model_, f)
case _:
raise ValueError(f"The file extension: {file_extension} is not supported choose '.pkl' or '.json'")
raise ValueError(
f"The file extension: {file_extension} is not supported choose '.pkl' or '.json'"
)

@staticmethod
def load_parameters(obj, foldername: str, prefix: str = "model") -> Any:
import os

foldername = Path(foldername)
file_path = foldername / (prefix + "_params")
if os.path.exists(file_path := file_path.with_suffix(".json")):
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
parameters = json.load(f)
model = ConstantTemplate()
for name, value in parameters.items():
Expand All @@ -323,6 +328,7 @@ def load_parameters(obj, foldername: str, prefix: str = "model") -> Any:
return model
if os.path.exists(file_path := file_path.with_suffix(".pkl")):
import cloudpickle

with open(file_path, "rb") as f:
model = cloudpickle.load(f)
return model
Expand Down
18 changes: 11 additions & 7 deletions sam/models/mlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def predict(
The transformed input data, when return_data is True, otherwise None
"""
import onnxruntime as ort

self.validate_data(X)

if y is None and self.use_diff_of_y:
Expand All @@ -356,7 +357,9 @@ def predict(
else:
return prediction

def dump_parameters(self, foldername: str, prefix: str = "model", file_extension='.h5') -> None:
def dump_parameters(
self, foldername: str, prefix: str = "model", file_extension=".h5"
) -> None:
"""
Writes the following files:
* prefix.h5
Expand All @@ -378,23 +381,23 @@ def dump_parameters(self, foldername: str, prefix: str = "model", file_extension
import tf2onnx
import onnx
import tensorflow as tf

check_is_fitted(self, "model_")
foldername = Path(foldername)
match file_extension:
case ".onnx":
input_signature = [tf.TensorSpec((None, *self.input_shape), name='X')]
input_signature = [tf.TensorSpec((None, *self.input_shape), name="X")]
onnx_model, _ = tf2onnx.convert.from_keras(
self.model_,
input_signature=input_signature,
opset=13
self.model_, input_signature=input_signature, opset=13
)
onnx.save(onnx_model, foldername / (prefix + '.onnx'))
onnx.save(onnx_model, foldername / (prefix + ".onnx"))
case ".h5":
self.model_.save(foldername / (prefix + ".h5"))

case _:
raise ValueError(
f"The file extension: {file_extension} is not supported choose '.pkl' or '.json'")
f"The file extension: {file_extension} is not supported choose '.pkl' or '.json'"
)

@staticmethod
def load_parameters(obj, foldername: str, prefix: str = "model") -> Any:
Expand All @@ -411,6 +414,7 @@ def load_parameters(obj, foldername: str, prefix: str = "model") -> Any:
import keras
import os
import onnxruntime as ort

foldername = Path(foldername)
loss = obj._get_loss()
file_path = foldername / prefix
Expand Down
15 changes: 8 additions & 7 deletions sam/models/tests/test_mlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ class TestPipelineFeatureEngineer(unittest.TestCase):
def test_get_feature_names_out(self):
from sklearn.pipeline import Pipeline
from sam.feature_engineering import BuildRollingFeatures

X, y = get_dataset()
fe = Pipeline([
("roll", BuildRollingFeatures(window_size='1h')),
("scaler", StandardScaler())
])
fe = Pipeline(
[("roll", BuildRollingFeatures(window_size="1h")), ("scaler", StandardScaler())]
)
model = MLPTimeseriesRegressor(epochs=1, feature_engineer=fe)
model.fit(X, y)
feature_names = model.get_feature_names_out()
self.assertListEqual(list(feature_names), ['x', 'x#mean_1h'])
self.assertListEqual(list(feature_names), ["x", "x#mean_1h"])


class TestLoadDump(unittest.TestCase):
Expand All @@ -167,6 +167,7 @@ class TestLoadDump(unittest.TestCase):
@classmethod
def setUpClass(cls):
import os

os.makedirs(cls.file_dir, exist_ok=True)

@classmethod
Expand All @@ -187,7 +188,7 @@ def test_dump_load_parameters(self):
model = MLPTimeseriesRegressor(epochs=1, feature_engineer=fe)
model.fit(X, y)

model.dump_parameters(foldername=self.file_dir, file_extension='.onnx')
model.dump_parameters(foldername=self.file_dir, file_extension=".onnx")
y_pred_tf = model.predict(X=X)
self.assertIsInstance(model.model_, keras.Model)
model.model_ = model.load_parameters(obj=model, foldername=self.file_dir)
Expand All @@ -206,7 +207,7 @@ def test_to_from_dict(self):
model = MLPTimeseriesRegressor(epochs=1, feature_engineer=fe, y_scaler=StandardScaler())
model.fit(X, y)

model.dump_parameters(foldername=self.file_dir, file_extension='.onnx')
model.dump_parameters(foldername=self.file_dir, file_extension=".onnx")
params = model.to_dict()
y_pred_tf = model.predict(X=X)
self.assertIsInstance(model.model_, keras.Model)
Expand Down
12 changes: 6 additions & 6 deletions sam/utils/json_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def object_to_dict(obj):

# Usually this is overkill
if hasattr(obj, "get_params"):
data['params'] = obj.get_params(deep=True)
data["params"] = obj.get_params(deep=True)

return data

Expand All @@ -56,15 +56,15 @@ def object_from_dict(data):
if result:
return result

obj_class = getattr(importlib.import_module(data['module']), data['class'])
obj_class = getattr(importlib.import_module(data["module"]), data["class"])
# Get the arguments expected by the __init__ of AquasuiteModel
signature = inspect.signature(obj_class.__init__)
# Get the arguments which are in the __init__
arguments = [param.name for param in signature.parameters.values()]
# Get the arguments which are in the vars and in the __init__
found_arguments = {k: v for k, v in data['vars'].items() if k in arguments}
found_arguments = {k: v for k, v in data["vars"].items() if k in arguments}
# Get the attributes which are in the vars and not in the __init__
found_attributes = {k: v for k, v in data['vars'].items() if k not in arguments}
found_attributes = {k: v for k, v in data["vars"].items() if k not in arguments}
obj = obj_class(**found_arguments)

# Set the found attributes
Expand All @@ -74,6 +74,6 @@ def object_from_dict(data):
setattr(obj, param_name, value)

# Set the params if it has the `set_params` function.
if 'params' in data.keys() and hasattr(obj_class, 'set_params'):
obj.set_params(**data['params'])
if "params" in data.keys() and hasattr(obj_class, "set_params"):
obj.set_params(**data["params"])
return obj

0 comments on commit 5e534c2

Please sign in to comment.