Skip to content

Commit

Permalink
🐛 Removed match case statements to support python 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
Adjorn committed Dec 19, 2024
1 parent 8f4dbf2 commit 0a00de6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 40 deletions.
18 changes: 9 additions & 9 deletions sam/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,17 +697,17 @@ def dump(
backup, self.model_ = self.model_, None

foldername = Path(foldername)
match model_file_extension:
case ".json":
import json
if model_file_extension == ".json":
import json

with open(foldername / (prefix + ".json"), "w") as file:
json.dump(self.to_dict(), file)
case ".pkl":
import cloudpickle
with open(foldername / (prefix + ".json"), "w") as file:
json.dump(self.to_dict(), file)

if model_file_extension == ".pkl":
import cloudpickle

with open(foldername / (prefix + ".pkl"), "wb") as file:
cloudpickle.dump(self, file)
with open(foldername / (prefix + ".pkl"), "wb") as file:
cloudpickle.dump(self, file)

if backup is not None:
self.model_ = backup
Expand Down
31 changes: 15 additions & 16 deletions sam/models/constant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,22 +294,21 @@ def predict(
def dump_parameters(
self, foldername: str, prefix: str = "model", file_extension=".json"
) -> None:
match file_extension:
case ".json":
parameters = vars(self.model_)
parameters["model_quantiles_"] = parameters["model_quantiles_"].tolist()
with open(Path(foldername) / f"{prefix}_params.json", "w") as f:
json.dump(obj=parameters, fp=f)
case ".pkl":
import cloudpickle

with open(Path(foldername) / f"{prefix}_params.pkl", "wb") as f:
cloudpickle.dump(self.model_, f)
case _:
raise ValueError(
f"The file extension: {file_extension} "
f"is not supported choose '.pkl' or '.json'"
)
if file_extension == ".json":
parameters = vars(self.model_)
parameters["model_quantiles_"] = parameters["model_quantiles_"].tolist()
with open(Path(foldername) / f"{prefix}_params.json", "w") as f:
json.dump(obj=parameters, fp=f)
return
if file_extension == ".pkl":
import cloudpickle

with open(Path(foldername) / f"{prefix}_params.pkl", "wb") as f:
cloudpickle.dump(self.model_, f)
return
raise ValueError(
f"The file extension: {file_extension} " f"is not supported choose '.pkl' or '.json'"
)

@staticmethod
def load_parameters(obj, foldername: str, prefix: str = "model") -> Any:
Expand Down
29 changes: 14 additions & 15 deletions sam/models/mlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,21 +384,20 @@ def dump_parameters(

check_is_fitted(self, "model_")
foldername = Path(foldername)
match file_extension:
case ".onnx":
input_signature = [tf.TensorSpec((None, *self.input_shape), name="X")]
onnx_model, _ = tf2onnx.convert.from_keras(
self.model_, input_signature=input_signature, opset=13
)
onnx.save(onnx_model, foldername / (prefix + ".onnx"))
case ".h5":
self.model_.save(foldername / (prefix + ".h5"))

case _:
raise ValueError(
f"The file extension: {file_extension} "
f"is not supported choose '.pkl' or '.json'"
)
if file_extension == ".onnx":
input_signature = [tf.TensorSpec((None, *self.input_shape), name="X")]
onnx_model, _ = tf2onnx.convert.from_keras(
self.model_, input_signature=input_signature, opset=13
)
onnx.save(onnx_model, foldername / (prefix + ".onnx"))
return
if file_extension == ".h5":
self.model_.save(foldername / (prefix + ".h5"))
return

raise ValueError(
f"The file extension: {file_extension} " f"is not supported choose '.pkl' or '.json'"
)

@staticmethod
def load_parameters(obj, foldername: str, prefix: str = "model") -> Any:
Expand Down

0 comments on commit 0a00de6

Please sign in to comment.