diff --git a/protenix/data/core/featurizer.py b/protenix/data/core/featurizer.py index fdfaddc..0e3ae81 100644 --- a/protenix/data/core/featurizer.py +++ b/protenix/data/core/featurizer.py @@ -68,23 +68,22 @@ def encoder( """ num_keys = len(encode_def_dict_or_list) if isinstance(encode_def_dict_or_list, dict): - items = encode_def_dict_or_list.items() assert ( num_keys == max(encode_def_dict_or_list.values()) + 1 ), "Do not use discontinuous number, which might causing potential bugs in the code" + idx_map = encode_def_dict_or_list elif isinstance(encode_def_dict_or_list, list): - items = ((key, idx) for idx, key in enumerate(encode_def_dict_or_list)) + idx_map = {key: idx for idx, key in enumerate(encode_def_dict_or_list)} else: raise TypeError( "encode_def_dict_or_list must be a list or dict, " f"but got {type(encode_def_dict_or_list)}" ) - onehot_dict = { - key: [int(i == idx) for i in range(num_keys)] for key, idx in items - } - onehot_encoded_data = [onehot_dict[item] for item in input_list] - onehot_tensor = torch.Tensor(onehot_encoded_data) - return onehot_tensor + # Vectorized: map input items to integer indices, then use F.one_hot + indices = torch.tensor( + [idx_map[item] for item in input_list], dtype=torch.long + ) + return torch.nn.functional.one_hot(indices, num_classes=num_keys).float() @staticmethod def restype_onehot_encoded(restype_list: list[str]) -> torch.Tensor: @@ -132,21 +131,15 @@ def ref_atom_name_chars_encoded(atom_names: list[str]) -> torch.Tensor: Returns: torch.Tensor: A Tensor of character encoded atom names """ - onehot_dict = {} - for index, key in enumerate(range(64)): - onehot = [0] * 64 - onehot[index] = 1 - onehot_dict[key] = onehot - # [N_atom, 4, 64] - mol_encode = [] - for atom_name in atom_names: - # [4, 64] - atom_encode = [] - for name_str in atom_name.ljust(4): - atom_encode.append(onehot_dict[ord(name_str) - 32]) - mol_encode.append(atom_encode) - onehot_tensor = torch.Tensor(mol_encode) - return onehot_tensor + # Vectorized: build padded string, convert to char codes, use one_hot + n = len(atom_names) + padded = "".join(name.ljust(4)[:4] for name in atom_names) + char_codes = np.frombuffer(padded.encode("ascii"), dtype=np.uint8) + indices = (char_codes.astype(np.int64) - 32).clip(0, 63) + indices_tensor = torch.from_numpy(indices).reshape(n, 4) + return torch.nn.functional.one_hot( + indices_tensor, num_classes=64 + ).float() @staticmethod def get_prot_nuc_frame(token: Token, centre_atom: Atom) -> tuple[int, list[int]]: diff --git a/protenix/data/pipeline/dataset.py b/protenix/data/pipeline/dataset.py index 8d10b32..fa3bf47 100644 --- a/protenix/data/pipeline/dataset.py +++ b/protenix/data/pipeline/dataset.py @@ -250,10 +250,11 @@ def is_valid(row): > 0 ] else: + # Vectorized equivalent of: + # indices_list[indices_list["eval_type"].apply(lambda x: x in EvaluationChainInterface)] + # .isin() checks each element of the Series against the set, same semantics as the lambda. indices_list = indices_list[ - indices_list["eval_type"].apply( - lambda x: x in EvaluationChainInterface - ) + indices_list["eval_type"].isin(EvaluationChainInterface) ] self.check_indices_list(indices_list, "find_eval_chain_interface filtering") if self.limits > 0 and len(indices_list) > self.limits: diff --git a/protenix/data/template/template_featurizer.py b/protenix/data/template/template_featurizer.py index 7339e56..92f3dac 100644 --- a/protenix/data/template/template_featurizer.py +++ b/protenix/data/template/template_featurizer.py @@ -577,45 +577,61 @@ def as_data_dict(self) -> BatchDict: "template_atom_mask": self.atom_mask, } + # Shared config instance to avoid repeated object creation + _DGRAM_CONFIG = DistogramFeaturesConfig( + min_bin=3.25, max_bin=50.75, num_bins=39 + ) + def as_protenix_dict(self) -> BatchDict: """Compute additional features and return as Protenix dictionary.""" features = self.as_data_dict() - dgrams, pb_masks = [], [] - unit_vectors, bb_masks = [], [] - num_templates = self.aatype.shape[0] + num_res = self.aatype.shape[1] + + # Pre-allocate output arrays instead of list append + stack + all_pb_masks = np.empty( + (num_templates, num_res, num_res), dtype=np.float32 + ) + all_dgrams = np.empty( + (num_templates, num_res, num_res, 39), dtype=np.float32 + ) + all_unit_vectors = np.empty( + (num_templates, num_res, num_res, 3), dtype=np.float32 + ) + all_bb_masks = np.empty( + (num_templates, num_res, num_res), dtype=np.float32 + ) + + config = Templates._DGRAM_CONFIG + is_lig = getattr(self, "is_ligand", None) for i in range(num_templates): aatype = self.aatype[i] mask = self.atom_mask[i] pos = self.atom_positions[i] * mask[..., None] - # Compute pseudo-beta positions and mask - pb_pos, pb_mask = TemplateFeatures.pseudo_beta_fn(aatype, pos, mask) + pb_pos, pb_mask = TemplateFeatures.pseudo_beta_fn( + aatype, pos, mask, is_ligand=is_lig + ) pb_mask_2d = pb_mask[:, None] * pb_mask[None, :] - # Compute distogram dgram = TemplateFeatures.dgram_from_positions( - pb_pos, - config=DistogramFeaturesConfig( - min_bin=3.25, max_bin=50.75, num_bins=39 - ), + pb_pos, config=config ) - dgrams.append(dgram * pb_mask_2d[..., None]) - pb_masks.append(pb_mask_2d) + all_dgrams[i] = dgram * pb_mask_2d[..., None] + all_pb_masks[i] = pb_mask_2d - # Compute normalized unit vectors between residues uv, bb_mask_2d = TemplateFeatures.compute_template_unit_vector( aatype, pos, mask ) - unit_vectors.append(uv * bb_mask_2d[..., None]) - bb_masks.append(bb_mask_2d) + all_unit_vectors[i] = uv * bb_mask_2d[..., None] + all_bb_masks[i] = bb_mask_2d features.update( { - "template_pseudo_beta_mask": np.stack(pb_masks), - "template_distogram": np.stack(dgrams), - "template_unit_vector": np.stack(unit_vectors), - "template_backbone_frame_mask": np.stack(bb_masks), + "template_pseudo_beta_mask": all_pb_masks, + "template_distogram": all_dgrams, + "template_unit_vector": all_unit_vectors, + "template_backbone_frame_mask": all_bb_masks, } ) return features diff --git a/protenix/data/template/template_parser.py b/protenix/data/template/template_parser.py index be994d8..251c033 100644 --- a/protenix/data/template/template_parser.py +++ b/protenix/data/template/template_parser.py @@ -20,6 +20,7 @@ import re from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple +import numpy as np from Bio import PDB from Bio.Data import PDBData @@ -635,18 +636,35 @@ def parse( @staticmethod def _get_indices(seq: str, start: int) -> List[int]: - """Calculates residue indices for a sequence with gaps or insertions.""" - indices = [] - curr = start - for char in seq: - if char == "-": - indices.append(-1) - elif char.islower(): - curr += 1 - else: - indices.append(curr) - curr += 1 - return indices + """Calculates residue indices for a sequence with gaps or insertions. + + Vectorized equivalent of:: + + indices = [] + counter = start + for char in seq: + if char == '-': + indices.append(-1) + elif char.islower(): + counter += 1 + else: # uppercase + indices.append(counter) + counter += 1 + """ + char_arr = np.frombuffer(seq.encode("ascii"), dtype=np.uint8) + is_gap = char_arr == ord("-") + is_lower = (char_arr >= ord("a")) & (char_arr <= ord("z")) + is_upper = ~is_gap & ~is_lower + # Positions that produce output (gap or uppercase) + is_output = is_gap | is_upper + # Positions that increment the counter (uppercase or lowercase) + is_increment = is_upper | is_lower + counter = np.cumsum(is_increment) + start + # For output positions: gap -> -1, uppercase -> counter value + gap_at_output = is_gap[is_output] + counter_at_output = counter[is_output] + indices = np.where(gap_at_output, -1, counter_at_output) + return indices.tolist() @staticmethod def _parse_description(desc: str) -> HitMetadata: diff --git a/protenix/data/template/template_utils.py b/protenix/data/template/template_utils.py index 4874895..eefa6fe 100644 --- a/protenix/data/template/template_utils.py +++ b/protenix/data/template/template_utils.py @@ -225,27 +225,32 @@ def pseudo_beta_fn( return pseudo_beta, pseudo_beta_mask + # Pre-computed bin edges (class-level cache to avoid recomputation) + _dgram_cache: dict = {} + @staticmethod def dgram_from_positions( positions: np.ndarray, config: DistogramFeaturesConfig ) -> np.ndarray: """Compute distogram from amino acid positions.""" - lower_breaks = np.linspace(config.min_bin, config.max_bin, config.num_bins) - lower_breaks = np.square(lower_breaks) - upper_breaks = np.concatenate( - [lower_breaks[1:], np.array([1e8], dtype=np.float32)], axis=-1 - ) - dist2 = np.sum( - np.square( - np.expand_dims(positions, axis=-2) - np.expand_dims(positions, axis=-3) - ), - axis=-1, - keepdims=True, - ) - - dgram = (dist2 > lower_breaks).astype(np.float32) * ( - dist2 < upper_breaks - ).astype(np.float32) + cache_key = (config.min_bin, config.max_bin, config.num_bins) + if cache_key not in TemplateFeatures._dgram_cache: + lower = np.linspace( + config.min_bin, config.max_bin, config.num_bins, dtype=np.float32 + ) + lower = np.square(lower) + upper = np.empty_like(lower) + upper[:-1] = lower[1:] + upper[-1] = 1e8 + TemplateFeatures._dgram_cache[cache_key] = (lower, upper) + lower_breaks, upper_breaks = TemplateFeatures._dgram_cache[cache_key] + + # Compute squared distances using einsum (avoids large intermediate) + pos = positions.astype(np.float32, copy=False) + diff = pos[:, np.newaxis, :] - pos[np.newaxis, :, :] + dist2 = np.einsum("ijk,ijk->ij", diff, diff)[..., np.newaxis] + + dgram = ((dist2 > lower_breaks) & (dist2 < upper_breaks)).astype(np.float32) return dgram @staticmethod @@ -256,11 +261,8 @@ def compute_template_unit_vector( epsilon: float = 1e-6, ) -> tuple[np.ndarray, np.ndarray]: """Simplified calculation of template unit vector.""" - # Get backbone indices (C, CA, N) for each residue from protein_data_processing - # Group 0: [C, CA, N] - backbone_indices = RESTYPE_RIGIDGROUP_DENSE_ATOM_IDX[aatype, 0] # [num_res, 3] + backbone_indices = RESTYPE_RIGIDGROUP_DENSE_ATOM_IDX[aatype, 0] - # Indices according to protein_data_processing.py: C is 0, CA is 1, N is 2 c_idx = backbone_indices[:, 0] ca_idx = backbone_indices[:, 1] n_idx = backbone_indices[:, 2] @@ -268,9 +270,9 @@ def compute_template_unit_vector( num_res = aatype.shape[0] res_indices = np.arange(num_res) - c_pos = atom_positions[res_indices, c_idx] - ca_pos = atom_positions[res_indices, ca_idx] - n_pos = atom_positions[res_indices, n_idx] + c_pos = atom_positions[res_indices, c_idx].astype(np.float32, copy=False) + ca_pos = atom_positions[res_indices, ca_idx].astype(np.float32, copy=False) + n_pos = atom_positions[res_indices, n_idx].astype(np.float32, copy=False) c_mask = atom_mask[res_indices, c_idx] ca_mask = atom_mask[res_indices, ca_idx] @@ -278,37 +280,30 @@ def compute_template_unit_vector( mask = (c_mask * ca_mask * n_mask).astype(np.float32) - # Local frame: origin at CA - # x-axis along C-CA (following original code convention) + # Local frame: CA origin, C-CA is x-axis (following AF3 convention) + # Uses einsum for inline norm computation instead of np.linalg.norm v1 = c_pos - ca_pos v2 = n_pos - ca_pos - e1 = v1 / (np.linalg.norm(v1, axis=-1, keepdims=True) + epsilon) - # Orthogonalize v2 against e1 - e2 = v2 - np.sum(v2 * e1, axis=-1, keepdims=True) * e1 - e2 = e2 / (np.linalg.norm(e2, axis=-1, keepdims=True) + epsilon) - # e3 = e1 x e2 + v1_norm = np.sqrt(np.einsum("ij,ij->i", v1, v1))[:, np.newaxis] + epsilon + e1 = v1 / v1_norm + e2 = v2 - np.einsum("ij,ij->i", v2, e1)[:, np.newaxis] * e1 + e2_norm = np.sqrt(np.einsum("ij,ij->i", e2, e2))[:, np.newaxis] + epsilon + e2 = e2 / e2_norm e3 = np.cross(e1, e2) - # Relative positions of all CA atoms to all local frames - # diff[i, j] = CA[j] - CA[i] - diff = ca_pos[None, :, :] - ca_pos[:, None, :] # [num_res, num_res, 3] - - # Transform to local frame: P' = R^T @ diff - # R = [e1 | e2 | e3] -> x' = e1 . diff, y' = e2 . diff, z' = e3 . diff - ux = np.sum(e1[:, None, :] * diff, axis=-1) - uy = np.sum(e2[:, None, :] * diff, axis=-1) - uz = np.sum(e3[:, None, :] * diff, axis=-1) - - unit_vector = np.stack([ux, uy, uz], axis=-1) # [num_res, num_res, 3] + # Build rotation matrix and transform via einsum + R = np.stack([e1, e2, e3], axis=-1) # [num_res, 3, 3] + diff = ca_pos[np.newaxis, :, :] - ca_pos[:, np.newaxis, :] + unit_vector = np.einsum("ilk,ijl->ijk", R, diff) - # Normalize to unit vector - uv_norm = np.linalg.norm(unit_vector, axis=-1, keepdims=True) - unit_vector = unit_vector / (uv_norm + epsilon) + uv_norm = np.sqrt( + np.einsum("ijk,ijk->ij", unit_vector, unit_vector) + )[..., np.newaxis] + epsilon + unit_vector = unit_vector / uv_norm # 2D mask mask_2d = mask[:, None] * mask[None, :] - return unit_vector, mask_2d @@ -502,17 +497,18 @@ def _check_residue_distances( ): """Verifies that distance between consecutive CA atoms is within limits.""" ca_idx = ATOM37_ORDER["CA"] - prev_pos = None - for i, (p, m) in enumerate(zip(pos, mask)): - if m[ca_idx]: - curr_pos = p[ca_idx] - if prev_pos is not None: - dist = np.linalg.norm(curr_pos - prev_pos) - if dist > max_dist: - raise CaDistanceError( - f"Distance between residues {i} and previous is {dist:.2f} > {max_dist}" - ) - prev_pos = curr_pos + ca_mask = mask[:, ca_idx].astype(bool) + if ca_mask.sum() < 2: + return + ca_pos = pos[ca_mask, ca_idx, :] + diffs = ca_pos[1:] - ca_pos[:-1] + dists = np.sqrt(np.einsum("ij,ij->i", diffs, diffs)) + max_found = dists.max() + if max_found > max_dist: + bad_idx = int(np.argmax(dists)) + raise CaDistanceError( + f"Distance between residues at index {bad_idx} is {max_found:.2f} > {max_dist}" + ) def _get_atom_positions( self,