@@ -5,83 +5,78 @@ use burn::{prelude::Backend, tensor::Tensor};
5
5
/// ```ignore
6
6
/// broadcast!(
7
7
/// a: Tensor<Backend, RANK_A>,
8
- /// b: Tensor<Backend, RANK_B>
8
+ /// b: Tensor<RANK_B>
9
9
/// )
10
10
/// ```
11
11
///
12
12
/// # Parameters
13
13
/// - `a`: Identifier for the first tensor variable (e.g., `a`).
14
- /// - `backend `: The backend type used in the first tensor (e.g., `MyBackend`).
15
- /// - `dims1 `: The static rank of the first tensor (e.g., `2`, `3`, etc.).
14
+ /// - `Backend `: The backend to use
15
+ /// - `RANK_A `: The static rank of the first tensor (e.g., `2`, `3`, etc.).
16
16
///
17
17
/// - `b`: Identifier for the second tensor variable (e.g., `b`).
18
- /// - `backend`: The backend type used in the second tensor (must match `backend` for correctness).
19
- /// - `dims2`: The static rank of the second tensor.
20
- ///
21
- /// # Expansion
22
- /// Expands to:
23
- /// ```rust
24
- /// {
25
- /// const N: usize = max(dims1, dims2);
26
- /// broadcast::<B, N, dims1, dims2>(a, b)
27
- /// }
28
- /// ```
18
+ /// - `RANK_B`: The static rank of the second tensor.
29
19
///
30
20
/// # Example
31
21
/// ```rust
32
- /// let a: Tensor<MyBackend, 2> = ... ;
33
- /// let b: Tensor<MyBackend, 4> = ... ;
22
+ /// let device = &NdArrayDevice::default() ;
23
+ /// type B = NdArray<f32> ;
34
24
///
35
- /// let result = broadcast!(
36
- /// a: Tensor<MyBackend, 2>,
37
- /// b: Tensor<MyBackend, 4>
38
- /// );
39
- /// // Expands to: broadcast::<MyBackend, 4, 2, 4>(a, b)
40
- /// ```
25
+ /// let a = Tensor::<B, 3>::from_data(
26
+ /// [
27
+ /// [[2, 8, 7, 2], [9, 14, 13, 12], [9, 14, 13, 12]],
28
+ /// [[2, 8, 7, 2], [9, 14, 13, 12], [9, 14, 13, 12]],
29
+ /// ],
30
+ /// device,
31
+ /// );
41
32
///
33
+ /// let b = Tensor::<B, 2>::from_data([[4, 11, 10, 5]], device);
34
+ ///
35
+ /// let (a, b) = broadcast!(a:Tensor<B, 3>, b:Tensor<2>);
36
+ ///
37
+ /// let a_add_b = a.add(b);
38
+ ///
39
+ /// // Output:
40
+ /// // Tensor {
41
+ /// // data:
42
+ /// // [[[ 6.0, 19.0, 17.0, 7.0],
43
+ /// // [13.0, 25.0, 23.0, 17.0],
44
+ /// // [13.0, 25.0, 23.0, 17.0]],
45
+ /// // [[ 6.0, 19.0, 17.0, 7.0],
46
+ /// // [13.0, 25.0, 23.0, 17.0],
47
+ /// // [13.0, 25.0, 23.0, 17.0]]],
48
+ /// // shape: [2, 3, 4],
49
+ /// // device: Cpu,
50
+ /// // backend: "ndarray",
51
+ /// // kind: "Float",
52
+ /// // dtype: "f32",
53
+ /// // }
54
+ /// ```
42
55
#[ macro_export]
43
56
macro_rules! broadcast {
44
57
(
45
- $a: ident : Tensor <$backend1 : ty, $dims1: tt>,
46
- $b: ident : Tensor <$backend2 : ty , $ dims2: tt>
58
+ $a: ident : Tensor <$backend : ty, $dims1: tt>,
59
+ $b: ident : Tensor <$dims2: tt>
47
60
) => { {
48
- use $crate:: ops :: broadcast_op;
61
+ use $crate:: broadcast :: broadcast_op;
49
62
const fn max( a: usize , b: usize ) -> usize {
50
63
if a > b { a } else { b }
51
64
}
52
-
53
- const N : usize = max( $dims1, $dims2) ;
54
-
55
- broadcast_op:: <$backend1, N , $dims1, $dims2>( $a, $b)
56
- } } ;
57
- }
58
65
59
- #[ macro_export]
60
- macro_rules! add_broadcast {
61
- (
62
- $a: ident : Tensor <$backend1: ty, $dims1: tt>,
63
- $b: ident : Tensor <$backend2: ty, $dims2: tt>
64
- ) => { {
65
- use $crate:: ops:: broadcast_op;
66
- const fn max( a: usize , b: usize ) -> usize {
67
- if a > b { a } else { b }
68
- }
69
-
70
66
const N : usize = max( $dims1, $dims2) ;
71
67
72
- let ( a, b) = broadcast_op:: <$backend1, N , $dims1, $dims2>( $a, $b) ;
73
- a. add( b)
68
+ broadcast_op:: <$backend, N , $dims1, $dims2>( & $a, & $b)
74
69
} } ;
75
70
}
76
71
77
72
pub fn broadcast_op < B : Backend , const N : usize , const DA : usize , const DB : usize > (
78
- a : Tensor < B , DA > ,
79
- b : Tensor < B , DB > ,
73
+ a : & Tensor < B , DA > ,
74
+ b : & Tensor < B , DB > ,
80
75
) -> ( Tensor < B , N > , Tensor < B , N > ) {
81
76
// pad left with 1s
82
77
83
- let a = a. unsqueeze :: < N > ( ) ;
84
- let b = b. unsqueeze :: < N > ( ) ;
78
+ let a = a. clone ( ) . unsqueeze :: < N > ( ) ;
79
+ let b = b. clone ( ) . unsqueeze :: < N > ( ) ;
85
80
86
81
let b_shape = b. shape ( ) . dims :: < N > ( ) ;
87
82
@@ -129,8 +124,8 @@ pub fn broadcast_op<B: Backend, const N: usize, const DA: usize, const DB: usize
129
124
130
125
#[ cfg( test) ]
131
126
mod tests {
132
- use burn:: backend:: ndarray:: { NdArray , NdArrayDevice } ;
133
127
use super :: * ;
128
+ use burn:: backend:: ndarray:: { NdArray , NdArrayDevice } ;
134
129
135
130
#[ test]
136
131
fn test_broadcast_multi_dims ( ) {
@@ -140,7 +135,7 @@ mod tests {
140
135
let a = Tensor :: < B , 6 > :: empty ( [ 7 , 6 , 2 , 3 , 1 , 9 ] , device) ;
141
136
let b = Tensor :: < B , 4 > :: empty ( [ 2 , 1 , 7 , 1 ] , device) ;
142
137
143
- let ( a, b) = broadcast ! ( a: Tensor <B , 6 >, b: Tensor <B , 4 >) ;
138
+ let ( a, b) = broadcast ! ( a: Tensor <B , 6 >, b: Tensor <4 >) ;
144
139
145
140
assert_eq ! ( a. shape( ) , b. shape( ) ) ;
146
141
}
@@ -160,14 +155,21 @@ mod tests {
160
155
161
156
let b = Tensor :: < B , 2 > :: from_data ( [ [ 4 , 11 , 10 , 5 ] ] , device) ;
162
157
163
- let ( a, b) = broadcast ! ( a: Tensor <B , 3 >, b: Tensor <B , 2 >) ;
164
-
158
+ let ( a, b) = broadcast ! ( a: Tensor <B , 3 >, b: Tensor <2 >) ;
165
159
let a_add_b = a. add ( b) ;
166
160
167
161
Tensor :: < B , 3 > :: from_data (
168
162
[
169
- [ [ 6 , 19 , 17 , 7 ] , [ 13 , 25 , 23 , 17 ] , [ 13 , 25 , 23 , 17 ] ] ,
170
- [ [ 6 , 19 , 17 , 7 ] , [ 13 , 25 , 23 , 17 ] , [ 13 , 25 , 23 , 17 ] ] ,
163
+ [
164
+ [ 6.0 , 19.0 , 17.0 , 7.0 ] ,
165
+ [ 13.0 , 25.0 , 23.0 , 17.0 ] ,
166
+ [ 13.0 , 25.0 , 23.0 , 17.0 ] ,
167
+ ] ,
168
+ [
169
+ [ 6.0 , 19.0 , 17.0 , 7.0 ] ,
170
+ [ 13.0 , 25.0 , 23.0 , 17.0 ] ,
171
+ [ 13.0 , 25.0 , 23.0 , 17.0 ] ,
172
+ ] ,
171
173
] ,
172
174
device,
173
175
)
@@ -181,13 +183,10 @@ mod tests {
181
183
type B = NdArray < f32 > ;
182
184
183
185
let a = Tensor :: < B , 1 > :: from_data ( [ 3.0 , 2.0 , 6.0 , 3.0 ] , device) ;
184
-
185
186
let b = Tensor :: < B , 1 > :: from_data ( [ 1.0 , 0.5 , 4.0 , 7.0 ] , device) ;
186
-
187
187
let a = a. reshape ( [ -1 , 1 ] ) ;
188
188
189
- let ( a, b) = broadcast ! ( a: Tensor <B , 2 >, b: Tensor <B , 1 >) ;
190
-
189
+ let ( a, b) = broadcast ! ( a: Tensor <B , 2 >, b: Tensor <1 >) ;
191
190
let max_a_b = a. max_pair ( b) ;
192
191
193
192
Tensor :: < B , 2 > :: from_data (
@@ -209,12 +208,10 @@ mod tests {
209
208
type B = NdArray < f32 > ;
210
209
211
210
let a = Tensor :: < B , 1 > :: from_data ( [ 1.1 , 2.2 , 3.3 ] , device) ;
212
-
213
211
let b = Tensor :: < B , 1 > :: from_data ( [ 4.0 , 5.0 , 6.0 , 7.0 ] , device) ;
214
-
215
212
let a = a. reshape ( [ -1 , 1 ] ) ;
216
213
217
- let ( a, b) = broadcast ! ( a: Tensor <B , 2 >, b: Tensor <B , 1 >) ;
214
+ let ( a, b) = broadcast ! ( a: Tensor <B , 2 >, b: Tensor <1 >) ;
218
215
let add_a_b = a. add ( b) ;
219
216
220
217
Tensor :: < B , 2 > :: from_data (
@@ -232,7 +229,7 @@ mod tests {
232
229
let b = Tensor :: < B , 1 > :: from_data ( [ 4.0 , 5.0 , 6.0 , 7.0 ] , device) ;
233
230
234
231
let b = b. reshape ( [ -1 , 1 ] ) ;
235
- let ( a, b) = broadcast ! ( a: Tensor <B , 1 >, b: Tensor <B , 2 >) ;
232
+ let ( a, b) = broadcast ! ( a: Tensor <B , 1 >, b: Tensor <2 >) ;
236
233
let add_a_b = a. add ( b) ;
237
234
238
235
Tensor :: < B , 2 > :: from_data (
@@ -254,13 +251,10 @@ mod tests {
254
251
type B = NdArray < f32 > ;
255
252
256
253
let a = Tensor :: < B , 1 > :: from_data ( [ 3.0 , 2.0 , 6.0 , 3.0 ] , device) ;
257
-
258
254
let b = Tensor :: < B , 1 > :: from_data ( [ 1.0 , 0.5 , 4.0 , 7.0 , 8.0 ] , device) ;
259
255
260
256
let b = b. reshape ( [ -1 , 1 ] ) ;
261
-
262
- let ( a, b) = broadcast ! ( a: Tensor <B , 1 >, b: Tensor <B , 2 >) ;
263
-
257
+ let ( a, b) = broadcast ! ( a: Tensor <B , 1 >, b: Tensor <2 >) ;
264
258
let max_a_b = a. max_pair ( b) ;
265
259
266
260
Tensor :: < B , 2 > :: from_data (
@@ -292,7 +286,7 @@ mod tests {
292
286
) ;
293
287
294
288
let b = Tensor :: < B , 1 > :: from_data ( [ 1.0 , 0.5 , 4.0 , 7.0 ] , device) ;
295
- let ( a, b) = broadcast ! ( a: Tensor <B , 2 >, b: Tensor <B , 1 >) ;
289
+ let ( a, b) = broadcast ! ( a: Tensor <B , 2 >, b: Tensor <1 >) ;
296
290
297
291
let add_a_b = a. add ( b) ;
298
292
0 commit comments