Skip to content

Commit 4bc28cb

Browse files
Updates
1 parent 47c712c commit 4bc28cb

File tree

1 file changed

+61
-67
lines changed

1 file changed

+61
-67
lines changed

crates/burn-tensor/src/tensor/api/broadcast.rs

Lines changed: 61 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,83 +5,78 @@ use burn::{prelude::Backend, tensor::Tensor};
55
/// ```ignore
66
/// broadcast!(
77
/// a: Tensor<Backend, RANK_A>,
8-
/// b: Tensor<Backend, RANK_B>
8+
/// b: Tensor<RANK_B>
99
/// )
1010
/// ```
1111
///
1212
/// # Parameters
1313
/// - `a`: Identifier for the first tensor variable (e.g., `a`).
14-
/// - `backend`: The backend type used in the first tensor (e.g., `MyBackend`).
15-
/// - `dims1`: The static rank of the first tensor (e.g., `2`, `3`, etc.).
14+
/// - `Backend`: The backend to use
15+
/// - `RANK_A`: The static rank of the first tensor (e.g., `2`, `3`, etc.).
1616
///
1717
/// - `b`: Identifier for the second tensor variable (e.g., `b`).
18-
/// - `backend`: The backend type used in the second tensor (must match `backend` for correctness).
19-
/// - `dims2`: The static rank of the second tensor.
20-
///
21-
/// # Expansion
22-
/// Expands to:
23-
/// ```rust
24-
/// {
25-
/// const N: usize = max(dims1, dims2);
26-
/// broadcast::<B, N, dims1, dims2>(a, b)
27-
/// }
28-
/// ```
18+
/// - `RANK_B`: The static rank of the second tensor.
2919
///
3020
/// # Example
3121
/// ```rust
32-
/// let a: Tensor<MyBackend, 2> = ...;
33-
/// let b: Tensor<MyBackend, 4> = ...;
22+
/// let device = &NdArrayDevice::default();
23+
/// type B = NdArray<f32>;
3424
///
35-
/// let result = broadcast!(
36-
/// a: Tensor<MyBackend, 2>,
37-
/// b: Tensor<MyBackend, 4>
38-
/// );
39-
/// // Expands to: broadcast::<MyBackend, 4, 2, 4>(a, b)
40-
/// ```
25+
/// let a = Tensor::<B, 3>::from_data(
26+
/// [
27+
/// [[2, 8, 7, 2], [9, 14, 13, 12], [9, 14, 13, 12]],
28+
/// [[2, 8, 7, 2], [9, 14, 13, 12], [9, 14, 13, 12]],
29+
/// ],
30+
/// device,
31+
/// );
4132
///
33+
/// let b = Tensor::<B, 2>::from_data([[4, 11, 10, 5]], device);
34+
///
35+
/// let (a, b) = broadcast!(a:Tensor<B, 3>, b:Tensor<2>);
36+
///
37+
/// let a_add_b = a.add(b);
38+
///
39+
/// // Output:
40+
/// // Tensor {
41+
/// // data:
42+
/// // [[[ 6.0, 19.0, 17.0, 7.0],
43+
/// // [13.0, 25.0, 23.0, 17.0],
44+
/// // [13.0, 25.0, 23.0, 17.0]],
45+
/// // [[ 6.0, 19.0, 17.0, 7.0],
46+
/// // [13.0, 25.0, 23.0, 17.0],
47+
/// // [13.0, 25.0, 23.0, 17.0]]],
48+
/// // shape: [2, 3, 4],
49+
/// // device: Cpu,
50+
/// // backend: "ndarray",
51+
/// // kind: "Float",
52+
/// // dtype: "f32",
53+
/// // }
54+
/// ```
4255
#[macro_export]
4356
macro_rules! broadcast {
4457
(
45-
$a:ident : Tensor<$backend1:ty, $dims1:tt>,
46-
$b:ident : Tensor<$backend2:ty, $dims2:tt>
58+
$a:ident : Tensor<$backend:ty, $dims1:tt>,
59+
$b:ident : Tensor<$dims2:tt>
4760
) => {{
48-
use $crate::ops::broadcast_op;
61+
use $crate::broadcast::broadcast_op;
4962
const fn max(a: usize, b: usize) -> usize {
5063
if a > b { a } else { b }
5164
}
52-
53-
const N: usize = max($dims1, $dims2);
54-
55-
broadcast_op::<$backend1, N, $dims1, $dims2>($a, $b)
56-
}};
57-
}
5865

59-
#[macro_export]
60-
macro_rules! add_broadcast {
61-
(
62-
$a:ident : Tensor<$backend1:ty, $dims1:tt>,
63-
$b:ident : Tensor<$backend2:ty, $dims2:tt>
64-
) => {{
65-
use $crate::ops::broadcast_op;
66-
const fn max(a: usize, b: usize) -> usize {
67-
if a > b { a } else { b }
68-
}
69-
7066
const N: usize = max($dims1, $dims2);
7167

72-
let (a,b) = broadcast_op::<$backend1, N, $dims1, $dims2>($a, $b);
73-
a.add(b)
68+
broadcast_op::<$backend, N, $dims1, $dims2>(&$a, &$b)
7469
}};
7570
}
7671

7772
pub fn broadcast_op<B: Backend, const N: usize, const DA: usize, const DB: usize>(
78-
a: Tensor<B, DA>,
79-
b: Tensor<B, DB>,
73+
a: &Tensor<B, DA>,
74+
b: &Tensor<B, DB>,
8075
) -> (Tensor<B, N>, Tensor<B, N>) {
8176
// pad left with 1s
8277

83-
let a = a.unsqueeze::<N>();
84-
let b = b.unsqueeze::<N>();
78+
let a = a.clone().unsqueeze::<N>();
79+
let b = b.clone().unsqueeze::<N>();
8580

8681
let b_shape = b.shape().dims::<N>();
8782

@@ -129,8 +124,8 @@ pub fn broadcast_op<B: Backend, const N: usize, const DA: usize, const DB: usize
129124

130125
#[cfg(test)]
131126
mod tests {
132-
use burn::backend::ndarray::{NdArray, NdArrayDevice};
133127
use super::*;
128+
use burn::backend::ndarray::{NdArray, NdArrayDevice};
134129

135130
#[test]
136131
fn test_broadcast_multi_dims() {
@@ -140,7 +135,7 @@ mod tests {
140135
let a = Tensor::<B, 6>::empty([7, 6, 2, 3, 1, 9], device);
141136
let b = Tensor::<B, 4>::empty([2, 1, 7, 1], device);
142137

143-
let (a, b) = broadcast!(a: Tensor<B, 6>, b: Tensor<B, 4>);
138+
let (a, b) = broadcast!(a: Tensor<B, 6>, b: Tensor<4>);
144139

145140
assert_eq!(a.shape(), b.shape());
146141
}
@@ -160,14 +155,21 @@ mod tests {
160155

161156
let b = Tensor::<B, 2>::from_data([[4, 11, 10, 5]], device);
162157

163-
let (a, b) = broadcast!(a:Tensor<B, 3>, b:Tensor<B, 2>);
164-
158+
let (a, b) = broadcast!(a:Tensor<B, 3>, b:Tensor<2>);
165159
let a_add_b = a.add(b);
166160

167161
Tensor::<B, 3>::from_data(
168162
[
169-
[[6, 19, 17, 7], [13, 25, 23, 17], [13, 25, 23, 17]],
170-
[[6, 19, 17, 7], [13, 25, 23, 17], [13, 25, 23, 17]],
163+
[
164+
[6.0, 19.0, 17.0, 7.0],
165+
[13.0, 25.0, 23.0, 17.0],
166+
[13.0, 25.0, 23.0, 17.0],
167+
],
168+
[
169+
[6.0, 19.0, 17.0, 7.0],
170+
[13.0, 25.0, 23.0, 17.0],
171+
[13.0, 25.0, 23.0, 17.0],
172+
],
171173
],
172174
device,
173175
)
@@ -181,13 +183,10 @@ mod tests {
181183
type B = NdArray<f32>;
182184

183185
let a = Tensor::<B, 1>::from_data([3.0, 2.0, 6.0, 3.0], device);
184-
185186
let b = Tensor::<B, 1>::from_data([1.0, 0.5, 4.0, 7.0], device);
186-
187187
let a = a.reshape([-1, 1]);
188188

189-
let (a, b) = broadcast!(a:Tensor<B, 2>, b:Tensor<B, 1>);
190-
189+
let (a, b) = broadcast!(a:Tensor<B, 2>, b:Tensor<1>);
191190
let max_a_b = a.max_pair(b);
192191

193192
Tensor::<B, 2>::from_data(
@@ -209,12 +208,10 @@ mod tests {
209208
type B = NdArray<f32>;
210209

211210
let a = Tensor::<B, 1>::from_data([1.1, 2.2, 3.3], device);
212-
213211
let b = Tensor::<B, 1>::from_data([4.0, 5.0, 6.0, 7.0], device);
214-
215212
let a = a.reshape([-1, 1]);
216213

217-
let (a, b) = broadcast!(a:Tensor<B, 2>, b:Tensor<B, 1>);
214+
let (a, b) = broadcast!(a:Tensor<B, 2>, b:Tensor<1>);
218215
let add_a_b = a.add(b);
219216

220217
Tensor::<B, 2>::from_data(
@@ -232,7 +229,7 @@ mod tests {
232229
let b = Tensor::<B, 1>::from_data([4.0, 5.0, 6.0, 7.0], device);
233230

234231
let b = b.reshape([-1, 1]);
235-
let (a, b) = broadcast!(a:Tensor<B, 1>, b:Tensor<B, 2>);
232+
let (a, b) = broadcast!(a:Tensor<B, 1>, b:Tensor<2>);
236233
let add_a_b = a.add(b);
237234

238235
Tensor::<B, 2>::from_data(
@@ -254,13 +251,10 @@ mod tests {
254251
type B = NdArray<f32>;
255252

256253
let a = Tensor::<B, 1>::from_data([3.0, 2.0, 6.0, 3.0], device);
257-
258254
let b = Tensor::<B, 1>::from_data([1.0, 0.5, 4.0, 7.0, 8.0], device);
259255

260256
let b = b.reshape([-1, 1]);
261-
262-
let (a, b) = broadcast!(a:Tensor<B, 1>, b:Tensor<B, 2>);
263-
257+
let (a, b) = broadcast!(a:Tensor<B, 1>, b:Tensor<2>);
264258
let max_a_b = a.max_pair(b);
265259

266260
Tensor::<B, 2>::from_data(
@@ -292,7 +286,7 @@ mod tests {
292286
);
293287

294288
let b = Tensor::<B, 1>::from_data([1.0, 0.5, 4.0, 7.0], device);
295-
let (a, b) = broadcast!(a:Tensor<B, 2>, b:Tensor<B, 1>);
289+
let (a, b) = broadcast!(a:Tensor<B, 2>, b:Tensor<1>);
296290

297291
let add_a_b = a.add(b);
298292

0 commit comments

Comments
 (0)