Skip to content

Commit 6df1934

Browse files
Adding option for broadcasting tensors similar to torch
1 parent 38874eb commit 6df1934

File tree

1 file changed

+310
-0
lines changed

1 file changed

+310
-0
lines changed
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
use burn::{prelude::Backend, tensor::Tensor};
2+
/// Broadcast two tensors with potentially different static ranks to a common rank.
3+
///
4+
/// # Syntax
5+
/// ```ignore
6+
/// broadcast!(
7+
/// a: Tensor<Backend, RANK_A>,
8+
/// b: Tensor<Backend, RANK_B>
9+
/// )
10+
/// ```
11+
///
12+
/// # Parameters
13+
/// - `a`: Identifier for the first tensor variable (e.g., `a`).
14+
/// - `backend`: The backend type used in the first tensor (e.g., `MyBackend`).
15+
/// - `dims1`: The static rank of the first tensor (e.g., `2`, `3`, etc.).
16+
///
17+
/// - `b`: Identifier for the second tensor variable (e.g., `b`).
18+
/// - `backend`: The backend type used in the second tensor (must match `backend` for correctness).
19+
/// - `dims2`: The static rank of the second tensor.
20+
///
21+
/// # Expansion
22+
/// Expands to:
23+
/// ```rust
24+
/// {
25+
/// const N: usize = max(dims1, dims2);
26+
/// broadcast::<B, N, dims1, dims2>(a, b)
27+
/// }
28+
/// ```
29+
///
30+
/// # Example
31+
/// ```rust
32+
/// let a: Tensor<MyBackend, 2> = ...;
33+
/// let b: Tensor<MyBackend, 4> = ...;
34+
///
35+
/// let result = broadcast!(
36+
/// a: Tensor<MyBackend, 2>,
37+
/// b: Tensor<MyBackend, 4>
38+
/// );
39+
/// // Expands to: broadcast::<MyBackend, 4, 2, 4>(a, b)
40+
/// ```
41+
///
42+
#[macro_export]
43+
macro_rules! broadcast {
44+
(
45+
$a:ident : Tensor<$backend1:ty, $dims1:tt>,
46+
$b:ident : Tensor<$backend2:ty, $dims2:tt>
47+
) => {{
48+
use $crate::ops::broadcast_op;
49+
const fn max(a: usize, b: usize) -> usize {
50+
if a > b { a } else { b }
51+
}
52+
53+
const N: usize = max($dims1, $dims2);
54+
55+
broadcast_op::<$backend1, N, $dims1, $dims2>($a, $b)
56+
}};
57+
}
58+
59+
#[macro_export]
60+
macro_rules! add_broadcast {
61+
(
62+
$a:ident : Tensor<$backend1:ty, $dims1:tt>,
63+
$b:ident : Tensor<$backend2:ty, $dims2:tt>
64+
) => {{
65+
use $crate::ops::broadcast_op;
66+
const fn max(a: usize, b: usize) -> usize {
67+
if a > b { a } else { b }
68+
}
69+
70+
const N: usize = max($dims1, $dims2);
71+
72+
let (a,b) = broadcast_op::<$backend1, N, $dims1, $dims2>($a, $b);
73+
a.add(b)
74+
}};
75+
}
76+
77+
pub fn broadcast_op<B: Backend, const N: usize, const DA: usize, const DB: usize>(
78+
a: Tensor<B, DA>,
79+
b: Tensor<B, DB>,
80+
) -> (Tensor<B, N>, Tensor<B, N>) {
81+
// pad left with 1s
82+
83+
let a = a.unsqueeze::<N>();
84+
let b = b.unsqueeze::<N>();
85+
86+
let b_shape = b.shape().dims::<N>();
87+
88+
// Convert dims, change non 1 values to -1 and 1 values to corresponding tensor shape
89+
// for burn expand format
90+
91+
// Make changes in b dimensions to match a dimensions and insert -1s
92+
93+
let b_shape_new: Vec<i64> = a
94+
.shape()
95+
.dims::<N>()
96+
.iter_mut()
97+
.enumerate()
98+
.map(
99+
|(i, val)| {
100+
if b_shape[i] == 1 { *val as i64 } else { -1_i64 }
101+
},
102+
)
103+
.collect();
104+
105+
// Make changes in a dimensions to match b dimensions and insert -1s
106+
107+
let a_shape = a.shape().dims::<N>();
108+
109+
let a_shape_new: Vec<i64> = b
110+
.shape()
111+
.dims::<N>()
112+
.iter_mut()
113+
.enumerate()
114+
.map(
115+
|(i, val)| {
116+
if a_shape[i] == 1 { *val as i64 } else { -1_i64 }
117+
},
118+
)
119+
.collect();
120+
121+
// Expand both tensors to match each other using the new shapes by
122+
// expanding tensors a and b using new shape with -1s inserted
123+
124+
let b = b.expand::<N, [i64; N]>(b_shape_new.try_into().unwrap());
125+
let a = a.expand::<N, [i64; N]>(a_shape_new.try_into().unwrap());
126+
127+
(a, b)
128+
}
129+
130+
#[cfg(test)]
131+
mod tests {
132+
use burn::backend::ndarray::{NdArray, NdArrayDevice};
133+
use super::*;
134+
135+
#[test]
136+
fn test_broadcast_multi_dims() {
137+
let device = &NdArrayDevice::default();
138+
type B = NdArray<f32>;
139+
140+
let a = Tensor::<B, 6>::empty([7, 6, 2, 3, 1, 9], device);
141+
let b = Tensor::<B, 4>::empty([2, 1, 7, 1], device);
142+
143+
let (a, b) = broadcast!(a: Tensor<B, 6>, b: Tensor<B, 4>);
144+
145+
assert_eq!(a.shape(), b.shape());
146+
}
147+
148+
#[test]
149+
fn test_broadcast_multi_dims_values() {
150+
let device = &NdArrayDevice::default();
151+
type B = NdArray<f32>;
152+
153+
let a = Tensor::<B, 3>::from_data(
154+
[
155+
[[2, 8, 7, 2], [9, 14, 13, 12], [9, 14, 13, 12]],
156+
[[2, 8, 7, 2], [9, 14, 13, 12], [9, 14, 13, 12]],
157+
],
158+
device,
159+
);
160+
161+
let b = Tensor::<B, 2>::from_data([[4, 11, 10, 5]], device);
162+
163+
let (a, b) = broadcast!(a:Tensor<B, 3>, b:Tensor<B, 2>);
164+
165+
let a_add_b = a.add(b);
166+
167+
Tensor::<B, 3>::from_data(
168+
[
169+
[[6, 19, 17, 7], [13, 25, 23, 17], [13, 25, 23, 17]],
170+
[[6, 19, 17, 7], [13, 25, 23, 17], [13, 25, 23, 17]],
171+
],
172+
device,
173+
)
174+
.into_data()
175+
.assert_eq(&a_add_b.to_data(), true);
176+
}
177+
178+
#[test]
179+
fn test_max_broadcast() {
180+
let device = &NdArrayDevice::default();
181+
type B = NdArray<f32>;
182+
183+
let a = Tensor::<B, 1>::from_data([3.0, 2.0, 6.0, 3.0], device);
184+
185+
let b = Tensor::<B, 1>::from_data([1.0, 0.5, 4.0, 7.0], device);
186+
187+
let a = a.reshape([-1, 1]);
188+
189+
let (a, b) = broadcast!(a:Tensor<B, 2>, b:Tensor<B, 1>);
190+
191+
let max_a_b = a.max_pair(b);
192+
193+
Tensor::<B, 2>::from_data(
194+
[
195+
[3.0, 3.0, 4.0, 7.0],
196+
[2.0, 2.0, 4.0, 7.0],
197+
[6.0, 6.0, 6.0, 7.0],
198+
[3.0, 3.0, 4.0, 7.0],
199+
],
200+
device,
201+
)
202+
.into_data()
203+
.assert_eq(&max_a_b.to_data(), true);
204+
}
205+
206+
#[test]
207+
fn test_add_broadcast() {
208+
let device = &NdArrayDevice::default();
209+
type B = NdArray<f32>;
210+
211+
let a = Tensor::<B, 1>::from_data([1.1, 2.2, 3.3], device);
212+
213+
let b = Tensor::<B, 1>::from_data([4.0, 5.0, 6.0, 7.0], device);
214+
215+
let a = a.reshape([-1, 1]);
216+
217+
let (a, b) = broadcast!(a:Tensor<B, 2>, b:Tensor<B, 1>);
218+
let add_a_b = a.add(b);
219+
220+
Tensor::<B, 2>::from_data(
221+
[
222+
[5.1, 6.1, 7.1, 8.1],
223+
[6.2, 7.2, 8.2, 9.2],
224+
[7.3, 8.3, 9.3, 10.3],
225+
],
226+
device,
227+
)
228+
.into_data()
229+
.assert_eq(&add_a_b.to_data(), true);
230+
231+
let a = Tensor::<B, 1>::from_data([1.1, 2.2, 3.3], device);
232+
let b = Tensor::<B, 1>::from_data([4.0, 5.0, 6.0, 7.0], device);
233+
234+
let b = b.reshape([-1, 1]);
235+
let (a, b) = broadcast!(a:Tensor<B, 1>, b:Tensor<B, 2>);
236+
let add_a_b = a.add(b);
237+
238+
Tensor::<B, 2>::from_data(
239+
[
240+
[5.1, 6.2, 7.3],
241+
[6.1, 7.2, 8.3],
242+
[7.1, 8.2, 9.3],
243+
[8.1, 9.2, 10.3],
244+
],
245+
device,
246+
)
247+
.into_data()
248+
.assert_eq(&add_a_b.to_data(), true);
249+
}
250+
251+
#[test]
252+
fn test_max_broadcast_uneven() {
253+
let device = &NdArrayDevice::default();
254+
type B = NdArray<f32>;
255+
256+
let a = Tensor::<B, 1>::from_data([3.0, 2.0, 6.0, 3.0], device);
257+
258+
let b = Tensor::<B, 1>::from_data([1.0, 0.5, 4.0, 7.0, 8.0], device);
259+
260+
let b = b.reshape([-1, 1]);
261+
262+
let (a, b) = broadcast!(a:Tensor<B, 1>, b:Tensor<B, 2>);
263+
264+
let max_a_b = a.max_pair(b);
265+
266+
Tensor::<B, 2>::from_data(
267+
[
268+
[3.0, 2.0, 6.0, 3.0],
269+
[3.0, 2.0, 6.0, 3.0],
270+
[4.0, 4.0, 6.0, 4.0],
271+
[7.0, 7.0, 7.0, 7.0],
272+
[8.0, 8.0, 8.0, 8.0],
273+
],
274+
device,
275+
)
276+
.into_data()
277+
.assert_eq(&max_a_b.to_data(), true);
278+
}
279+
280+
#[test]
281+
fn test_add_broadcast_diff_dims() {
282+
let device = &NdArrayDevice::default();
283+
type B = NdArray<f32>;
284+
285+
let a = Tensor::<B, 2>::from_data(
286+
[
287+
[3.0, 2.0, 6.0, 3.0],
288+
[3.0, 2.0, 6.0, 3.0],
289+
[8.0, 7.0, 7.0, 13.0],
290+
],
291+
device,
292+
);
293+
294+
let b = Tensor::<B, 1>::from_data([1.0, 0.5, 4.0, 7.0], device);
295+
let (a, b) = broadcast!(a:Tensor<B, 2>, b:Tensor<B, 1>);
296+
297+
let add_a_b = a.add(b);
298+
299+
Tensor::<B, 2>::from_data(
300+
[
301+
[4.0, 2.5, 10.0, 10.0],
302+
[4.0, 2.5, 10.0, 10.0],
303+
[9.0, 7.5, 11.0, 20.0],
304+
],
305+
device,
306+
)
307+
.into_data()
308+
.assert_eq(&add_a_b.to_data(), true);
309+
}
310+
}

0 commit comments

Comments
 (0)