@@ -51,7 +51,7 @@ pub(super) trait ConvTrans2DKernel<E: Dtype>: Storage<E> {
51
51
) -> Result < ( ) , Error > ;
52
52
}
53
53
54
- pub trait TryConvTrans2D < Stride , Padding , Dilation , Groups > : Sized {
54
+ pub trait TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > : Sized {
55
55
type Convolved ;
56
56
57
57
/// Applies a 2D convolution to the input tensor.
@@ -61,8 +61,9 @@ pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
61
61
padding : Padding ,
62
62
dilation : Dilation ,
63
63
groups : Groups ,
64
+ output_padding : OutputPadding ,
64
65
) -> Self :: Convolved {
65
- self . try_convtrans2d ( stride, padding, dilation, groups)
66
+ self . try_convtrans2d ( stride, padding, dilation, groups, output_padding )
66
67
. unwrap ( )
67
68
}
68
69
@@ -73,6 +74,7 @@ pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
73
74
padding : Padding ,
74
75
dilation : Dilation ,
75
76
groups : Groups ,
77
+ output_padding : OutputPadding ,
76
78
) -> Result < Self :: Convolved , Error > ;
77
79
}
78
80
@@ -82,27 +84,31 @@ impl<
82
84
const PADDING : usize ,
83
85
const DILATION : usize ,
84
86
Groups : Dim ,
87
+ const OUTPUT_PADDING : usize ,
85
88
const DIM : usize ,
86
- > TryConvTrans2D < Const < STRIDE > , Const < PADDING > , Const < DILATION > , Groups >
89
+ > TryConvTrans2D < Const < STRIDE > , Const < PADDING > , Const < DILATION > , Groups , Const < OUTPUT_PADDING > >
87
90
for ( Const < DIM > , Const < KERNEL > )
88
91
where
89
- Const < { ( DIM - 1 ) * STRIDE - 2 * PADDING + DILATION * ( KERNEL - 1 ) + 1 } > : Sized ,
92
+ Const < { ( DIM - 1 ) * STRIDE - 2 * PADDING + DILATION * ( KERNEL - 1 ) + 1 + OUTPUT_PADDING } > :
93
+ Sized ,
90
94
{
91
- type Convolved = Const < { ( DIM - 1 ) * STRIDE - 2 * PADDING + DILATION * ( KERNEL - 1 ) + 1 } > ;
95
+ type Convolved =
96
+ Const < { ( DIM - 1 ) * STRIDE - 2 * PADDING + DILATION * ( KERNEL - 1 ) + 1 + OUTPUT_PADDING } > ;
92
97
93
98
fn try_convtrans2d (
94
99
self ,
95
100
_: Const < STRIDE > ,
96
101
_: Const < PADDING > ,
97
102
_: Const < DILATION > ,
98
103
_: Groups ,
104
+ _: Const < OUTPUT_PADDING > ,
99
105
) -> Result < Self :: Convolved , Error > {
100
106
Ok ( Const )
101
107
}
102
108
}
103
109
104
- impl < Kernel : Dim , Stride : Dim , Padding : Dim , Dilation : Dim , Groups : Dim >
105
- TryConvTrans2D < Stride , Padding , Dilation , Groups > for ( usize , Kernel )
110
+ impl < Kernel : Dim , Stride : Dim , Padding : Dim , Dilation : Dim , Groups : Dim , OutputPadding : Dim >
111
+ TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > for ( usize , Kernel )
106
112
{
107
113
type Convolved = usize ;
108
114
@@ -112,18 +118,33 @@ impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim>
112
118
padding : Padding ,
113
119
dilation : Dilation ,
114
120
_: Groups ,
121
+ output_padding : OutputPadding ,
115
122
) -> Result < Self :: Convolved , Error > {
116
123
let ( dim, kernel) = self ;
117
- Ok (
118
- ( ( dim - 1 ) * stride. size ( ) + dilation. size ( ) * ( kernel. size ( ) - 1 ) + 1 )
119
- . checked_sub ( 2 * padding. size ( ) )
120
- . unwrap ( ) ,
121
- )
124
+ Ok ( ( ( dim - 1 ) * stride. size ( )
125
+ + dilation. size ( ) * ( kernel. size ( ) - 1 )
126
+ + 1
127
+ + output_padding. size ( ) )
128
+ . checked_sub ( 2 * padding. size ( ) )
129
+ . unwrap ( ) )
122
130
}
123
131
}
124
132
125
- impl < InpChan , OutChanOverGroups , Kernel , Stride , Padding , Dilation , Groups , H , W , E , D , T >
126
- TryConvTrans2D < Stride , Padding , Dilation , Groups >
133
+ impl <
134
+ InpChan ,
135
+ OutChanOverGroups ,
136
+ Kernel ,
137
+ Stride ,
138
+ Padding ,
139
+ Dilation ,
140
+ Groups ,
141
+ OutputPadding ,
142
+ H ,
143
+ W ,
144
+ E ,
145
+ D ,
146
+ T ,
147
+ > TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding >
127
148
for (
128
149
Tensor < ( InpChan , H , W ) , E , D , T > ,
129
150
Tensor < ( InpChan , OutChanOverGroups , Kernel , Kernel ) , E , D > ,
@@ -136,23 +157,26 @@ where
136
157
Padding : Dim ,
137
158
Dilation : Dim ,
138
159
Groups : Dim ,
160
+ OutputPadding : Dim ,
139
161
H : Dim ,
140
162
W : Dim ,
141
163
E : Dtype ,
142
164
D : ConvTrans2DKernel < E > + crate :: tensor_ops:: reshape_to:: ReshapeKernel < E > ,
143
165
T : Tape < E , D > ,
144
166
OutChanOverGroups : std:: ops:: Mul < Groups > ,
145
167
<OutChanOverGroups as std:: ops:: Mul < Groups > >:: Output : Dim ,
146
- ( H , Kernel ) : TryConvTrans2D < Stride , Padding , Dilation , Groups > ,
147
- ( W , Kernel ) : TryConvTrans2D < Stride , Padding , Dilation , Groups > ,
148
- <( H , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups > >:: Convolved : Dim ,
149
- <( W , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups > >:: Convolved : Dim ,
168
+ ( H , Kernel ) : TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > ,
169
+ ( W , Kernel ) : TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > ,
170
+ <( H , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > >:: Convolved :
171
+ Dim ,
172
+ <( W , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > >:: Convolved :
173
+ Dim ,
150
174
{
151
175
type Convolved = Tensor <
152
176
(
153
177
<OutChanOverGroups as std:: ops:: Mul < Groups > >:: Output ,
154
- <( H , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups > >:: Convolved ,
155
- <( W , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups > >:: Convolved ,
178
+ <( H , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > >:: Convolved ,
179
+ <( W , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > >:: Convolved ,
156
180
) ,
157
181
E ,
158
182
D ,
@@ -165,11 +189,13 @@ where
165
189
padding : Padding ,
166
190
dilation : Dilation ,
167
191
groups : Groups ,
192
+ output_padding : OutputPadding ,
168
193
) -> Result < Self :: Convolved , Error > {
169
194
let ( img, filters) = self ;
170
195
let ( inp_chan, h, w) = img. shape ;
171
196
let img = img. try_reshape_like ( & ( Const :: < 1 > , inp_chan, h, w) ) ?;
172
- let out = ( img, filters) . try_convtrans2d ( stride, padding, dilation, groups) ?;
197
+ let out =
198
+ ( img, filters) . try_convtrans2d ( stride, padding, dilation, groups, output_padding) ?;
173
199
let ( _, out_chan, out_h, out_w) = out. shape ;
174
200
out. try_reshape_like ( & ( out_chan, out_h, out_w) )
175
201
}
@@ -182,13 +208,14 @@ impl<
182
208
Padding ,
183
209
Dilation ,
184
210
Groups ,
211
+ OutputPadding ,
185
212
Batch ,
186
213
H ,
187
214
W ,
188
215
E ,
189
216
D ,
190
217
T ,
191
- > TryConvTrans2D < Stride , Padding , Dilation , Groups >
218
+ > TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding >
192
219
for (
193
220
Tensor < ( Batch , InpChan , H , W ) , E , D , T > ,
194
221
Tensor < ( InpChan , OutChanOverGroups , Kernel , Kernel ) , E , D > ,
@@ -201,6 +228,7 @@ where
201
228
Padding : Dim ,
202
229
Dilation : Dim ,
203
230
Groups : Dim ,
231
+ OutputPadding : Dim ,
204
232
Batch : Dim ,
205
233
H : Dim ,
206
234
W : Dim ,
@@ -209,17 +237,19 @@ where
209
237
T : Tape < E , D > ,
210
238
OutChanOverGroups : std:: ops:: Mul < Groups > ,
211
239
<OutChanOverGroups as std:: ops:: Mul < Groups > >:: Output : Dim ,
212
- ( H , Kernel ) : TryConvTrans2D < Stride , Padding , Dilation , Groups > ,
213
- ( W , Kernel ) : TryConvTrans2D < Stride , Padding , Dilation , Groups > ,
214
- <( H , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups > >:: Convolved : Dim ,
215
- <( W , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups > >:: Convolved : Dim ,
240
+ ( H , Kernel ) : TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > ,
241
+ ( W , Kernel ) : TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > ,
242
+ <( H , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > >:: Convolved :
243
+ Dim ,
244
+ <( W , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > >:: Convolved :
245
+ Dim ,
216
246
{
217
247
type Convolved = Tensor <
218
248
(
219
249
Batch ,
220
250
<OutChanOverGroups as std:: ops:: Mul < Groups > >:: Output ,
221
- <( H , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups > >:: Convolved ,
222
- <( W , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups > >:: Convolved ,
251
+ <( H , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > >:: Convolved ,
252
+ <( W , Kernel ) as TryConvTrans2D < Stride , Padding , Dilation , Groups , OutputPadding > >:: Convolved ,
223
253
) ,
224
254
E ,
225
255
D ,
@@ -232,6 +262,7 @@ where
232
262
padding : Padding ,
233
263
dilation : Dilation ,
234
264
groups : Groups ,
265
+ output_padding : OutputPadding ,
235
266
) -> Result < Self :: Convolved , Error > {
236
267
let ( img, filters) = self ;
237
268
assert_eq ! ( img. shape. 1 , filters. shape. 0 ) ;
@@ -242,8 +273,8 @@ where
242
273
if img. strides != img. shape . strides ( ) || filters. strides != filters. shape . strides ( ) {
243
274
panic ! ( "Image & filter inputs to conv2d must be contiguous" ) ;
244
275
}
245
- let h_out = ( h, kernel) . convtrans2d ( stride, padding, dilation, groups) ;
246
- let w_out = ( w, kernel) . convtrans2d ( stride, padding, dilation, groups) ;
276
+ let h_out = ( h, kernel) . convtrans2d ( stride, padding, dilation, groups, output_padding ) ;
277
+ let w_out = ( w, kernel) . convtrans2d ( stride, padding, dilation, groups, output_padding ) ;
247
278
let op = ConvTrans2DOp {
248
279
stride : stride. size ( ) ,
249
280
padding : padding. size ( ) ,
0 commit comments