Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions pxdesign/runner/dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import os
from pathlib import Path
Expand Down Expand Up @@ -195,12 +194,12 @@ def save_structure_cif(
entity_poly_type (dict[str, str]): The entity poly type information.
pdb_id (str): The PDB ID for the entry.
"""
pred_atom_array = copy.deepcopy(atom_array)
pred_pose = pred_coordinate.cpu().numpy()
pred_atom_array.coord = pred_pose
original_coord = atom_array.coord
atom_array.coord = pred_coordinate.cpu().numpy()
save_atoms_to_cif(
output_fpath,
pred_atom_array,
atom_array,
entity_poly_type,
pdb_id,
)
atom_array.coord = original_coord
Comment on lines +197 to +205
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atom_array.coord is restored only on the happy path. If save_atoms_to_cif() raises (I/O error, invalid structure, etc.), the function exits early and leaves atom_array mutated for the rest of the run. Wrap the coord swap in a try/finally so the original coordinates are restored even on exceptions.

Copilot uses AI. Check for mistakes.
Empty file added tests/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions tests/test_avoid_deepcopy_cif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Test that CIF writing uses coord swap instead of deepcopy."""

import ast
from pathlib import Path

import numpy as np


def test_no_deepcopy_in_dumper():
"""Verify that dumper.py no longer uses copy.deepcopy."""
filepath = Path(__file__).parent.parent / "pxdesign" / "runner" / "dumper.py"
tree = ast.parse(filepath.read_text())

for node in ast.walk(tree):
if isinstance(node, ast.Call):
func = node.func
if (
isinstance(func, ast.Attribute)
and func.attr == "deepcopy"
and isinstance(func.value, ast.Name)
and func.value.id == "copy"
):
raise AssertionError(
"Found copy.deepcopy() in dumper.py. "
"Use coord save/restore instead to avoid N_sample full copies."
)


def test_coord_restore_correctness():
"""Verify the coord swap pattern restores original coordinates."""
original_coords = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

class MockArray:
def __init__(self):
self.coord = original_coords.copy()

mock = MockArray()
new_coords = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])

saved = mock.coord
mock.coord = new_coords
assert np.array_equal(mock.coord, new_coords), "Should have new coords during write"
mock.coord = saved
assert np.array_equal(mock.coord, original_coords), "Should restore original coords after write"