Skip to content

Commit

Permalink
Support concatenation of more than two AtomArray objects
Browse files Browse the repository at this point in the history
  • Loading branch information
padix-key committed Dec 20, 2024
1 parent 0acb143 commit f15cf89
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 51 deletions.
1 change: 1 addition & 0 deletions doc/apidoc.json
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
"Atom",
"AtomArray",
"AtomArrayStack",
"concatenate",
"array",
"stack",
"repeat",
Expand Down
150 changes: 113 additions & 37 deletions src/biotite/structure/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"Atom",
"AtomArray",
"AtomArrayStack",
"concatenate",
"array",
"stack",
"repeat",
Expand All @@ -22,6 +23,7 @@

import abc
import numbers
from collections.abc import Sequence
import numpy as np
from biotite.copyable import Copyable
from biotite.structure.bonds import BondList
Expand Down Expand Up @@ -420,42 +422,7 @@ def __len__(self):
return self._array_length

def __add__(self, array):
if not isinstance(self, type(array)):
raise TypeError("Can only concatenate two arrays or two stacks")
# Create either new array or stack, depending of the own type
if isinstance(self, AtomArray):
concat = AtomArray(length=self._array_length + array._array_length)
if isinstance(self, AtomArrayStack):
concat = AtomArrayStack(
self.stack_depth(), self._array_length + array._array_length
)

concat._coord = np.concatenate((self._coord, array.coord), axis=-2)

# Transfer only annotations,
# which are existent in both operands
arr_categories = list(array._annot.keys())
for category in self._annot.keys():
if category in arr_categories:
annot = self._annot[category]
arr_annot = array._annot[category]
concat._annot[category] = np.concatenate((annot, arr_annot))

# Concatenate bonds lists,
# if at least one of them contains bond information
if self._bonds is not None or array._bonds is not None:
bonds1 = self._bonds
bonds2 = array._bonds
if bonds1 is None:
bonds1 = BondList(self._array_length)
if bonds2 is None:
bonds2 = BondList(array._array_length)
concat._bonds = bonds1 + bonds2

# Copy box
if self._box is not None:
concat._box = np.copy(self._box)
return concat
return concatenate([self, array])

def __copy_fill__(self, clone):
super().__copy_fill__(clone)
Expand Down Expand Up @@ -619,6 +586,7 @@ class AtomArray(_AtomArrayBase):
:class:`AtomArray` is done with the '+' operator.
Only the annotation categories, which are existing in both arrays,
are transferred to the new array.
For a list of :class:`AtomArray` objects, use :func:`concatenate()`.
Optionally, an :class:`AtomArray` can store chemical bond
information via a :class:`BondList` object.
Expand Down Expand Up @@ -891,7 +859,9 @@ class AtomArrayStack(_AtomArrayBase):
:class:`AtomArray` instance.
Concatenation of atoms for each array in the stack is done using the
'+' operator. For addition of atom arrays onto the stack use the
'+' operator.
For a list of :class:`AtomArray` objects, use :func:`concatenate()`.
For addition of atom arrays onto the stack use the
:func:`stack()` method.
The :attr:`box` attribute has the shape *m x 3 x 3*, as the cell
Expand Down Expand Up @@ -1305,6 +1275,112 @@ def stack(arrays):
return array_stack


def concatenate(atoms):
"""
Concatenate multiple :class:`AtomArray` or :class:`AtomArrayStack` objects into
a single :class:`AtomArray` or :class:`AtomArrayStack`, respectively.
Parameters
----------
atoms : iterable object of AtomArray or AtomArrayStack
The atoms to be concatenated.
:class:`AtomArray` cannot be mixed with :class:`AtomArrayStack`.
Returns
-------
concatenated_atoms : AtomArray or AtomArrayStack
The concatenated atoms, i.e. its ``array_length()`` is the sum of the
``array_length()`` of the input ``atoms``.
Notes
-----
The following rules apply:
- Only the annotation categories that exist in all elements are transferred.
- The box of the first element that has a box is transferred, if any.
- The bonds of all elements are concatenated, if any element has associated bonds.
For elements without a :class:`BondList` an empty :class:`BondList` is assumed.
Examples
--------
>>> atoms1 = array([
... Atom([1,2,3], res_id=1, atom_name="N"),
... Atom([4,5,6], res_id=1, atom_name="CA"),
... Atom([7,8,9], res_id=1, atom_name="C")
... ])
>>> atoms2 = array([
... Atom([1,2,3], res_id=2, atom_name="N"),
... Atom([4,5,6], res_id=2, atom_name="CA"),
... Atom([7,8,9], res_id=2, atom_name="C")
... ])
>>> print(concatenate([atoms1, atoms2]))
1 N 1.000 2.000 3.000
1 CA 4.000 5.000 6.000
1 C 7.000 8.000 9.000
2 N 1.000 2.000 3.000
2 CA 4.000 5.000 6.000
2 C 7.000 8.000 9.000
"""
# Ensure that the atoms can be iterated over multiple times
if not isinstance(atoms, Sequence):
atoms = list(atoms)

length = 0
depth = None
element_type = None
common_categories = set(atoms[0].get_annotation_categories())
box = None
has_bonds = False
for element in atoms:
if element_type is None:
element_type = type(element)
else:
if not isinstance(element, element_type):
raise TypeError(
f"Cannot concatenate '{type(element).__name__}' "
f"with '{element_type.__name__}'"
)
length += element.array_length()
if isinstance(element, AtomArrayStack):
if depth is None:
depth = element.stack_depth()
else:
if element.stack_depth() != depth:
raise IndexError("The stack depths are not equal")
common_categories &= set(element.get_annotation_categories())
if element.box is not None and box is None:
box = element.box
if element.bonds is not None:
has_bonds = True

if element_type == AtomArray:
concat_atoms = AtomArray(length)
elif element_type == AtomArrayStack:
concat_atoms = AtomArrayStack(depth, length)
concat_atoms.coord = np.concatenate([element.coord for element in atoms], axis=-2)
for category in common_categories:
concat_atoms.set_annotation(
category,
np.concatenate(
[element.get_annotation(category) for element in atoms], axis=0
),
)
concat_atoms.box = box
if has_bonds:
# Concatenate bonds of all elements
concat_atoms.bonds = BondList.concatenate(
[
element.bonds
if element.bonds is not None
else BondList(element.array_length())
for element in atoms
]
)

return concat_atoms


def repeat(atoms, coord):
"""
Repeat atoms (:class:`AtomArray` or :class:`AtomArrayStack`)
Expand Down
71 changes: 57 additions & 14 deletions src/biotite/structure/bonds.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cimport cython
cimport numpy as np
from libc.stdlib cimport free, realloc

from collections.abc import Sequence
import itertools
import numbers
from enum import IntEnum
Expand Down Expand Up @@ -309,6 +310,61 @@ class BondList(Copyable):
self._bonds = np.zeros((0, 3), dtype=np.uint32)
self._max_bonds_per_atom = 0

@staticmethod
def concatenate(bonds_lists):
"""
Concatenate multiple :class:`BondList` objects into a single
:class:`BondList`, respectively.
Parameters
----------
bonds_lists : iterable object of BondList
The bond lists to be concatenated.
Returns
-------
concatenated_bonds : BondList
The concatenated bond lists.
Examples
--------
>>> bonds1 = BondList(2, np.array([(0, 1)]))
>>> bonds2 = BondList(3, np.array([(0, 1), (0, 2)]))
>>> merged_bonds = BondList.concatenate([bonds1, bonds2])
>>> print(merged_bonds.get_atom_count())
5
>>> print(merged_bonds.as_array()[:, :2])
[[0 1]
[2 3]
[2 4]]
"""
# Ensure that the bonds_lists can be iterated over multiple times
if not isinstance(bonds_lists, Sequence):
bonds_lists = list(bonds_lists)

cdef np.ndarray merged_bonds = np.concatenate(
[bond_list._bonds for bond_list in bonds_lists]
)
# Offset the indices of appended bonds list
# (consistent with addition of AtomArray)
cdef int start = 0, stop = 0
cdef int cum_atom_count = 0
for bond_list in bonds_lists:
stop = start + bond_list._bonds.shape[0]
merged_bonds[start : stop, :2] += cum_atom_count
cum_atom_count += bond_list._atom_count
start = stop

cdef merged_bond_list = BondList(cum_atom_count)
# Array is not used in constructor to prevent unnecessary
# maximum and redundant bond calculation
merged_bond_list._bonds = merged_bonds
merged_bond_list._max_bonds_per_atom = max(
[bond_list._max_bonds_per_atom for bond_list in bonds_lists]
)
return merged_bond_list

def __copy_create__(self):
# Create empty bond list to prevent
# unnecessary removal of redundant atoms
Expand Down Expand Up @@ -1002,20 +1058,7 @@ class BondList(Copyable):
)

def __add__(self, bond_list):
cdef np.ndarray merged_bonds \
= np.concatenate([self._bonds, bond_list._bonds])
# Offset the indices of appended bonds list
# (consistent with addition of AtomArray)
merged_bonds[len(self._bonds):, :2] += self._atom_count
cdef uint32 merged_count = self._atom_count + bond_list._atom_count
cdef merged_bond_list = BondList(merged_count)
# Array is not used in constructor to prevent unnecessary
# maximum and redundant bond calculation
merged_bond_list._bonds = merged_bonds
merged_bond_list._max_bonds_per_atom = max(
self._max_bonds_per_atom, bond_list._max_bonds_per_atom
)
return merged_bond_list
return BondList.concatenate([self, bond_list])

def __getitem__(self, index):
## Variables for both, integer and boolean index arrays
Expand Down
8 changes: 8 additions & 0 deletions tests/structure/test_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ def test_stack_indexing(stack):
assert filtered_stack.array_length() == 1


def test_concatenate_single(array, stack):
"""
Concatenation of a single array or stack should return the same object.
"""
assert array == struc.concatenate([array])
assert stack == struc.concatenate([stack])


def test_concatenation(array, stack):
concat_array = array[2:] + array[:2]
assert concat_array.chain_id.tolist() == ["B", "B", "B", "A", "A"]
Expand Down
36 changes: 36 additions & 0 deletions tests/structure/test_bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further
# information.

import itertools
from os.path import join
import numpy as np
import pytest
Expand Down Expand Up @@ -132,6 +133,41 @@ def test_modification(bond_list):
assert bond_list.as_array().tolist() == [[1, 3, 1], [3, 4, 0], [4, 6, 0], [1, 4, 0]]


@pytest.mark.parametrize("seed", range(10))
def test_concatenation_and_splitting(seed):
"""
Repeatedly concatenating and splitting a `BondList` with the same indices
should recover the same object.
"""
N_BOND_LISTS = 5
MAX_ATOMS = 10
MAX_BONDS = 10

rng = np.random.default_rng(seed)
split_bond_lists = []
starts = [0]
for _ in range(N_BOND_LISTS):
n_atoms = rng.integers(1, MAX_ATOMS)
bonds = rng.integers(0, n_atoms, size=(MAX_BONDS, 2))
bond_types = rng.integers(0, len(struc.BondType), size=MAX_BONDS)
split_bond_lists.append(
struc.BondList(
n_atoms, np.concatenate([bonds, bond_types[:, np.newaxis]], axis=1)
)
)
starts.append(starts[-1] + n_atoms)

concatenated_bond_list = struc.BondList.concatenate(split_bond_lists)
resplit_bond_lists = [
concatenated_bond_list[start:stop] for start, stop in itertools.pairwise(starts)
]

for ref_bond_list, test_bond_list in zip(
split_bond_lists, resplit_bond_lists, strict=True
):
assert ref_bond_list == test_bond_list


def test_add_two_bond_list():
"""
Test adding two `BondList` objects.
Expand Down

0 comments on commit f15cf89

Please sign in to comment.