Skip to content

Commit e81228c

Browse files
committed
Adds OUTPUT_PADDING to ConvTrans2D
- Draft state. - Unsure if correct, but a very simple and quick test gives the same result from pytorch. - Note: Tensorflow result differs, both from dfdx and from pytorch. Reference pytorch test: ```python import torch x = np.array([[[[0.1, 0.7], [0.3, 0.4]]]]) w = np.array([[[[-0.1, -0.3, 0.7], [0.8, -0.2, 0.1], [0.3, 0.4, -0.5]]]]) a = torch.nn.ConvTranspose2d(output_padding=0, in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, bias = False) b = torch.nn.ConvTranspose2d(output_padding=1, in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, bias = False) x = torch.from_numpy(x).float() w0 = torch.from_numpy(w).float() with torch.no_grad(): a.weight = torch.nn.Parameter(w0) b.weight = torch.nn.Parameter(w0) ya = a(x) yb = b(x) print(ya.size()) # torch.Size([1, 1, 3, 3]) print(yb.size()) # torch.Size([1, 1, 4, 4]) print(ya) print(yb) ```
1 parent 1175903 commit e81228c

File tree

3 files changed

+132
-55
lines changed

3 files changed

+132
-55
lines changed

dfdx-core/src/tensor_ops/convtrans2d/mod.rs

+61-30
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub(super) trait ConvTrans2DKernel<E: Dtype>: Storage<E> {
5151
) -> Result<(), Error>;
5252
}
5353

54-
pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
54+
pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>: Sized {
5555
type Convolved;
5656

5757
/// Applies a 2D convolution to the input tensor.
@@ -61,8 +61,9 @@ pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
6161
padding: Padding,
6262
dilation: Dilation,
6363
groups: Groups,
64+
output_padding: OutputPadding,
6465
) -> Self::Convolved {
65-
self.try_convtrans2d(stride, padding, dilation, groups)
66+
self.try_convtrans2d(stride, padding, dilation, groups, output_padding)
6667
.unwrap()
6768
}
6869

@@ -73,6 +74,7 @@ pub trait TryConvTrans2D<Stride, Padding, Dilation, Groups>: Sized {
7374
padding: Padding,
7475
dilation: Dilation,
7576
groups: Groups,
77+
output_padding: OutputPadding,
7678
) -> Result<Self::Convolved, Error>;
7779
}
7880

@@ -82,27 +84,31 @@ impl<
8284
const PADDING: usize,
8385
const DILATION: usize,
8486
Groups: Dim,
87+
const OUTPUT_PADDING: usize,
8588
const DIM: usize,
86-
> TryConvTrans2D<Const<STRIDE>, Const<PADDING>, Const<DILATION>, Groups>
89+
> TryConvTrans2D<Const<STRIDE>, Const<PADDING>, Const<DILATION>, Groups, Const<OUTPUT_PADDING>>
8790
for (Const<DIM>, Const<KERNEL>)
8891
where
89-
Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>: Sized,
92+
Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 + OUTPUT_PADDING }>:
93+
Sized,
9094
{
91-
type Convolved = Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 }>;
95+
type Convolved =
96+
Const<{ (DIM - 1) * STRIDE - 2 * PADDING + DILATION * (KERNEL - 1) + 1 + OUTPUT_PADDING }>;
9297

9398
fn try_convtrans2d(
9499
self,
95100
_: Const<STRIDE>,
96101
_: Const<PADDING>,
97102
_: Const<DILATION>,
98103
_: Groups,
104+
_: Const<OUTPUT_PADDING>,
99105
) -> Result<Self::Convolved, Error> {
100106
Ok(Const)
101107
}
102108
}
103109

104-
impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim>
105-
TryConvTrans2D<Stride, Padding, Dilation, Groups> for (usize, Kernel)
110+
impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim, OutputPadding: Dim>
111+
TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding> for (usize, Kernel)
106112
{
107113
type Convolved = usize;
108114

@@ -112,18 +118,33 @@ impl<Kernel: Dim, Stride: Dim, Padding: Dim, Dilation: Dim, Groups: Dim>
112118
padding: Padding,
113119
dilation: Dilation,
114120
_: Groups,
121+
output_padding: OutputPadding,
115122
) -> Result<Self::Convolved, Error> {
116123
let (dim, kernel) = self;
117-
Ok(
118-
((dim - 1) * stride.size() + dilation.size() * (kernel.size() - 1) + 1)
119-
.checked_sub(2 * padding.size())
120-
.unwrap(),
121-
)
124+
Ok(((dim - 1) * stride.size()
125+
+ dilation.size() * (kernel.size() - 1)
126+
+ 1
127+
+ output_padding.size())
128+
.checked_sub(2 * padding.size())
129+
.unwrap())
122130
}
123131
}
124132

125-
impl<InpChan, OutChanOverGroups, Kernel, Stride, Padding, Dilation, Groups, H, W, E, D, T>
126-
TryConvTrans2D<Stride, Padding, Dilation, Groups>
133+
impl<
134+
InpChan,
135+
OutChanOverGroups,
136+
Kernel,
137+
Stride,
138+
Padding,
139+
Dilation,
140+
Groups,
141+
OutputPadding,
142+
H,
143+
W,
144+
E,
145+
D,
146+
T,
147+
> TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>
127148
for (
128149
Tensor<(InpChan, H, W), E, D, T>,
129150
Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>,
@@ -136,23 +157,26 @@ where
136157
Padding: Dim,
137158
Dilation: Dim,
138159
Groups: Dim,
160+
OutputPadding: Dim,
139161
H: Dim,
140162
W: Dim,
141163
E: Dtype,
142164
D: ConvTrans2DKernel<E> + crate::tensor_ops::reshape_to::ReshapeKernel<E>,
143165
T: Tape<E, D>,
144166
OutChanOverGroups: std::ops::Mul<Groups>,
145167
<OutChanOverGroups as std::ops::Mul<Groups>>::Output: Dim,
146-
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
147-
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
148-
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
149-
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
168+
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
169+
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
170+
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
171+
Dim,
172+
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
173+
Dim,
150174
{
151175
type Convolved = Tensor<
152176
(
153177
<OutChanOverGroups as std::ops::Mul<Groups>>::Output,
154-
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
155-
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
178+
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
179+
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
156180
),
157181
E,
158182
D,
@@ -165,11 +189,13 @@ where
165189
padding: Padding,
166190
dilation: Dilation,
167191
groups: Groups,
192+
output_padding: OutputPadding,
168193
) -> Result<Self::Convolved, Error> {
169194
let (img, filters) = self;
170195
let (inp_chan, h, w) = img.shape;
171196
let img = img.try_reshape_like(&(Const::<1>, inp_chan, h, w))?;
172-
let out = (img, filters).try_convtrans2d(stride, padding, dilation, groups)?;
197+
let out =
198+
(img, filters).try_convtrans2d(stride, padding, dilation, groups, output_padding)?;
173199
let (_, out_chan, out_h, out_w) = out.shape;
174200
out.try_reshape_like(&(out_chan, out_h, out_w))
175201
}
@@ -182,13 +208,14 @@ impl<
182208
Padding,
183209
Dilation,
184210
Groups,
211+
OutputPadding,
185212
Batch,
186213
H,
187214
W,
188215
E,
189216
D,
190217
T,
191-
> TryConvTrans2D<Stride, Padding, Dilation, Groups>
218+
> TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>
192219
for (
193220
Tensor<(Batch, InpChan, H, W), E, D, T>,
194221
Tensor<(InpChan, OutChanOverGroups, Kernel, Kernel), E, D>,
@@ -201,6 +228,7 @@ where
201228
Padding: Dim,
202229
Dilation: Dim,
203230
Groups: Dim,
231+
OutputPadding: Dim,
204232
Batch: Dim,
205233
H: Dim,
206234
W: Dim,
@@ -209,17 +237,19 @@ where
209237
T: Tape<E, D>,
210238
OutChanOverGroups: std::ops::Mul<Groups>,
211239
<OutChanOverGroups as std::ops::Mul<Groups>>::Output: Dim,
212-
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
213-
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups>,
214-
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
215-
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved: Dim,
240+
(H, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
241+
(W, Kernel): TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>,
242+
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
243+
Dim,
244+
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved:
245+
Dim,
216246
{
217247
type Convolved = Tensor<
218248
(
219249
Batch,
220250
<OutChanOverGroups as std::ops::Mul<Groups>>::Output,
221-
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
222-
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups>>::Convolved,
251+
<(H, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
252+
<(W, Kernel) as TryConvTrans2D<Stride, Padding, Dilation, Groups, OutputPadding>>::Convolved,
223253
),
224254
E,
225255
D,
@@ -232,6 +262,7 @@ where
232262
padding: Padding,
233263
dilation: Dilation,
234264
groups: Groups,
265+
output_padding: OutputPadding,
235266
) -> Result<Self::Convolved, Error> {
236267
let (img, filters) = self;
237268
assert_eq!(img.shape.1, filters.shape.0);
@@ -242,8 +273,8 @@ where
242273
if img.strides != img.shape.strides() || filters.strides != filters.shape.strides() {
243274
panic!("Image & filter inputs to conv2d must be contiguous");
244275
}
245-
let h_out = (h, kernel).convtrans2d(stride, padding, dilation, groups);
246-
let w_out = (w, kernel).convtrans2d(stride, padding, dilation, groups);
276+
let h_out = (h, kernel).convtrans2d(stride, padding, dilation, groups, output_padding);
277+
let w_out = (w, kernel).convtrans2d(stride, padding, dilation, groups, output_padding);
247278
let op = ConvTrans2DOp {
248279
stride: stride.size(),
249280
padding: padding.size(),

dfdx-core/src/tensor_ops/convtrans2d/tests.rs

+14-14
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ fn test_convtrans2d_default() {
3333
],
3434
])
3535
.to_dtype::<TestDtype>();
36-
let y =
37-
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>);
36+
let y = (x.leaky_trace(), w.clone())
37+
.convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<1>, Const::<0>);
3838
#[rustfmt::skip]
3939
assert_close_to_literal!(
4040
y,
@@ -125,8 +125,8 @@ fn test_convtrans2d_stride_2() {
125125
],
126126
])
127127
.to_dtype::<TestDtype>();
128-
let y =
129-
(x.leaky_trace(), w.clone()).convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>);
128+
let y = (x.leaky_trace(), w.clone())
129+
.convtrans2d(Const::<2>, Const::<0>, Const::<1>, Const::<1>, Const::<0>);
130130
#[rustfmt::skip]
131131
assert_close_to_literal!(
132132
y,
@@ -223,8 +223,8 @@ fn test_convtrans2d_padded() {
223223
],
224224
])
225225
.to_dtype::<TestDtype>();
226-
let y =
227-
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>);
226+
let y = (x.leaky_trace(), w.clone())
227+
.convtrans2d(Const::<1>, Const::<1>, Const::<1>, Const::<1>, Const::<0>);
228228
assert_close_to_literal!(
229229
y,
230230
[
@@ -283,8 +283,8 @@ fn test_convtrans2d_batched() {
283283
let x: Tensor<Rank3<3, 28, 28>, TestDtype, _> = dev.sample_normal();
284284
let w: Tensor<Rank4<3, 5, 6, 6>, TestDtype, _> = dev.sample_normal();
285285

286-
let y: Tensor<Rank3<5, 83, 83>, _, _, _> =
287-
(x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>);
286+
let y: Tensor<Rank3<5, 83, 83>, _, _, _> = (x.leaky_trace(), w.clone())
287+
.convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>, Const::<0>);
288288
let y0 = y.retaped::<NoneTape>();
289289
let grads0 = y.square().mean().backward();
290290
let x0 = grads0.get(&x);
@@ -294,8 +294,8 @@ fn test_convtrans2d_batched() {
294294
.broadcast::<Rank4<10, 3, 28, 28>, _>()
295295
.reshape::<Rank4<10, 3, 28, 28>>();
296296

297-
let y: Tensor<Rank4<10, 5, 83, 83>, _, _, _> =
298-
(x.leaky_trace(), w.clone()).convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>);
297+
let y: Tensor<Rank4<10, 5, 83, 83>, _, _, _> = (x.leaky_trace(), w.clone())
298+
.convtrans2d(Const::<3>, Const::<2>, Const::<1>, Const::<1>, Const::<0>);
299299
for i in 0..10 {
300300
assert_close_to_tensor!(y0, y.retaped::<NoneTape>().select(dev.tensor(i)), 1e-5);
301301
}
@@ -341,8 +341,8 @@ fn test_convtrans2d_grouped() {
341341
],
342342
])
343343
.to_dtype::<TestDtype>();
344-
let y =
345-
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>);
344+
let y = (x.leaky_trace(), w.clone())
345+
.convtrans2d(Const::<1>, Const::<0>, Const::<1>, Const::<2>, Const::<0>);
346346
#[rustfmt::skip]
347347
assert_close_to_literal!(
348348
y,
@@ -451,8 +451,8 @@ fn test_convtrans2d_dilated() {
451451
],
452452
])
453453
.to_dtype::<TestDtype>();
454-
let y =
455-
(x.leaky_trace(), w.clone()).convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>);
454+
let y = (x.leaky_trace(), w.clone())
455+
.convtrans2d(Const::<1>, Const::<0>, Const::<2>, Const::<1>, Const::<0>);
456456
#[rustfmt::skip]
457457
assert_close_to_literal!(
458458
y,

0 commit comments

Comments
 (0)