From e0bf9868520bcc8a4227f0e9d06afeaa08ac5d42 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Mon, 20 Jan 2025 11:46:18 +0100 Subject: [PATCH] Make _check_xyz_tensor_map ~3x faster by accessing data once --- python/src/sphericart/metatensor.py | 17 +++++++++++------ .../python/sphericart/torch/metatensor.py | 17 +++++++++++------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/python/src/sphericart/metatensor.py b/python/src/sphericart/metatensor.py index 4feebf937..81bfe8eb6 100644 --- a/python/src/sphericart/metatensor.py +++ b/python/src/sphericart/metatensor.py @@ -228,15 +228,21 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: def _check_xyz_tensor_map(xyz: TensorMap): - if len(xyz.blocks()) != 1: + blocks = xyz.blocks() + if len(blocks) != 1: raise ValueError("`xyz` should have only one block") - if len(xyz.block().components) != 1: + + block = blocks[0] + components = block.components + if len(components) != 1: raise ValueError("`xyz` should have only one component") - if xyz.block().components[0].names != ["xyz"]: + if components[0].names != ["xyz"]: raise ValueError("`xyz` should have only one component named 'xyz'") - if xyz.block().components[0].values.shape[0] != 3: + + values_shape = block.values.shape + if values_shape[1] != 3: raise ValueError("`xyz` should have 3 Cartesian coordinates") - if xyz.block().properties.values.shape[0] != 1: + if values_shape[2] != 1: raise ValueError("`xyz` should have only one property") @@ -251,7 +257,6 @@ def _wrap_into_tensor_map( sh_gradients: Optional[np.ndarray] = None, sh_hessians: Optional[np.ndarray] = None, ) -> TensorMap: - # infer l_max l_max = len(components) - 1 diff --git a/sphericart-torch/python/sphericart/torch/metatensor.py b/sphericart-torch/python/sphericart/torch/metatensor.py index f93a837be..5e704d859 100644 --- a/sphericart-torch/python/sphericart/torch/metatensor.py +++ b/sphericart-torch/python/sphericart/torch/metatensor.py @@ -250,15 +250,21 @@ def _send_precomputed_labels_to_device(self, device): def _check_xyz_tensor_map(xyz: TensorMap): - if len(xyz.blocks()) != 1: + blocks = xyz.blocks() + if len(blocks) != 1: raise ValueError("`xyz` should have only one block") - if len(xyz.block().components) != 1: + + block = blocks[0] + components = block.components + if len(components) != 1: raise ValueError("`xyz` should have only one component") - if xyz.block().components[0].names != ["xyz"]: + if components[0].names != ["xyz"]: raise ValueError("`xyz` should have only one component named 'xyz'") - if xyz.block().components[0].values.shape[0] != 3: + + values_shape = block.values.shape + if values_shape[1] != 3: raise ValueError("`xyz` should have 3 Cartesian coordinates") - if xyz.block().properties.values.shape[0] != 1: + if values_shape[2] != 1: raise ValueError("`xyz` should have only one property") @@ -273,7 +279,6 @@ def _wrap_into_tensor_map( sh_gradients: Optional[torch.Tensor] = None, sh_hessians: Optional[torch.Tensor] = None, ) -> TensorMap: - # infer l_max l_max = len(components) - 1