Skip to content

Commit

Permalink
Move guess_elements() to public API
Browse files Browse the repository at this point in the history
  • Loading branch information
padix-key committed May 30, 2024
1 parent bdefa86 commit 4dc6c91
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 100 deletions.
3 changes: 2 additions & 1 deletion doc/apidoc.json
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@
"Repair" : [
"renumber_atom_ids",
"renumber_res_ids",
"create_continuous_res_ids"
"create_continuous_res_ids",
"infer_elements"
],
"Residue level utility" : [
"get_residue_starts",
Expand Down
41 changes: 1 addition & 40 deletions src/biotite/structure/io/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import os.path
import io
from ..atoms import AtomArray, AtomArrayStack
from ..atoms import AtomArrayStack


def load_structure(file_path, template=None, **kwargs):
Expand Down Expand Up @@ -224,42 +224,3 @@ def _as_single_model_if_possible(atoms):
return atoms[0]
else:
return atoms


# Helper function to estimate elements from atom names
_elements = [elem.upper() for elem in
["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg",
"Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe",
"Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y",
"Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te",
"I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb",
"Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt",
"Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa",
"U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf",
"Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts",
"Og"]
]
def _guess_element(atom_name):
# remove digits (1H -> H)
elem = "".join([i for i in atom_name if not i.isdigit()])
elem = elem.upper()
if len(elem) == 0:
return ""

# Some often used elements for biomolecules
if elem.startswith("C") or elem.startswith("N") or \
elem.startswith("O") or elem.startswith("S") or \
elem.startswith("H"):
return elem[0]

# Exactly match element abbreviations
try:
return _elements[_elements.index(elem[:2])]
except ValueError:
try:
return _elements[_elements.index(elem[0])]
except ValueError:
pass

return ""

32 changes: 16 additions & 16 deletions src/biotite/structure/io/gro/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...atoms import AtomArray, AtomArrayStack
from ...box import is_orthogonal
from ....file import TextFile, InvalidFileError
from ..general import _guess_element as guess_element
from ...repair import infer_elements
from ...error import BadStructureError
import copy
from datetime import datetime
Expand Down Expand Up @@ -38,15 +38,15 @@ class GROFile(TextFile):
--------
Load a `\\*.gro` file, modify the structure and save the new
structure into a new file:
>>> import os.path
>>> file = GROFile.read(os.path.join(path_to_structures, "1l2y.gro"))
>>> array_stack = file.get_structure()
>>> array_stack_mod = rotate(array_stack, [1,2,3])
>>> file = GROFile()
>>> file.set_structure(array_stack_mod)
>>> file.write(os.path.join(path_to_directory, "1l2y_mod.gro"))
"""
def get_model_count(self):
"""
Expand All @@ -68,7 +68,7 @@ def get_structure(self, model=None):
"""
Get an :class:`AtomArray` or :class:`AtomArrayStack` from the
GRO file.
Parameters
----------
model : int, optional
Expand All @@ -80,21 +80,21 @@ def get_structure(self, model=None):
If this parameter is omitted, an :class:`AtomArrayStack`
containing all models will be returned, even if the
structure contains only one model.
Returns
-------
array : AtomArray or AtomArrayStack
The return type depends on the `model` parameter.
"""

def get_atom_line_i(model_start_i, model_atom_counts):
"""
Helper function to get the indices of all atoms for a model
"""
return np.arange(
model_start_i+1, model_start_i+1+model_atom_counts
)

def set_box_dimen(box_param):
"""
Helper function to create the box vectors from the values
Expand All @@ -104,7 +104,7 @@ def set_box_dimen(box_param):
----------
box_param : list of float
The box dimensions in the GRO file.
Returns
-------
box_vectors : ndarray, dtype=float, shape=(3,3)
Expand Down Expand Up @@ -171,7 +171,7 @@ def set_box_dimen(box_param):
array.res_id[i] = int(line[0:5])
array.res_name[i] = line[5:10].strip()
array.atom_name[i] = line[10:15].strip()
array.element[i] = guess_element(line[10:15].strip())
array.element = infer_elements(array.atom_name)

# Fill in coordinates and boxes
if isinstance(array, AtomArray):
Expand All @@ -186,7 +186,7 @@ def set_box_dimen(box_param):
box_i = atom_i[-1] + 1
box_param = [float(e)*10 for e in self.lines[box_i].split()]
array.box = set_box_dimen(box_param)

elif isinstance(array, AtomArrayStack):
for m in range(len(model_start_i)):
atom_i = get_atom_line_i(
Expand All @@ -204,18 +204,18 @@ def set_box_dimen(box_param):
# Create a box in the stack if not already existing
# and the box is not a dummy
if box is not None:
if array.box is None:
if array.box is None:
array.box = np.zeros((array.stack_depth(), 3, 3))
array.box[m] = box

return array


def set_structure(self, array):
"""
Set the :class:`AtomArray` or :class:`AtomArrayStack` for the
file.
Parameters
----------
array : AtomArray or AtomArrayStack
Expand All @@ -235,7 +235,7 @@ def get_box_dimen(array):
----------
array : AtomArray
The atom array to get the box dimensions from.
Returns
-------
box : str
Expand All @@ -259,7 +259,7 @@ def get_box_dimen(array):
box[2,0], box[2,1],
)
return " ".join([f"{e:>9.5f}" for e in box_elements])

if "atom_id" in array.get_annotation_categories():
atom_id = array.atom_id
else:
Expand Down
17 changes: 8 additions & 9 deletions src/biotite/structure/io/pdb/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ...bonds import BondList, connect_via_residue_names
from ...box import vectors_from_unitcell, unitcell_from_vectors
from ....file import TextFile, InvalidFileError
from ..general import _guess_element as guess_element
from ...repair import infer_elements
from ...error import BadStructureError
from ...filter import (
filter_first_altloc,
Expand Down Expand Up @@ -460,15 +460,14 @@ def get_structure(self, model=None, altloc="first", extra_fields=[],

# Replace empty strings for elements with guessed types
# This is used e.g. for PDB files created by Gromacs
if "" in array.element:
rep_num = 0
for idx in range(len(array.element)):
if not array.element[idx]:
atom_name = array.atom_name[idx]
array.element[idx] = guess_element(atom_name)
rep_num += 1
empty_element_mask = array.element == ""
if empty_element_mask.any():
warnings.warn(
"{} elements were guessed from atom_name.".format(rep_num)
f"{np.count_nonzero(empty_element_mask)} elements "
"were guessed from atom name"
)
array.element[empty_element_mask] = infer_elements(
array.atom_name[empty_element_mask]
)

# Fill in coordinates
Expand Down
72 changes: 70 additions & 2 deletions src/biotite/structure/repair.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
"""

__name__ = "biotite.structure"
__author__ = "Patrick Kunzmann"
__author__ = "Patrick Kunzmann, Daniel Bauer"
__all__ = ["renumber_atom_ids", "renumber_res_ids",
"create_continuous_res_ids"]
"create_continuous_res_ids", "infer_elements"]

import warnings
import numpy as np
from .atoms import AtomArray, AtomArrayStack
from .residues import get_residue_starts
from .chains import get_chain_starts

Expand Down Expand Up @@ -125,3 +126,70 @@ def create_continuous_res_ids(atoms, restart_each_chain=True):
res_ids[start:] -= res_ids[start] - 1

return res_ids


def infer_elements(atoms):
"""
Infer the element of an atom based on its name.
Parameters
----------
atoms : AtomArray or AtomArrayStack or array-like of str
The atoms for which the elements should be inferred.
Alternatively the atom names can be passed directly.
Returns
-------
elements : ndarray, dtype=str
The inferred elements.
Examples
--------
>>> print(infer_elements(atom_array)[:10])
['N' 'C' 'C' 'O' 'C' 'C' 'O' 'N' 'H' 'H']
>>> print(infer_elements(["CA", "C", "C1", "OD1", "HD21", "1H", "FE"]))
['C' 'C' 'C' 'O' 'H' 'H' 'FE']
"""
if isinstance(atoms, (AtomArray, AtomArrayStack)):
atom_names = atoms.atom_name
else:
atom_names = atoms
return np.array([_guess_element(name) for name in atom_names])


_elements = [elem.upper() for elem in
["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg",
"Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe",
"Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y",
"Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te",
"I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb",
"Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt",
"Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa",
"U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf",
"Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts",
"Og"]
]
def _guess_element(atom_name):
# remove digits (1H -> H)
elem = "".join([i for i in atom_name if not i.isdigit()])
elem = elem.upper()
if len(elem) == 0:
return ""

# Some often used elements for biomolecules
if elem.startswith("C") or elem.startswith("N") or \
elem.startswith("O") or elem.startswith("S") or \
elem.startswith("H"):
return elem[0]

# Exactly match element abbreviations
try:
return _elements[_elements.index(elem[:2])]
except ValueError:
try:
return _elements[_elements.index(elem[0])]
except ValueError:
warnings.warn(f"Could not infer element for '{atom_name}'")
return ""
30 changes: 0 additions & 30 deletions tests/structure/test_generalio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from tempfile import NamedTemporaryFile
import biotite.structure as struc
import biotite.structure.io as strucio
from biotite.structure.io.general import _guess_element
import glob
import os
from os.path import join, splitext
Expand Down Expand Up @@ -173,32 +172,3 @@ def test_small_molecule():
os.remove(temp.name)

assert test_array == ref_array


@pytest.mark.parametrize(
"name,expected",
[("CA", "C"),
("C", "C"),
("CB", "C"),
("OD1", "O"),
("HD21", "H"),
("1H", "H"),
("CL", "C"),
("HE", "H"),
("SD", "S"),
("NA", "N"),
("NX", "N"),
("BE", "BE"),
("BEA", "BE"),
("K", "K"),
("KA", "K"),
("QWERT", "")]
)
def test_guess_element(name, expected):
"""
Check if elements are correctly guessed based on known examples.
Elements are automatically guessed in GRO and PDB files where the
*element* column is missing.
"""
result = _guess_element(name)
assert result == expected
2 changes: 1 addition & 1 deletion tests/structure/test_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_extra_fields(hybrid36):


@pytest.mark.filterwarnings("ignore")
def test_guess_elements():
def test_inferred_elements():
# Read valid pdb file
path = join(data_dir("structure"), "1l2y.pdb")
pdb_file = pdb.PDBFile.read(path)
Expand Down
28 changes: 27 additions & 1 deletion tests/structure/test_repair.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,30 @@ def test_create_continuous_res_ids(multi_chain, restart_each_chain):
).tolist()
else:
assert test_res_ids.tolist() \
== (np.arange(len(test_res_ids)) + 1).tolist()
== (np.arange(len(test_res_ids)) + 1).tolist()


@pytest.mark.parametrize(
"name,expected",
[("CA", "C"),
("C", "C"),
("CB", "C"),
("OD1", "O"),
("HD21", "H"),
("1H", "H"),
#("CL", "CL"), # This is an edge case where inference is difficult
("HE", "H"),
("SD", "S"),
("NA", "N"),
("NX", "N"),
("BE", "BE"),
("BEA", "BE"),
("K", "K"),
("KA", "K"),
("QWERT", "")]
)
def test_infer_elements(name, expected):
"""
Check if elements are correctly guessed based on known examples.
"""
assert struc.infer_elements([name])[0] == expected

0 comments on commit 4dc6c91

Please sign in to comment.