diff --git a/dpdata/amber/md.py b/dpdata/amber/md.py index cb4f2d25e..32396f0f7 100644 --- a/dpdata/amber/md.py +++ b/dpdata/amber/md.py @@ -18,6 +18,78 @@ force_convert = energy_convert +def cell_lengths_angles_to_cell( + cell_lengths: np.ndarray, cell_angles: np.ndarray +) -> np.ndarray: + """Convert cell lengths and angles to cell vectors. + + Parameters + ---------- + cell_lengths : np.ndarray + Cell lengths with shape (..., 3) where the last dimension + corresponds to [a, b, c] + cell_angles : np.ndarray + Cell angles in degrees with shape (..., 3) where the last dimension + corresponds to [alpha, beta, gamma] + + Returns + ------- + np.ndarray + Cell vectors with shape (..., 3, 3) where the last two dimensions + form the cell matrix + + Notes + ----- + Uses the standard crystallographic convention: + - v1 = [a, 0, 0] + - v2 = [b*cos(gamma), b*sin(gamma), 0] + - v3 = [c*cos(beta), c*(cos(alpha) - cos(beta)*cos(gamma))/sin(gamma), c*z] + where z = sqrt(1 - cos²(alpha) - cos²(beta) - cos²(gamma) + 2*cos(alpha)*cos(beta)*cos(gamma))/sin(gamma) + """ + # Convert to radians + alpha = np.deg2rad(cell_angles[..., 0]) # angle between b and c + beta = np.deg2rad(cell_angles[..., 1]) # angle between a and c + gamma = np.deg2rad(cell_angles[..., 2]) # angle between a and b + + a = cell_lengths[..., 0] + b = cell_lengths[..., 1] + c = cell_lengths[..., 2] + + cos_alpha = np.cos(alpha) + cos_beta = np.cos(beta) + cos_gamma = np.cos(gamma) + sin_gamma = np.sin(gamma) + + # Calculate the z-component of the third vector + z_factor = ( + 1 + - cos_alpha**2 + - cos_beta**2 + - cos_gamma**2 + + 2 * cos_alpha * cos_beta * cos_gamma + ) + z_factor = np.maximum(z_factor, 0) # Ensure non-negative for sqrt + z = np.sqrt(z_factor) / sin_gamma + + # Build cell vectors + shape = cell_lengths.shape[:-1] + (3, 3) + cell = np.zeros(shape) + + # First vector: [a, 0, 0] + cell[..., 0, 0] = a + + # Second vector: [b*cos(gamma), b*sin(gamma), 0] + cell[..., 1, 0] = b * cos_gamma + cell[..., 1, 1] = b * sin_gamma + + # Third vector: [c*cos(beta), c*(cos(alpha) - cos(beta)*cos(gamma))/sin(gamma), c*z] + cell[..., 2, 0] = c * cos_beta + cell[..., 2, 1] = c * (cos_alpha - cos_beta * cos_gamma) / sin_gamma + cell[..., 2, 2] = c * z + + return cell + + def read_amber_traj( parm7_file, nc_file, @@ -85,15 +157,8 @@ def read_amber_traj( coords = np.array(f.variables["coordinates"][:]) cell_lengths = np.array(f.variables["cell_lengths"][:]) cell_angles = np.array(f.variables["cell_angles"][:]) - if np.all(cell_angles > 89.99) and np.all(cell_angles < 90.01): - # only support 90 - # TODO: support other angles - shape = cell_lengths.shape - cells = np.zeros((shape[0], 3, 3)) - for ii in range(3): - cells[:, ii, ii] = cell_lengths[:, ii] - else: - raise RuntimeError("Unsupported cells") + # Convert cell lengths and angles to cell vectors for all cases + cells = cell_lengths_angles_to_cell(cell_lengths, cell_angles) if labeled: with netcdf_file(mdfrc_file, "r") as f: diff --git a/tests/test_amber_nonorthogonal.py b/tests/test_amber_nonorthogonal.py new file mode 100644 index 000000000..b1b111111 --- /dev/null +++ b/tests/test_amber_nonorthogonal.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import unittest + +import numpy as np + +from dpdata.amber.md import cell_lengths_angles_to_cell + + +class TestAmberNonOrthogonalCells(unittest.TestCase): + def test_orthogonal_cell_conversion(self): + """Test that orthogonal cells (90° angles) work correctly.""" + # Test case: simple cubic cell with a=10, b=15, c=20, all angles=90° + cell_lengths = np.array([[10.0, 15.0, 20.0]]) + cell_angles = np.array([[90.0, 90.0, 90.0]]) + + expected_cell = np.array( + [[[10.0, 0.0, 0.0], [0.0, 15.0, 0.0], [0.0, 0.0, 20.0]]] + ) + + result_cell = cell_lengths_angles_to_cell(cell_lengths, cell_angles) + + np.testing.assert_allclose(result_cell, expected_cell, rtol=1e-12, atol=1e-14) + + def test_monoclinic_cell_conversion(self): + """Test monoclinic cell (beta != 90°, alpha=gamma=90°).""" + # Test case: monoclinic cell with a=10, b=15, c=20, alpha=90°, beta=120°, gamma=90° + cell_lengths = np.array([[10.0, 15.0, 20.0]]) + cell_angles = np.array([[90.0, 120.0, 90.0]]) + + # Expected result: + # v1 = [10, 0, 0] + # v2 = [0, 15, 0] (gamma=90°) + # v3 = [20*cos(120°), 0, 20*sin(120°)] = [-10, 0, 17.32...] + cos_120 = np.cos(np.deg2rad(120.0)) # -0.5 + sin_120 = np.sin(np.deg2rad(120.0)) # sqrt(3)/2 + + expected_cell = np.array( + [ + [ + [10.0, 0.0, 0.0], + [0.0, 15.0, 0.0], + [20.0 * cos_120, 0.0, 20.0 * sin_120], + ] + ] + ) + + result_cell = cell_lengths_angles_to_cell(cell_lengths, cell_angles) + + np.testing.assert_allclose(result_cell, expected_cell, rtol=1e-12, atol=1e-14) + + def test_hexagonal_cell_conversion(self): + """Test hexagonal cell (gamma=120°, alpha=beta=90°).""" + # Test case: hexagonal cell with a=10, b=10, c=15, alpha=90°, beta=90°, gamma=120° + cell_lengths = np.array([[10.0, 10.0, 15.0]]) + cell_angles = np.array([[90.0, 90.0, 120.0]]) + + # Expected result: + # v1 = [10, 0, 0] + # v2 = [10*cos(120°), 10*sin(120°), 0] = [-5, 8.66..., 0] + # v3 = [0, 0, 15] (alpha=beta=90°) + cos_120 = np.cos(np.deg2rad(120.0)) # -0.5 + sin_120 = np.sin(np.deg2rad(120.0)) # sqrt(3)/2 + + expected_cell = np.array( + [ + [ + [10.0, 0.0, 0.0], + [10.0 * cos_120, 10.0 * sin_120, 0.0], + [0.0, 0.0, 15.0], + ] + ] + ) + + result_cell = cell_lengths_angles_to_cell(cell_lengths, cell_angles) + + np.testing.assert_allclose(result_cell, expected_cell, rtol=1e-12, atol=1e-14) + + def test_triclinic_cell_conversion(self): + """Test triclinic cell (all angles != 90°).""" + # Test case: triclinic cell with a=8, b=10, c=12, alpha=70°, beta=80°, gamma=110° + cell_lengths = np.array([[8.0, 10.0, 12.0]]) + cell_angles = np.array([[70.0, 80.0, 110.0]]) + + result_cell = cell_lengths_angles_to_cell(cell_lengths, cell_angles) + + # Check that the result has the right shape + self.assertEqual(result_cell.shape, (1, 3, 3)) + + # Check that the cell vectors have the correct lengths + computed_lengths = np.linalg.norm(result_cell[0], axis=1) + expected_lengths = np.array([8.0, 10.0, 12.0]) + np.testing.assert_allclose(computed_lengths, expected_lengths, rtol=1e-12) + + # Check that the angles between vectors are correct + v1, v2, v3 = result_cell[0] + + # Angle between v2 and v3 should be alpha (70°) + cos_alpha = np.dot(v2, v3) / (np.linalg.norm(v2) * np.linalg.norm(v3)) + alpha_computed = np.rad2deg(np.arccos(cos_alpha)) + np.testing.assert_allclose(alpha_computed, 70.0, rtol=1e-8) + + # Angle between v1 and v3 should be beta (80°) + cos_beta = np.dot(v1, v3) / (np.linalg.norm(v1) * np.linalg.norm(v3)) + beta_computed = np.rad2deg(np.arccos(cos_beta)) + np.testing.assert_allclose(beta_computed, 80.0, rtol=1e-8) + + # Angle between v1 and v2 should be gamma (110°) + cos_gamma = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) + gamma_computed = np.rad2deg(np.arccos(cos_gamma)) + np.testing.assert_allclose(gamma_computed, 110.0, rtol=1e-8) + + def test_extreme_angles_case(self): + """Test edge case with angles very far from 90°.""" + cell_lengths = np.array([[5.0, 8.0, 12.0]]) + cell_angles = np.array([[60.0, 70.0, 130.0]]) # all far from 90° + + # Should work without error + result = cell_lengths_angles_to_cell(cell_lengths, cell_angles) + self.assertEqual(result.shape, (1, 3, 3)) + + # Verify the lengths are preserved + computed_lengths = np.linalg.norm(result[0], axis=1) + expected_lengths = np.array([5.0, 8.0, 12.0]) + np.testing.assert_allclose(computed_lengths, expected_lengths, rtol=1e-10) + + def test_multiple_frames(self): + """Test that multiple frames are handled correctly.""" + # Test case: 3 frames with different cell parameters + cell_lengths = np.array( + [ + [10.0, 10.0, 10.0], # cubic + [8.0, 12.0, 15.0], # orthorhombic + [10.0, 10.0, 12.0], + ] + ) # hexagonal-like + cell_angles = np.array( + [[90.0, 90.0, 90.0], [90.0, 90.0, 90.0], [90.0, 90.0, 120.0]] + ) + + result_cell = cell_lengths_angles_to_cell(cell_lengths, cell_angles) + + # Check shape + self.assertEqual(result_cell.shape, (3, 3, 3)) + + # Check first frame (cubic) + expected_frame1 = np.array( + [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]] + ) + np.testing.assert_allclose( + result_cell[0], expected_frame1, rtol=1e-12, atol=1e-14 + ) + + # Check third frame (hexagonal-like) + cos_120 = np.cos(np.deg2rad(120.0)) # -0.5 + sin_120 = np.sin(np.deg2rad(120.0)) # sqrt(3)/2 + expected_frame3 = np.array( + [[10.0, 0.0, 0.0], [10.0 * cos_120, 10.0 * sin_120, 0.0], [0.0, 0.0, 12.0]] + ) + np.testing.assert_allclose( + result_cell[2], expected_frame3, rtol=1e-12, atol=1e-14 + ) + + +if __name__ == "__main__": + unittest.main()