Skip to content

Commit f7de627

Browse files
committed
use outputChannelsPerGroup to calc components
1 parent 035ed29 commit f7de627

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts

+5-6
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ export const createConvTranspose2DProgramInfo = (
4141
const hasBias = inputs.length > 2;
4242
const outputShape = attributes.outputShape;
4343
const isChannelsLast = attributes.format === 'NHWC';
44-
const components = isChannelsLast && attributes.group === 1 ? getMaxComponents(outputShape[3]) : 1;
44+
const group = attributes.group;
45+
const wShape = inputs[1].dims;
46+
const inputChannelsPerGroup = wShape[2] / group;
47+
const outputChannelsPerGroup = wShape[3];
48+
const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1;
4549
const outputSize = ShapeUtil.size(outputShape) / components;
4650
const dispatch = [Math.ceil(outputSize / 64), 1, 1];
4751
LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`);
@@ -65,11 +69,6 @@ export const createConvTranspose2DProgramInfo = (
6569
effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2),
6670
];
6771

68-
const group = attributes.group;
69-
const wShape = inputs[1].dims;
70-
const inputChannelsPerGroup = wShape[2] / group;
71-
const outputChannelsPerGroup = wShape[3];
72-
7372
const programUniforms: ProgramUniform[] = [
7473
{ type: DataType.uint32, data: outputSize },
7574
{ type: DataType.uint32, data: strides },

0 commit comments

Comments
 (0)