Skip to content

Commit 2f537e4

Browse files
committed
Cache creation of compound masks
1 parent 347a0c0 commit 2f537e4

File tree

2 files changed

+82
-36
lines changed

2 files changed

+82
-36
lines changed

package/MDAnalysis/core/groups.py

+45-34
Original file line numberDiff line numberDiff line change
@@ -936,45 +936,56 @@ def _split_by_compound_indices(self, compound, stable_sort=False):
936936
n_compounds : int
937937
The number of individual compounds.
938938
"""
939-
# Caching would help here, especially when repeating the operation
940-
# over different frames, since these masks are coordinate-independent.
941-
# However, cache must be invalidated whenever new compound indices are
942-
# modified, which is not yet implemented.
943-
# Also, should we include here the grouping for 'group', which is
939+
# Should we include here the grouping for 'group', which is
944940
# essentially a non-split?
945941

942+
cache_key = f"{compound}_masks"
946943
compound_indices = self._get_compound_indices(compound)
947-
compound_sizes = np.bincount(compound_indices)
948-
size_per_atom = compound_sizes[compound_indices]
949-
compound_sizes = compound_sizes[compound_sizes != 0]
950-
unique_compound_sizes = unique_int_1d(compound_sizes)
951-
952-
# Are we already sorted? argsorting and fancy-indexing can be expensive
953-
# so we do a quick pre-check.
954-
needs_sorting = np.any(np.diff(compound_indices) < 0)
955-
if needs_sorting:
956-
# stable sort ensures reproducibility, especially concerning who
957-
# gets to be a compound's atom[0] and be a reference for unwrap.
958-
if stable_sort:
959-
sort_indices = np.argsort(compound_indices, kind='stable')
960-
else:
961-
# Quicksort
962-
sort_indices = np.argsort(compound_indices)
963-
# We must sort size_per_atom accordingly (Issue #3352).
964-
size_per_atom = size_per_atom[sort_indices]
965-
966-
compound_masks = []
967-
atom_masks = []
968-
for compound_size in unique_compound_sizes:
969-
compound_masks.append(compound_sizes == compound_size)
944+
945+
# create new cache or invalidate cache when compound indices changed
946+
if (
947+
cache_key not in self._cache
948+
or np.all(self._cache[cache_key]["compound_indices"]
949+
!= compound_indices)):
950+
compound_sizes = np.bincount(compound_indices)
951+
size_per_atom = compound_sizes[compound_indices]
952+
compound_sizes = compound_sizes[compound_sizes != 0]
953+
unique_compound_sizes = unique_int_1d(compound_sizes)
954+
955+
# Are we already sorted? argsorting and fancy-indexing can be
956+
# expensive so we do a quick pre-check.
957+
needs_sorting = np.any(np.diff(compound_indices) < 0)
970958
if needs_sorting:
971-
atom_masks.append(sort_indices[size_per_atom == compound_size]
972-
.reshape(-1, compound_size))
973-
else:
974-
atom_masks.append(np.where(size_per_atom == compound_size)[0]
975-
.reshape(-1, compound_size))
959+
# stable sort ensures reproducibility, especially concerning
960+
# who gets to be a compound's atom[0] and be a reference for
961+
# unwrap.
962+
if stable_sort:
963+
sort_indices = np.argsort(compound_indices, kind='stable')
964+
else:
965+
# Quicksort
966+
sort_indices = np.argsort(compound_indices)
967+
# We must sort size_per_atom accordingly (Issue #3352).
968+
size_per_atom = size_per_atom[sort_indices]
969+
970+
compound_masks = []
971+
atom_masks = []
972+
for compound_size in unique_compound_sizes:
973+
compound_masks.append(compound_sizes == compound_size)
974+
if needs_sorting:
975+
atom_masks.append(sort_indices[size_per_atom
976+
== compound_size]
977+
.reshape(-1, compound_size))
978+
else:
979+
atom_masks.append(np.where(size_per_atom
980+
== compound_size)[0]
981+
.reshape(-1, compound_size))
982+
983+
self._cache[cache_key] = {
984+
"compound_indices": compound_indices,
985+
"data": (atom_masks, compound_masks, len(compound_sizes))
986+
}
976987

977-
return atom_masks, compound_masks, len(compound_sizes)
988+
return self._cache[cache_key]["data"]
978989

979990
@warn_if_not_unique
980991
@_pbc_to_wrap

testsuite/MDAnalysisTests/core/test_accumulate.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
2222
#
2323
import numpy as np
24-
from numpy.testing import assert_equal, assert_almost_equal
24+
from numpy.testing import assert_equal, assert_almost_equal, assert_allclose
2525

2626
import MDAnalysis as mda
2727
from MDAnalysis.exceptions import DuplicateWarning, NoDataError
@@ -99,7 +99,6 @@ def test_accumulate_array_attribute_compounds(self, name, compound, level):
9999
ref = [np.ones((len(a), 2, 5)).sum(axis=0) for a in group.atoms.groupby(name).values()]
100100
assert_equal(group.accumulate(np.ones((len(group.atoms), 2, 5)), compound=compound), ref)
101101

102-
103102
class TestTotals(object):
104103
"""Tests the functionality of *Group.total*() like total_mass
105104
and total_charge.
@@ -291,3 +290,39 @@ def test_quadrupole_moment_fragments(self, group):
291290
assert_almost_equal(quadrupoles,
292291
np.array([0., 0.0011629, 0.1182701, 0.6891748
293292
])) and len(quadrupoles) == n_compounds
293+
294+
295+
class TestCache:
296+
@pytest.fixture()
297+
def group(self):
298+
return mda.Universe(PSF, DCD).atoms
299+
300+
def test_cache(self, group):
301+
"""Test that one cache per compound is created."""
302+
group_nocache = group.copy()
303+
group_cache = group.copy()
304+
305+
for compound in ['residues', 'fragments']:
306+
actual = group_nocache.accumulate("masses", compound=compound)
307+
desired = group_cache.accumulate("masses", compound=compound)
308+
309+
assert_allclose(actual, desired)
310+
group_nocache._cache.pop(f"{compound}_masks")
311+
312+
@pytest.mark.parametrize("compound",
313+
['residues', 'fragments'])
314+
def test_cache_updating(self, group, compound):
315+
"""Test caching of compound_masks for updating atomgroups."""
316+
kwargs = {"attribute": "masses", "compound": compound}
317+
318+
group_nocache = group.select_atoms("prop z < 1.0", updating=True)
319+
group_cache = group.select_atoms("prop z < 1.0", updating=True)
320+
321+
assert_allclose(group_nocache.accumulate(**kwargs),
322+
group_cache.accumulate(**kwargs))
323+
324+
# Clear cache and forward to next frame
325+
group_nocache._cache.pop(f"{compound}_masks")
326+
group.universe.trajectory.next()
327+
assert_allclose(group_nocache.accumulate(**kwargs),
328+
group_cache.accumulate(**kwargs))

0 commit comments

Comments
 (0)