@@ -41,7 +41,11 @@ export const createConvTranspose2DProgramInfo = (
41
41
const hasBias = inputs . length > 2 ;
42
42
const outputShape = attributes . outputShape ;
43
43
const isChannelsLast = attributes . format === 'NHWC' ;
44
- const components = isChannelsLast && attributes . group === 1 ? getMaxComponents ( outputShape [ 3 ] ) : 1 ;
44
+ const group = attributes . group ;
45
+ const wShape = inputs [ 1 ] . dims ;
46
+ const inputChannelsPerGroup = wShape [ 2 ] / group ;
47
+ const outputChannelsPerGroup = wShape [ 3 ] ;
48
+ const components = isChannelsLast ? getMaxComponents ( outputChannelsPerGroup ) : 1 ;
45
49
const outputSize = ShapeUtil . size ( outputShape ) / components ;
46
50
const dispatch = [ Math . ceil ( outputSize / 64 ) , 1 , 1 ] ;
47
51
LOG_DEBUG ( 'verbose' , ( ) => `[conv2d_backprop_webgpu] dispatch = ${ dispatch } ` ) ;
@@ -65,11 +69,6 @@ export const createConvTranspose2DProgramInfo = (
65
69
effectiveFilterDims [ 1 ] - 1 - Math . floor ( ( attributes . pads [ 1 ] + attributes . pads [ 3 ] ) / 2 ) ,
66
70
] ;
67
71
68
- const group = attributes . group ;
69
- const wShape = inputs [ 1 ] . dims ;
70
- const inputChannelsPerGroup = wShape [ 2 ] / group ;
71
- const outputChannelsPerGroup = wShape [ 3 ] ;
72
-
73
72
const programUniforms : ProgramUniform [ ] = [
74
73
{ type : DataType . uint32 , data : outputSize } ,
75
74
{ type : DataType . uint32 , data : strides } ,
0 commit comments