Skip to content

Commit

Permalink
Merge pull request #414 from lab-cosmo/feat/qickdirty-cpp-paramaters-…
Browse files Browse the repository at this point in the history
…to-mlmodel

Adapt cpp representation input paramaters compatible with the current lammps-rascal interface
  • Loading branch information
PicoCentauri authored Jun 15, 2022
2 parents 3df1cd1 + bb0ef50 commit 5f6dd49
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 6 deletions.
14 changes: 13 additions & 1 deletion bindings/rascal/models/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,21 @@ def _get_init_params(self):

def _set_data(self, data):
super()._set_data(data)
# allows to load deprecated models
if "cpp_kernel" in data.keys():
self._kernel = self._kernel.from_dict(data["cpp_kernel"])
else:
print(
"WARNING: a deprecated model was loaded. Key 'cpp_kernel' "
"was not found in model. Please dump and reload the model "
"to update it. The model parameters will not change, only "
"the format."
)

def _get_data(self):
return super()._get_data()
data = super()._get_data()
data.update(cpp_kernel=self._kernel.to_dict())
return data

def __call__(self, X, Y=None, grad=(False, False), compute_neg_stress=False):
"""
Expand Down
9 changes: 7 additions & 2 deletions bindings/rascal/models/krr.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,13 @@ def predict_stress(self, managers, KNM=None):

return -neg_stress

def get_weights(self):
return self.weights
@property
def weights(self):
return self._weights

@weights.setter
def weights(self, weights):
self._weights = weights.reshape(-1)

def _get_init_params(self):
init_params = dict(
Expand Down
16 changes: 15 additions & 1 deletion bindings/rascal/representations/spherical_covariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,20 @@ def _get_init_params(self):

def _set_data(self, data):
super()._set_data(data)
# allows to load deprecated models
if "cpp_representation" in data.keys():
self._representation = self._representation.from_dict(
data["cpp_representation"]
)
else:
print(
"WARNING: a deprecated model was loaded. Key "
"'cpp_representation' was not found in model. Please dump and "
"reload the model to update it. The model parameters will not "
"change, only the format."
)

def _get_data(self):
return super()._get_data()
data = super()._get_data()
data.update(cpp_representation=self._representation.to_dict())
return data
16 changes: 15 additions & 1 deletion bindings/rascal/representations/spherical_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,20 @@ def _get_init_params(self):

def _set_data(self, data):
super()._set_data(data)
# allows to load deprecated models
if "cpp_representation" in data.keys():
self._representation = self._representation.from_dict(
data["cpp_representation"]
)
else:
print(
"WARNING: a deprecated model was loaded. Key "
"'cpp_representation' was not found in model. Please dump and "
"reload the model to update it. The model parameters will not "
"change, only the format."
)

def _get_data(self):
return super()._get_data()
data = super()._get_data()
data.update(cpp_representation=self._representation.to_dict())
return data
16 changes: 15 additions & 1 deletion bindings/rascal/representations/spherical_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,20 @@ def _get_init_params(self):

def _set_data(self, data):
super()._set_data(data)
# allows to load deprecated models
if "cpp_representation" in data.keys():
self._representation = self._representation.from_dict(
data["cpp_representation"]
)
else:
print(
"WARNING: a deprecated model was loaded. Key "
"'cpp_representation' was not found in model. Please dump and "
"reload the model to update it. The model parameters will not "
"change, only the format."
)

def _get_data(self):
return super()._get_data()
data = super()._get_data()
data.update(cpp_representation=self._representation.to_dict())
return data

0 comments on commit 5f6dd49

Please sign in to comment.