1
1
use crate :: {
2
2
CubeBackend , CubeRuntime , FloatElement , IntElement ,
3
3
element:: BoolElement ,
4
+ execute_with_dtype,
4
5
kernel:: {
5
6
self ,
6
7
conv:: { ConvStrategy , ConvTranspose2dStrategy } ,
25
26
bias : Option < FloatTensor < Self > > ,
26
27
options : ConvOptions < 1 > ,
27
28
) -> FloatTensor < Self > {
28
- kernel:: conv:: conv :: < R , F , 1 > ( x, weight, bias, options, ConvStrategy :: default ( ) ) . unwrap ( )
29
+ execute_with_dtype ! (
30
+ float( x. dtype) ,
31
+ E ,
32
+ kernel:: conv:: conv:: <R , E , 1 >( x, weight, bias, options, ConvStrategy :: default ( ) )
33
+ . unwrap( )
34
+ )
29
35
}
30
36
31
37
fn conv2d (
34
40
bias : Option < FloatTensor < Self > > ,
35
41
options : ConvOptions < 2 > ,
36
42
) -> FloatTensor < Self > {
37
- kernel:: conv:: conv :: < R , F , 2 > ( x, weight, bias, options, ConvStrategy :: default ( ) ) . unwrap ( )
43
+ execute_with_dtype ! (
44
+ float( x. dtype) ,
45
+ E ,
46
+ kernel:: conv:: conv:: <R , E , 2 >( x, weight, bias, options, ConvStrategy :: default ( ) )
47
+ . unwrap( )
48
+ )
38
49
}
39
50
40
51
fn deform_conv2d (
45
56
bias : Option < FloatTensor < Self > > ,
46
57
options : DeformConvOptions < 2 > ,
47
58
) -> FloatTensor < Self > {
48
- kernel:: conv:: deform_conv2d :: < R , F > ( x, offset, weight, mask, bias, options) . unwrap ( )
59
+ execute_with_dtype ! (
60
+ float( x. dtype) ,
61
+ E ,
62
+ kernel:: conv:: deform_conv2d:: <R , E >( x, offset, weight, mask, bias, options) . unwrap( )
63
+ )
49
64
}
50
65
51
66
fn deform_conv2d_backward (
@@ -57,16 +72,19 @@ where
57
72
output_grad : FloatTensor < Self > ,
58
73
options : DeformConvOptions < 2 > ,
59
74
) -> DeformConv2dBackward < Self > {
60
- kernel:: conv:: deform_conv2d_backward :: < R , F , I , BT > (
61
- x,
62
- offset,
63
- weight,
64
- mask,
65
- bias,
66
- output_grad,
67
- options,
68
- )
69
- . unwrap ( )
75
+ execute_with_dtype ! ( float( x. dtype) , E , {
76
+ let ( x, o, w, m, b) = kernel:: conv:: deform_conv2d_backward:: <R , E , I , BT >(
77
+ x,
78
+ offset,
79
+ weight,
80
+ mask,
81
+ bias,
82
+ output_grad,
83
+ options,
84
+ )
85
+ . unwrap( ) ;
86
+ DeformConv2dBackward :: new( x, o, w, m, b)
87
+ } )
70
88
}
71
89
72
90
fn conv3d (
75
93
bias : Option < FloatTensor < Self > > ,
76
94
options : ConvOptions < 3 > ,
77
95
) -> FloatTensor < Self > {
78
- kernel:: conv:: conv :: < R , F , 3 > ( x, weight, bias, options, ConvStrategy :: Direct ) . unwrap ( )
96
+ execute_with_dtype ! (
97
+ float( x. dtype) ,
98
+ E ,
99
+ kernel:: conv:: conv:: <R , E , 3 >( x, weight, bias, options, ConvStrategy :: Direct ) . unwrap( )
100
+ )
79
101
}
80
102
81
103
fn conv_transpose2d (
@@ -84,14 +106,18 @@ where
84
106
bias : Option < FloatTensor < Self > > ,
85
107
options : ConvTransposeOptions < 2 > ,
86
108
) -> FloatTensor < Self > {
87
- kernel:: conv:: conv_transpose2d :: < R , F , I > (
88
- x,
89
- weight,
90
- bias,
91
- options,
92
- ConvTranspose2dStrategy :: default ( ) ,
109
+ execute_with_dtype ! (
110
+ float( x. dtype) ,
111
+ E ,
112
+ kernel:: conv:: conv_transpose2d:: <R , E , I >(
113
+ x,
114
+ weight,
115
+ bias,
116
+ options,
117
+ ConvTranspose2dStrategy :: default ( ) ,
118
+ )
119
+ . unwrap( )
93
120
)
94
- . unwrap ( )
95
121
}
96
122
97
123
fn conv_transpose3d (
@@ -100,7 +126,11 @@ where
100
126
bias : Option < FloatTensor < Self > > ,
101
127
options : ConvTransposeOptions < 3 > ,
102
128
) -> FloatTensor < Self > {
103
- kernel:: conv:: conv_transpose3d :: < R , F > ( x, weight, bias, options)
129
+ execute_with_dtype ! (
130
+ float( x. dtype) ,
131
+ E ,
132
+ kernel:: conv:: conv_transpose3d:: <R , E >( x, weight, bias, options)
133
+ )
104
134
}
105
135
106
136
fn avg_pool2d (
@@ -110,7 +140,11 @@ where
110
140
padding : [ usize ; 2 ] ,
111
141
count_include_pad : bool ,
112
142
) -> FloatTensor < Self > {
113
- kernel:: pool:: avg_pool2d :: < R , F > ( x, kernel_size, stride, padding, count_include_pad)
143
+ execute_with_dtype ! (
144
+ float( x. dtype) ,
145
+ E ,
146
+ kernel:: pool:: avg_pool2d:: <R , E >( x, kernel_size, stride, padding, count_include_pad)
147
+ )
114
148
}
115
149
116
150
fn avg_pool2d_backward (
@@ -121,13 +155,17 @@ where
121
155
padding : [ usize ; 2 ] ,
122
156
count_include_pad : bool ,
123
157
) -> FloatTensor < Self > {
124
- kernel:: pool:: avg_pool2d_backward :: < R , F > (
125
- x,
126
- grad,
127
- kernel_size,
128
- stride,
129
- padding,
130
- count_include_pad,
158
+ execute_with_dtype ! (
159
+ float( x. dtype) ,
160
+ E ,
161
+ kernel:: pool:: avg_pool2d_backward:: <R , E >(
162
+ x,
163
+ grad,
164
+ kernel_size,
165
+ stride,
166
+ padding,
167
+ count_include_pad,
168
+ )
131
169
)
132
170
}
133
171
@@ -138,7 +176,11 @@ where
138
176
padding : [ usize ; 2 ] ,
139
177
dilation : [ usize ; 2 ] ,
140
178
) -> FloatTensor < Self > {
141
- kernel:: pool:: max_pool2d :: < R , F > ( x, kernel_size, stride, padding, dilation)
179
+ execute_with_dtype ! (
180
+ float( x. dtype) ,
181
+ E ,
182
+ kernel:: pool:: max_pool2d:: <R , E >( x, kernel_size, stride, padding, dilation)
183
+ )
142
184
}
143
185
144
186
fn max_pool2d_with_indices (
@@ -148,15 +190,17 @@ where
148
190
padding : [ usize ; 2 ] ,
149
191
dilation : [ usize ; 2 ] ,
150
192
) -> MaxPool2dWithIndices < Self > {
151
- let ( output, indices) = kernel:: pool:: max_pool2d_with_indices :: < R , F , I > (
152
- x,
153
- kernel_size,
154
- stride,
155
- padding,
156
- dilation,
157
- ) ;
193
+ execute_with_dtype ! ( float( x. dtype) , E , {
194
+ let ( output, indices) = kernel:: pool:: max_pool2d_with_indices:: <R , E , I >(
195
+ x,
196
+ kernel_size,
197
+ stride,
198
+ padding,
199
+ dilation,
200
+ ) ;
158
201
159
- MaxPool2dWithIndices :: new ( output, indices)
202
+ MaxPool2dWithIndices :: new( output, indices)
203
+ } )
160
204
}
161
205
162
206
fn max_pool2d_with_indices_backward (
@@ -168,34 +212,54 @@ where
168
212
output_grad : FloatTensor < Self > ,
169
213
indices : IntTensor < Self > ,
170
214
) -> MaxPool2dBackward < Self > {
171
- MaxPool2dBackward :: new ( kernel:: pool:: max_pool2d_with_indices_backward :: < R , F , I > (
172
- x,
173
- output_grad,
174
- indices,
175
- kernel_size,
176
- stride,
177
- padding,
178
- dilation,
179
- ) )
215
+ execute_with_dtype ! (
216
+ int( indices. dtype) ,
217
+ I ,
218
+ execute_with_dtype!(
219
+ float( x. dtype) ,
220
+ E ,
221
+ MaxPool2dBackward :: new( kernel:: pool:: max_pool2d_with_indices_backward:: <R , E , I >(
222
+ x,
223
+ output_grad,
224
+ indices,
225
+ kernel_size,
226
+ stride,
227
+ padding,
228
+ dilation,
229
+ ) )
230
+ )
231
+ )
180
232
}
181
233
182
234
fn adaptive_avg_pool2d ( x : FloatTensor < Self > , output_size : [ usize ; 2 ] ) -> FloatTensor < Self > {
183
- kernel:: pool:: adaptive_avg_pool2d :: < R , F > ( x, output_size)
235
+ execute_with_dtype ! (
236
+ float( x. dtype) ,
237
+ E ,
238
+ kernel:: pool:: adaptive_avg_pool2d:: <R , E >( x, output_size)
239
+ )
184
240
}
185
241
186
242
fn adaptive_avg_pool2d_backward (
187
243
x : FloatTensor < Self > ,
188
244
grad : FloatTensor < Self > ,
189
245
) -> FloatTensor < Self > {
190
- kernel:: pool:: adaptive_avg_pool2d_backward :: < R , F > ( x, grad)
246
+ execute_with_dtype ! (
247
+ float( x. dtype) ,
248
+ E ,
249
+ kernel:: pool:: adaptive_avg_pool2d_backward:: <R , E >( x, grad)
250
+ )
191
251
}
192
252
193
253
fn interpolate (
194
254
x : FloatTensor < Self > ,
195
255
output_size : [ usize ; 2 ] ,
196
256
options : InterpolateOptions ,
197
257
) -> FloatTensor < Self > {
198
- kernel:: interpolate:: interpolate :: < R , F > ( x, output_size, options)
258
+ execute_with_dtype ! (
259
+ float( x. dtype) ,
260
+ E ,
261
+ kernel:: interpolate:: interpolate:: <R , E >( x, output_size, options)
262
+ )
199
263
}
200
264
201
265
fn interpolate_backward (
@@ -204,6 +268,10 @@ where
204
268
output_size : [ usize ; 2 ] ,
205
269
options : InterpolateOptions ,
206
270
) -> FloatTensor < Self > {
207
- kernel:: interpolate:: interpolate_backward :: < R , F > ( x, grad, output_size, options)
271
+ execute_with_dtype ! (
272
+ float( x. dtype) ,
273
+ E ,
274
+ kernel:: interpolate:: interpolate_backward:: <R , E >( x, grad, output_size, options)
275
+ )
208
276
}
209
277
}
0 commit comments