@@ -936,45 +936,56 @@ def _split_by_compound_indices(self, compound, stable_sort=False):
936
936
n_compounds : int
937
937
The number of individual compounds.
938
938
"""
939
- # Caching would help here, especially when repeating the operation
940
- # over different frames, since these masks are coordinate-independent.
941
- # However, cache must be invalidated whenever new compound indices are
942
- # modified, which is not yet implemented.
943
- # Also, should we include here the grouping for 'group', which is
939
+ # Should we include here the grouping for 'group', which is
944
940
# essentially a non-split?
945
941
942
+ cache_key = f"{ compound } _masks"
946
943
compound_indices = self ._get_compound_indices (compound )
947
- compound_sizes = np .bincount (compound_indices )
948
- size_per_atom = compound_sizes [compound_indices ]
949
- compound_sizes = compound_sizes [compound_sizes != 0 ]
950
- unique_compound_sizes = unique_int_1d (compound_sizes )
951
-
952
- # Are we already sorted? argsorting and fancy-indexing can be expensive
953
- # so we do a quick pre-check.
954
- needs_sorting = np .any (np .diff (compound_indices ) < 0 )
955
- if needs_sorting :
956
- # stable sort ensures reproducibility, especially concerning who
957
- # gets to be a compound's atom[0] and be a reference for unwrap.
958
- if stable_sort :
959
- sort_indices = np .argsort (compound_indices , kind = 'stable' )
960
- else :
961
- # Quicksort
962
- sort_indices = np .argsort (compound_indices )
963
- # We must sort size_per_atom accordingly (Issue #3352).
964
- size_per_atom = size_per_atom [sort_indices ]
965
-
966
- compound_masks = []
967
- atom_masks = []
968
- for compound_size in unique_compound_sizes :
969
- compound_masks .append (compound_sizes == compound_size )
944
+
945
+ # create new cache or invalidate cache when compound indices changed
946
+ if (
947
+ cache_key not in self ._cache
948
+ or np .all (self ._cache [cache_key ]["compound_indices" ]
949
+ != compound_indices )):
950
+ compound_sizes = np .bincount (compound_indices )
951
+ size_per_atom = compound_sizes [compound_indices ]
952
+ compound_sizes = compound_sizes [compound_sizes != 0 ]
953
+ unique_compound_sizes = unique_int_1d (compound_sizes )
954
+
955
+ # Are we already sorted? argsorting and fancy-indexing can be
956
+ # expensive so we do a quick pre-check.
957
+ needs_sorting = np .any (np .diff (compound_indices ) < 0 )
970
958
if needs_sorting :
971
- atom_masks .append (sort_indices [size_per_atom == compound_size ]
972
- .reshape (- 1 , compound_size ))
973
- else :
974
- atom_masks .append (np .where (size_per_atom == compound_size )[0 ]
975
- .reshape (- 1 , compound_size ))
959
+ # stable sort ensures reproducibility, especially concerning
960
+ # who gets to be a compound's atom[0] and be a reference for
961
+ # unwrap.
962
+ if stable_sort :
963
+ sort_indices = np .argsort (compound_indices , kind = 'stable' )
964
+ else :
965
+ # Quicksort
966
+ sort_indices = np .argsort (compound_indices )
967
+ # We must sort size_per_atom accordingly (Issue #3352).
968
+ size_per_atom = size_per_atom [sort_indices ]
969
+
970
+ compound_masks = []
971
+ atom_masks = []
972
+ for compound_size in unique_compound_sizes :
973
+ compound_masks .append (compound_sizes == compound_size )
974
+ if needs_sorting :
975
+ atom_masks .append (sort_indices [size_per_atom
976
+ == compound_size ]
977
+ .reshape (- 1 , compound_size ))
978
+ else :
979
+ atom_masks .append (np .where (size_per_atom
980
+ == compound_size )[0 ]
981
+ .reshape (- 1 , compound_size ))
982
+
983
+ self ._cache [cache_key ] = {
984
+ "compound_indices" : compound_indices ,
985
+ "data" : (atom_masks , compound_masks , len (compound_sizes ))
986
+ }
976
987
977
- return atom_masks , compound_masks , len ( compound_sizes )
988
+ return self . _cache [ cache_key ][ "data" ]
978
989
979
990
@warn_if_not_unique
980
991
@_pbc_to_wrap
0 commit comments