Skip to content

Add _apply method to TensorProduct and TensorProductConv for proper dtype conversion #180

@laphysique

Description

@laphysique

The two objects have to implemented but not _apply, which will cause problems when they're used as submodules in a torch model. For example, if a MACE model using OpenEquivariance as backend is constructed with float32 and later converted by mace_model.to(torch.float64), it'll throw: RuntimeError: Dtype mismatch for tensor 'L1_in'. Expected: Float. Got: double. This happens because PyTorch uses _apply to convert submodules when to is called on the parent model.

The _apply can be added by utilizing the to method in the modules like:

def _apply(self, fn, recurse=True):
    if getattr(self, '_applying', False):
        return super()._apply(fn, recurse)

    problem: TPProblem = self.input_args["problem"]
    irrep_dtype = problem.irrep_dtype

    if irrep_dtype in dtype_to_enum:
        irrep_dtype = dtype_to_enum[irrep_dtype]

    current_dtype = enum_to_torch_dtype[irrep_dtype]

    dummy = torch.tensor(0.0, dtype=current_dtype)
    result = fn(dummy)

    if result.dtype != current_dtype:
        self._applying = True
        self.to(result.dtype)
        self._applying = False

    return super()._apply(fn, recurse)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions