-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Description
The two objects have to implemented but not _apply, which will cause problems when they're used as submodules in a torch model. For example, if a MACE model using OpenEquivariance as backend is constructed with float32 and later converted by mace_model.to(torch.float64), it'll throw: RuntimeError: Dtype mismatch for tensor 'L1_in'. Expected: Float. Got: double. This happens because PyTorch uses _apply to convert submodules when to is called on the parent model.
The _apply can be added by utilizing the to method in the modules like:
def _apply(self, fn, recurse=True):
if getattr(self, '_applying', False):
return super()._apply(fn, recurse)
problem: TPProblem = self.input_args["problem"]
irrep_dtype = problem.irrep_dtype
if irrep_dtype in dtype_to_enum:
irrep_dtype = dtype_to_enum[irrep_dtype]
current_dtype = enum_to_torch_dtype[irrep_dtype]
dummy = torch.tensor(0.0, dtype=current_dtype)
result = fn(dummy)
if result.dtype != current_dtype:
self._applying = True
self.to(result.dtype)
self._applying = False
return super()._apply(fn, recurse)
Metadata
Metadata
Assignees
Labels
No labels