From f15cf89b821c492ee7bf8fd1172629cf61e3f89a Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Tue, 19 Nov 2024 10:37:04 +0100 Subject: [PATCH] Support concatenation of more than two `AtomArray` objects --- doc/apidoc.json | 1 + src/biotite/structure/atoms.py | 150 ++++++++++++++++++++++++-------- src/biotite/structure/bonds.pyx | 71 ++++++++++++--- tests/structure/test_atoms.py | 8 ++ tests/structure/test_bonds.py | 36 ++++++++ 5 files changed, 215 insertions(+), 51 deletions(-) diff --git a/doc/apidoc.json b/doc/apidoc.json index 7432602f4..3fc5a98d7 100644 --- a/doc/apidoc.json +++ b/doc/apidoc.json @@ -204,6 +204,7 @@ "Atom", "AtomArray", "AtomArrayStack", + "concatenate", "array", "stack", "repeat", diff --git a/src/biotite/structure/atoms.py b/src/biotite/structure/atoms.py index 0be02e172..dc763c114 100644 --- a/src/biotite/structure/atoms.py +++ b/src/biotite/structure/atoms.py @@ -13,6 +13,7 @@ "Atom", "AtomArray", "AtomArrayStack", + "concatenate", "array", "stack", "repeat", @@ -22,6 +23,7 @@ import abc import numbers +from collections.abc import Sequence import numpy as np from biotite.copyable import Copyable from biotite.structure.bonds import BondList @@ -420,42 +422,7 @@ def __len__(self): return self._array_length def __add__(self, array): - if not isinstance(self, type(array)): - raise TypeError("Can only concatenate two arrays or two stacks") - # Create either new array or stack, depending of the own type - if isinstance(self, AtomArray): - concat = AtomArray(length=self._array_length + array._array_length) - if isinstance(self, AtomArrayStack): - concat = AtomArrayStack( - self.stack_depth(), self._array_length + array._array_length - ) - - concat._coord = np.concatenate((self._coord, array.coord), axis=-2) - - # Transfer only annotations, - # which are existent in both operands - arr_categories = list(array._annot.keys()) - for category in self._annot.keys(): - if category in arr_categories: - annot = self._annot[category] - arr_annot = array._annot[category] - concat._annot[category] = np.concatenate((annot, arr_annot)) - - # Concatenate bonds lists, - # if at least one of them contains bond information - if self._bonds is not None or array._bonds is not None: - bonds1 = self._bonds - bonds2 = array._bonds - if bonds1 is None: - bonds1 = BondList(self._array_length) - if bonds2 is None: - bonds2 = BondList(array._array_length) - concat._bonds = bonds1 + bonds2 - - # Copy box - if self._box is not None: - concat._box = np.copy(self._box) - return concat + return concatenate([self, array]) def __copy_fill__(self, clone): super().__copy_fill__(clone) @@ -619,6 +586,7 @@ class AtomArray(_AtomArrayBase): :class:`AtomArray` is done with the '+' operator. Only the annotation categories, which are existing in both arrays, are transferred to the new array. + For a list of :class:`AtomArray` objects, use :func:`concatenate()`. Optionally, an :class:`AtomArray` can store chemical bond information via a :class:`BondList` object. @@ -891,7 +859,9 @@ class AtomArrayStack(_AtomArrayBase): :class:`AtomArray` instance. Concatenation of atoms for each array in the stack is done using the - '+' operator. For addition of atom arrays onto the stack use the + '+' operator. + For a list of :class:`AtomArray` objects, use :func:`concatenate()`. + For addition of atom arrays onto the stack use the :func:`stack()` method. The :attr:`box` attribute has the shape *m x 3 x 3*, as the cell @@ -1305,6 +1275,112 @@ def stack(arrays): return array_stack +def concatenate(atoms): + """ + Concatenate multiple :class:`AtomArray` or :class:`AtomArrayStack` objects into + a single :class:`AtomArray` or :class:`AtomArrayStack`, respectively. + + Parameters + ---------- + atoms : iterable object of AtomArray or AtomArrayStack + The atoms to be concatenated. + :class:`AtomArray` cannot be mixed with :class:`AtomArrayStack`. + + Returns + ------- + concatenated_atoms : AtomArray or AtomArrayStack + The concatenated atoms, i.e. its ``array_length()`` is the sum of the + ``array_length()`` of the input ``atoms``. + + Notes + ----- + The following rules apply: + + - Only the annotation categories that exist in all elements are transferred. + - The box of the first element that has a box is transferred, if any. + - The bonds of all elements are concatenated, if any element has associated bonds. + For elements without a :class:`BondList` an empty :class:`BondList` is assumed. + + Examples + -------- + + >>> atoms1 = array([ + ... Atom([1,2,3], res_id=1, atom_name="N"), + ... Atom([4,5,6], res_id=1, atom_name="CA"), + ... Atom([7,8,9], res_id=1, atom_name="C") + ... ]) + >>> atoms2 = array([ + ... Atom([1,2,3], res_id=2, atom_name="N"), + ... Atom([4,5,6], res_id=2, atom_name="CA"), + ... Atom([7,8,9], res_id=2, atom_name="C") + ... ]) + >>> print(concatenate([atoms1, atoms2])) + 1 N 1.000 2.000 3.000 + 1 CA 4.000 5.000 6.000 + 1 C 7.000 8.000 9.000 + 2 N 1.000 2.000 3.000 + 2 CA 4.000 5.000 6.000 + 2 C 7.000 8.000 9.000 + """ + # Ensure that the atoms can be iterated over multiple times + if not isinstance(atoms, Sequence): + atoms = list(atoms) + + length = 0 + depth = None + element_type = None + common_categories = set(atoms[0].get_annotation_categories()) + box = None + has_bonds = False + for element in atoms: + if element_type is None: + element_type = type(element) + else: + if not isinstance(element, element_type): + raise TypeError( + f"Cannot concatenate '{type(element).__name__}' " + f"with '{element_type.__name__}'" + ) + length += element.array_length() + if isinstance(element, AtomArrayStack): + if depth is None: + depth = element.stack_depth() + else: + if element.stack_depth() != depth: + raise IndexError("The stack depths are not equal") + common_categories &= set(element.get_annotation_categories()) + if element.box is not None and box is None: + box = element.box + if element.bonds is not None: + has_bonds = True + + if element_type == AtomArray: + concat_atoms = AtomArray(length) + elif element_type == AtomArrayStack: + concat_atoms = AtomArrayStack(depth, length) + concat_atoms.coord = np.concatenate([element.coord for element in atoms], axis=-2) + for category in common_categories: + concat_atoms.set_annotation( + category, + np.concatenate( + [element.get_annotation(category) for element in atoms], axis=0 + ), + ) + concat_atoms.box = box + if has_bonds: + # Concatenate bonds of all elements + concat_atoms.bonds = BondList.concatenate( + [ + element.bonds + if element.bonds is not None + else BondList(element.array_length()) + for element in atoms + ] + ) + + return concat_atoms + + def repeat(atoms, coord): """ Repeat atoms (:class:`AtomArray` or :class:`AtomArrayStack`) diff --git a/src/biotite/structure/bonds.pyx b/src/biotite/structure/bonds.pyx index 5415e3f20..a869fcfd5 100644 --- a/src/biotite/structure/bonds.pyx +++ b/src/biotite/structure/bonds.pyx @@ -17,6 +17,7 @@ cimport cython cimport numpy as np from libc.stdlib cimport free, realloc +from collections.abc import Sequence import itertools import numbers from enum import IntEnum @@ -309,6 +310,61 @@ class BondList(Copyable): self._bonds = np.zeros((0, 3), dtype=np.uint32) self._max_bonds_per_atom = 0 + @staticmethod + def concatenate(bonds_lists): + """ + Concatenate multiple :class:`BondList` objects into a single + :class:`BondList`, respectively. + + Parameters + ---------- + bonds_lists : iterable object of BondList + The bond lists to be concatenated. + + Returns + ------- + concatenated_bonds : BondList + The concatenated bond lists. + + Examples + -------- + + >>> bonds1 = BondList(2, np.array([(0, 1)])) + >>> bonds2 = BondList(3, np.array([(0, 1), (0, 2)])) + >>> merged_bonds = BondList.concatenate([bonds1, bonds2]) + >>> print(merged_bonds.get_atom_count()) + 5 + >>> print(merged_bonds.as_array()[:, :2]) + [[0 1] + [2 3] + [2 4]] + """ + # Ensure that the bonds_lists can be iterated over multiple times + if not isinstance(bonds_lists, Sequence): + bonds_lists = list(bonds_lists) + + cdef np.ndarray merged_bonds = np.concatenate( + [bond_list._bonds for bond_list in bonds_lists] + ) + # Offset the indices of appended bonds list + # (consistent with addition of AtomArray) + cdef int start = 0, stop = 0 + cdef int cum_atom_count = 0 + for bond_list in bonds_lists: + stop = start + bond_list._bonds.shape[0] + merged_bonds[start : stop, :2] += cum_atom_count + cum_atom_count += bond_list._atom_count + start = stop + + cdef merged_bond_list = BondList(cum_atom_count) + # Array is not used in constructor to prevent unnecessary + # maximum and redundant bond calculation + merged_bond_list._bonds = merged_bonds + merged_bond_list._max_bonds_per_atom = max( + [bond_list._max_bonds_per_atom for bond_list in bonds_lists] + ) + return merged_bond_list + def __copy_create__(self): # Create empty bond list to prevent # unnecessary removal of redundant atoms @@ -1002,20 +1058,7 @@ class BondList(Copyable): ) def __add__(self, bond_list): - cdef np.ndarray merged_bonds \ - = np.concatenate([self._bonds, bond_list._bonds]) - # Offset the indices of appended bonds list - # (consistent with addition of AtomArray) - merged_bonds[len(self._bonds):, :2] += self._atom_count - cdef uint32 merged_count = self._atom_count + bond_list._atom_count - cdef merged_bond_list = BondList(merged_count) - # Array is not used in constructor to prevent unnecessary - # maximum and redundant bond calculation - merged_bond_list._bonds = merged_bonds - merged_bond_list._max_bonds_per_atom = max( - self._max_bonds_per_atom, bond_list._max_bonds_per_atom - ) - return merged_bond_list + return BondList.concatenate([self, bond_list]) def __getitem__(self, index): ## Variables for both, integer and boolean index arrays diff --git a/tests/structure/test_atoms.py b/tests/structure/test_atoms.py index 6b5d48f1d..304e25469 100644 --- a/tests/structure/test_atoms.py +++ b/tests/structure/test_atoms.py @@ -122,6 +122,14 @@ def test_stack_indexing(stack): assert filtered_stack.array_length() == 1 +def test_concatenate_single(array, stack): + """ + Concatenation of a single array or stack should return the same object. + """ + assert array == struc.concatenate([array]) + assert stack == struc.concatenate([stack]) + + def test_concatenation(array, stack): concat_array = array[2:] + array[:2] assert concat_array.chain_id.tolist() == ["B", "B", "B", "A", "A"] diff --git a/tests/structure/test_bonds.py b/tests/structure/test_bonds.py index 01a48ba17..ddb9e9187 100644 --- a/tests/structure/test_bonds.py +++ b/tests/structure/test_bonds.py @@ -2,6 +2,7 @@ # under the 3-Clause BSD License. Please see 'LICENSE.rst' for further # information. +import itertools from os.path import join import numpy as np import pytest @@ -132,6 +133,41 @@ def test_modification(bond_list): assert bond_list.as_array().tolist() == [[1, 3, 1], [3, 4, 0], [4, 6, 0], [1, 4, 0]] +@pytest.mark.parametrize("seed", range(10)) +def test_concatenation_and_splitting(seed): + """ + Repeatedly concatenating and splitting a `BondList` with the same indices + should recover the same object. + """ + N_BOND_LISTS = 5 + MAX_ATOMS = 10 + MAX_BONDS = 10 + + rng = np.random.default_rng(seed) + split_bond_lists = [] + starts = [0] + for _ in range(N_BOND_LISTS): + n_atoms = rng.integers(1, MAX_ATOMS) + bonds = rng.integers(0, n_atoms, size=(MAX_BONDS, 2)) + bond_types = rng.integers(0, len(struc.BondType), size=MAX_BONDS) + split_bond_lists.append( + struc.BondList( + n_atoms, np.concatenate([bonds, bond_types[:, np.newaxis]], axis=1) + ) + ) + starts.append(starts[-1] + n_atoms) + + concatenated_bond_list = struc.BondList.concatenate(split_bond_lists) + resplit_bond_lists = [ + concatenated_bond_list[start:stop] for start, stop in itertools.pairwise(starts) + ] + + for ref_bond_list, test_bond_list in zip( + split_bond_lists, resplit_bond_lists, strict=True + ): + assert ref_bond_list == test_bond_list + + def test_add_two_bond_list(): """ Test adding two `BondList` objects.