From bb0ef504bdab3edb89eb99e708b0fd8842a7e5b4 Mon Sep 17 00:00:00 2001 From: alexgo Date: Tue, 24 Aug 2021 12:40:05 +0200 Subject: [PATCH] add cpp information to rascal model json * cpp representation and kernel information is now included in rascal's model json file. * weights are now always flattened to one dimension for consistency in the models format. --- bindings/rascal/models/kernels.py | 14 +++++++++++++- bindings/rascal/models/krr.py | 9 +++++++-- .../representations/spherical_covariants.py | 16 +++++++++++++++- .../representations/spherical_expansion.py | 16 +++++++++++++++- .../representations/spherical_invariants.py | 16 +++++++++++++++- 5 files changed, 65 insertions(+), 6 deletions(-) diff --git a/bindings/rascal/models/kernels.py b/bindings/rascal/models/kernels.py index 96801794d..b2140b4d0 100644 --- a/bindings/rascal/models/kernels.py +++ b/bindings/rascal/models/kernels.py @@ -107,9 +107,21 @@ def _get_init_params(self): def _set_data(self, data): super()._set_data(data) + # allows to load deprecated models + if "cpp_kernel" in data.keys(): + self._kernel = self._kernel.from_dict(data["cpp_kernel"]) + else: + print( + "WARNING: a deprecated model was loaded. Key 'cpp_kernel' " + "was not found in model. Please dump and reload the model " + "to update it. The model parameters will not change, only " + "the format." + ) def _get_data(self): - return super()._get_data() + data = super()._get_data() + data.update(cpp_kernel=self._kernel.to_dict()) + return data def __call__(self, X, Y=None, grad=(False, False), compute_neg_stress=False): """ diff --git a/bindings/rascal/models/krr.py b/bindings/rascal/models/krr.py index 6ec674554..12a9de84e 100644 --- a/bindings/rascal/models/krr.py +++ b/bindings/rascal/models/krr.py @@ -209,8 +209,13 @@ def predict_stress(self, managers, KNM=None): return -neg_stress - def get_weights(self): - return self.weights + @property + def weights(self): + return self._weights + + @weights.setter + def weights(self, weights): + self._weights = weights.reshape(-1) def _get_init_params(self): init_params = dict( diff --git a/bindings/rascal/representations/spherical_covariants.py b/bindings/rascal/representations/spherical_covariants.py index 5393ae6de..3bf8e6ec3 100644 --- a/bindings/rascal/representations/spherical_covariants.py +++ b/bindings/rascal/representations/spherical_covariants.py @@ -377,6 +377,20 @@ def _get_init_params(self): def _set_data(self, data): super()._set_data(data) + # allows to load deprecated models + if "cpp_representation" in data.keys(): + self._representation = self._representation.from_dict( + data["cpp_representation"] + ) + else: + print( + "WARNING: a deprecated model was loaded. Key " + "'cpp_representation' was not found in model. Please dump and " + "reload the model to update it. The model parameters will not " + "change, only the format." + ) def _get_data(self): - return super()._get_data() + data = super()._get_data() + data.update(cpp_representation=self._representation.to_dict()) + return data diff --git a/bindings/rascal/representations/spherical_expansion.py b/bindings/rascal/representations/spherical_expansion.py index 81e0a20dc..1a8af2b51 100644 --- a/bindings/rascal/representations/spherical_expansion.py +++ b/bindings/rascal/representations/spherical_expansion.py @@ -322,6 +322,20 @@ def _get_init_params(self): def _set_data(self, data): super()._set_data(data) + # allows to load deprecated models + if "cpp_representation" in data.keys(): + self._representation = self._representation.from_dict( + data["cpp_representation"] + ) + else: + print( + "WARNING: a deprecated model was loaded. Key " + "'cpp_representation' was not found in model. Please dump and " + "reload the model to update it. The model parameters will not " + "change, only the format." + ) def _get_data(self): - return super()._get_data() + data = super()._get_data() + data.update(cpp_representation=self._representation.to_dict()) + return data diff --git a/bindings/rascal/representations/spherical_invariants.py b/bindings/rascal/representations/spherical_invariants.py index f966628dd..ffcbb6219 100644 --- a/bindings/rascal/representations/spherical_invariants.py +++ b/bindings/rascal/representations/spherical_invariants.py @@ -439,6 +439,20 @@ def _get_init_params(self): def _set_data(self, data): super()._set_data(data) + # allows to load deprecated models + if "cpp_representation" in data.keys(): + self._representation = self._representation.from_dict( + data["cpp_representation"] + ) + else: + print( + "WARNING: a deprecated model was loaded. Key " + "'cpp_representation' was not found in model. Please dump and " + "reload the model to update it. The model parameters will not " + "change, only the format." + ) def _get_data(self): - return super()._get_data() + data = super()._get_data() + data.update(cpp_representation=self._representation.to_dict()) + return data