@@ -67,30 +67,40 @@ def aabb_transform(aabb, matrix, /, *, out=None, dtype=None) -> np.ndarray:
6767 """
6868
6969 aabb = np .asarray (aabb , dtype = float )
70- matrix = np .asarray (matrix , dtype = float ).transpose ((- 1 , - 2 ))
70+ matrix = np .asarray (matrix , dtype = float )
71+
72+ # transpose last two dimensions
73+ axes = list (range (matrix .ndim ))
74+ axes [- 2 :] = axes [- 1 ], axes [- 2 ]
75+ matrix = matrix .transpose (axes )
7176
7277 if out is None :
73- out = np .empty_like (aabb , dtype = dtype )
78+ # Compute output shape by broadcasting aabb and matrix shapes (excluding last 2 dims)
79+ aabb_shape = aabb .shape [:- 2 ]
80+ matrix_shape = matrix .shape [:- 2 ]
81+ broadcast_shape = np .broadcast_shapes (aabb_shape , matrix_shape )
82+ out = np .empty ((* broadcast_shape , * aabb .shape [- 2 :]), dtype = dtype )
7483
7584 corners = np .full (
76- aabb .shape [:- 2 ] + ( 8 , 4 ),
85+ ( * aabb .shape [:- 2 ], 8 , 4 ),
7786 # Fill value of 1 is used for homogeneous coordinates.
7887 fill_value = 1.0 ,
7988 dtype = float ,
8089 )
90+
8191 # x
82- corners [..., 0 ::2 , 0 ] = aabb [..., 0 , 0 ]
83- corners [..., 1 ::2 , 0 ] = aabb [..., 1 , 0 ]
92+ corners [..., 0 ::2 , 0 ] = aabb [..., 0 , 0 , np . newaxis ]
93+ corners [..., 1 ::2 , 0 ] = aabb [..., 1 , 0 , np . newaxis ]
8494
8595 # y
86- corners [..., 0 ::4 , 1 ] = aabb [..., 0 , 1 ]
87- corners [..., 1 ::4 , 1 ] = aabb [..., 0 , 1 ]
88- corners [..., 2 ::4 , 1 ] = aabb [..., 1 , 1 ]
89- corners [..., 3 ::4 , 1 ] = aabb [..., 1 , 1 ]
96+ corners [..., 0 ::4 , 1 ] = aabb [..., 0 , 1 , np . newaxis ]
97+ corners [..., 1 ::4 , 1 ] = aabb [..., 0 , 1 , np . newaxis ]
98+ corners [..., 2 ::4 , 1 ] = aabb [..., 1 , 1 , np . newaxis ]
99+ corners [..., 3 ::4 , 1 ] = aabb [..., 1 , 1 , np . newaxis ]
90100
91101 # z
92- corners [..., 0 :4 , 2 ] = aabb [..., 0 , 2 ]
93- corners [..., 4 :8 , 2 ] = aabb [..., 1 , 2 ]
102+ corners [..., 0 :4 , 2 ] = aabb [..., 0 , 2 , np . newaxis ]
103+ corners [..., 4 :8 , 2 ] = aabb [..., 1 , 2 , np . newaxis ]
94104
95105 corners = corners @ matrix
96106 out [..., 0 , :] = np .min (corners [..., :- 1 ], axis = - 2 )
0 commit comments