Skip to content

Commit d4886bc

Browse files
authored
Broadcasting support for aabb_transform (#107)
* ruff * fix broadcasting for aabb_transform
1 parent 4c1e20d commit d4886bc

File tree

6 files changed

+83
-22
lines changed

6 files changed

+83
-22
lines changed

pylinalg/misc.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,30 +67,40 @@ def aabb_transform(aabb, matrix, /, *, out=None, dtype=None) -> np.ndarray:
6767
"""
6868

6969
aabb = np.asarray(aabb, dtype=float)
70-
matrix = np.asarray(matrix, dtype=float).transpose((-1, -2))
70+
matrix = np.asarray(matrix, dtype=float)
71+
72+
# transpose last two dimensions
73+
axes = list(range(matrix.ndim))
74+
axes[-2:] = axes[-1], axes[-2]
75+
matrix = matrix.transpose(axes)
7176

7277
if out is None:
73-
out = np.empty_like(aabb, dtype=dtype)
78+
# Compute output shape by broadcasting aabb and matrix shapes (excluding last 2 dims)
79+
aabb_shape = aabb.shape[:-2]
80+
matrix_shape = matrix.shape[:-2]
81+
broadcast_shape = np.broadcast_shapes(aabb_shape, matrix_shape)
82+
out = np.empty((*broadcast_shape, *aabb.shape[-2:]), dtype=dtype)
7483

7584
corners = np.full(
76-
aabb.shape[:-2] + (8, 4),
85+
(*aabb.shape[:-2], 8, 4),
7786
# Fill value of 1 is used for homogeneous coordinates.
7887
fill_value=1.0,
7988
dtype=float,
8089
)
90+
8191
# x
82-
corners[..., 0::2, 0] = aabb[..., 0, 0]
83-
corners[..., 1::2, 0] = aabb[..., 1, 0]
92+
corners[..., 0::2, 0] = aabb[..., 0, 0, np.newaxis]
93+
corners[..., 1::2, 0] = aabb[..., 1, 0, np.newaxis]
8494

8595
# y
86-
corners[..., 0::4, 1] = aabb[..., 0, 1]
87-
corners[..., 1::4, 1] = aabb[..., 0, 1]
88-
corners[..., 2::4, 1] = aabb[..., 1, 1]
89-
corners[..., 3::4, 1] = aabb[..., 1, 1]
96+
corners[..., 0::4, 1] = aabb[..., 0, 1, np.newaxis]
97+
corners[..., 1::4, 1] = aabb[..., 0, 1, np.newaxis]
98+
corners[..., 2::4, 1] = aabb[..., 1, 1, np.newaxis]
99+
corners[..., 3::4, 1] = aabb[..., 1, 1, np.newaxis]
90100

91101
# z
92-
corners[..., 0:4, 2] = aabb[..., 0, 2]
93-
corners[..., 4:8, 2] = aabb[..., 1, 2]
102+
corners[..., 0:4, 2] = aabb[..., 0, 2, np.newaxis]
103+
corners[..., 4:8, 2] = aabb[..., 1, 2, np.newaxis]
94104

95105
corners = corners @ matrix
96106
out[..., 0, :] = np.min(corners[..., :-1], axis=-2)

pylinalg/quaternion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def quat_from_axis_angle(axis, angle, /, *, out=None, dtype=None) -> np.ndarray:
251251
out = np.empty((*out_shape, 4), dtype=dtype)
252252

253253
# result should be independent of the length of the given axis
254-
lengths_shape = axis.shape[:-1] + (1,)
254+
lengths_shape = (*axis.shape[:-1], 1)
255255
axis = axis / np.linalg.norm(axis, axis=-1).reshape(lengths_shape)
256256

257257
out[..., :3] = axis * np.sin(angle / 2).reshape(lengths_shape)

pylinalg/vector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def vec_normalize(vectors, /, *, out=None, dtype=None) -> np.ndarray:
3030
if out is None:
3131
out = np.empty_like(vectors, dtype=dtype)
3232

33-
lengths_shape = vectors.shape[:-1] + (1,)
33+
lengths_shape = (*vectors.shape[:-1], 1)
3434
lengths = np.linalg.norm(vectors, axis=-1).reshape(lengths_shape)
3535
return np.divide(vectors, lengths, out=out)
3636

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[project]
44
name = "pylinalg"
5-
version = "0.6.7"
5+
version = "0.6.8"
66
description = "Linear algebra utilities for Python"
77
readme = "README.md"
88
license = { file = "LICENSE" }
@@ -59,5 +59,5 @@ ignore = [
5959
]
6060

6161
[tool.ruff.lint.per-file-ignores]
62-
"__init__.py" = ["F401", "F403"]
62+
"__init__.py" = ["F401", "F403", "RUF048"]
6363
"tests/conftest.py" = ["B008"]

tests/conftest.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
def pytest_report_header(config):
1414
# report the CPU model to allow detecting platform-specific problems
1515
if platform.system() == "Windows":
16-
name = (
17-
subprocess.check_output(["wmic", "cpu", "get", "name"])
18-
.decode()
19-
.strip()
20-
.split("\n")[1]
21-
)
22-
cpu_info = " ".join([name])
16+
try:
17+
name = (
18+
subprocess.check_output(["wmic", "cpu", "get", "name"])
19+
.decode()
20+
.strip()
21+
.split("\n")[1]
22+
)
23+
cpu_info = " ".join([name])
24+
except Exception:
25+
cpu_info = "Unknown CPU (wmic not available)"
2326
elif platform.system() == "Linux":
2427
info_string = subprocess.check_output(["lscpu"]).decode()
2528
for line in info_string.split("\n"):

tests/test_misc.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,51 @@ def test_aabb_transform(point, translation, scale):
4949
scale_matrix = la.mat_from_scale(scale)
5050
result = la.aabb_transform(aabb, scale_matrix)
5151
assert np.allclose(result, np.sort(aabb * scale, axis=0), atol=1e-10)
52+
53+
54+
def test_aabb_transform_single():
55+
"""Test single transform."""
56+
aabb = np.array([[-1, -1, -1], [1, 1, 1]])
57+
translation = np.array([1, 0, 0])
58+
translation_matrix = la.mat_from_translation(translation)
59+
60+
expected = aabb + translation
61+
result = la.aabb_transform(aabb, translation_matrix)
62+
assert np.allclose(result, expected, atol=1e-10)
63+
64+
65+
def test_aabb_transform_broadcasting():
66+
"""Test pairwise broadcasting of AABBs and matrices."""
67+
aabbs = np.array(
68+
[
69+
[[-1, -1, -1], [1, 1, 1]],
70+
[[-2, -2, -2], [2, 2, 2]],
71+
]
72+
)
73+
translations = np.array(
74+
[
75+
[1, 0, 0],
76+
[0, 1, 0],
77+
]
78+
)
79+
translation_matrices = la.mat_from_translation(translations)
80+
81+
expected = aabbs + translations[:, np.newaxis, :]
82+
result = la.aabb_transform(aabbs, translation_matrices)
83+
assert np.allclose(result, expected, atol=1e-10)
84+
85+
86+
def test_aabb_transform_broadcasting_2():
87+
"""Test broadcasting many matrices and one AABB."""
88+
aabb = np.array([[-1, -1, -1], [1, 1, 1]])
89+
translations = np.array(
90+
[
91+
[1, 0, 0],
92+
[0, 1, 0],
93+
]
94+
)
95+
translation_matrices = la.mat_from_translation(translations)
96+
97+
expected = aabb[np.newaxis, ...] + translations[:, np.newaxis, :]
98+
result = la.aabb_transform(aabb, translation_matrices)
99+
assert np.allclose(result, expected, atol=1e-10)

0 commit comments

Comments
 (0)