@@ -48,31 +48,73 @@ const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newSh
48
48
return { newShape, newPerm } ;
49
49
} ;
50
50
51
+ const isTransposeReshape = ( perm : number [ ] , shape : readonly number [ ] ) => {
52
+ // As long as the dims with values > 1 stay in the same order, it's a reshape.
53
+ // Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
54
+ let lastPermutedAxis = 0 ;
55
+ for ( let i = 0 ; i < perm . length ; ++ i ) {
56
+ if ( shape [ perm [ i ] ] === 1 ) {
57
+ continue ;
58
+ }
59
+ if ( perm [ i ] < lastPermutedAxis ) {
60
+ return false ;
61
+ }
62
+ lastPermutedAxis = perm [ i ] ;
63
+ }
64
+ return true ;
65
+ } ;
66
+
51
67
export const createTransposeProgramInfo = ( inputTensor : TensorView , permAttr : number [ ] ) : ProgramInfo => {
52
68
const inputDataType = inputTensor . dataType ;
53
69
const inputRank = inputTensor . dims . length ;
54
70
const perm = getAdjustedPerm ( inputRank , permAttr ) ;
55
71
const outputShape = getOutputShape ( inputTensor . dims , perm ) ;
72
+ let newInputShape = inputTensor . dims ;
73
+ let newOutputShape = outputShape ;
74
+ const transposeAsReshape = isTransposeReshape ( perm , inputTensor . dims ) ;
75
+ let getShaderSource ;
76
+ if ( transposeAsReshape ) {
77
+ getShaderSource = ( shaderHelper : ShaderHelper ) => {
78
+ const input = inputVariable ( 'input' , inputDataType , newInputShape , 4 ) ;
79
+ const output = outputVariable ( 'output' , inputDataType , newOutputShape , 4 ) ;
80
+ return `
81
+ ${ shaderHelper . registerUniform ( 'output_size' , 'u32' ) . declareVariables ( input , output ) }
82
+ ${ shaderHelper . mainStart ( ) }
83
+ ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size' ) }
84
+ output[global_idx] = input[global_idx];
85
+ }` ;
86
+ } ;
87
+
88
+ return {
89
+ name : 'TransposeCopy' ,
90
+ shaderCache : { inputDependencies : [ 'type' ] } ,
91
+ getRunData : ( ) => {
92
+ const outputSize = ShapeUtil . size ( outputShape ) ;
93
+ return {
94
+ outputs : [ { dims : outputShape , dataType : inputTensor . dataType } ] ,
95
+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ / 4 /* components */ ) } ,
96
+ programUniforms : [ { type : DataType . uint32 , data : Math . ceil ( outputSize / 4 ) } ] ,
97
+ } ;
98
+ } ,
99
+ getShaderSource,
100
+ } ;
101
+ }
56
102
const { newShape, newPerm } = squeezeShape ( inputTensor . dims , perm ) ;
57
103
const channelsLast = ShapeUtil . areEqual ( newPerm , [ 2 , 3 , 1 ] ) ;
58
104
const channelsFirst = ShapeUtil . areEqual ( newPerm , [ 3 , 1 , 2 ] ) ;
59
- const useShared = ( newShape . length === 2 && newPerm [ 0 ] > newPerm [ 1 ] ) || channelsLast || channelsFirst ;
60
- let newInputShape = useShared ? newShape : inputTensor . dims ;
61
- let newOutputShape = outputShape ;
105
+ const useShared = newShape . length === 2 || channelsLast || channelsFirst ;
62
106
if ( useShared ) {
63
107
newInputShape = channelsLast
64
108
? [ newShape [ 0 ] , newShape [ 1 ] * newShape [ 2 ] ]
65
109
: channelsFirst
66
110
? [ newShape [ 0 ] * newShape [ 1 ] , newShape [ 2 ] ]
67
111
: newShape ;
68
112
newOutputShape = [ newInputShape [ 1 ] , newInputShape [ 0 ] ] ;
69
- }
70
- const input = inputVariable ( 'a' , inputDataType , newInputShape . length ) ;
71
- const output = outputVariable ( 'output' , inputDataType , newOutputShape . length ) ;
72
- const tileSize = 16 ;
73
- let getShaderSource ;
74
- if ( useShared ) {
75
- getShaderSource = ( shaderHelper : ShaderHelper ) => `
113
+ const tileSize = 16 ;
114
+ getShaderSource = ( shaderHelper : ShaderHelper ) => {
115
+ const input = inputVariable ( 'a' , inputDataType , newInputShape . length ) ;
116
+ const output = outputVariable ( 'output' , inputDataType , newOutputShape . length ) ;
117
+ return `
76
118
${ shaderHelper . registerUniform ( 'output_size' , 'u32' ) . declareVariables ( input , output ) }
77
119
var<workgroup> tile : array<array<${ output . type . value } , ${ tileSize + 1 } >, ${ tileSize } >;
78
120
${ shaderHelper . mainStart ( [ tileSize , tileSize , 1 ] ) }
@@ -92,8 +134,29 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
92
134
${ output . setByIndices ( `${ output . type . indices } (output_row, output_col)` , 'tile[local_id.x][local_id.y]' ) }
93
135
}
94
136
}` ;
95
- } else {
96
- getShaderSource = ( shaderHelper : ShaderHelper ) => `
137
+ } ;
138
+ return {
139
+ name : 'TransposeShared' ,
140
+ shaderCache : { inputDependencies : [ 'type' ] } ,
141
+ getRunData : ( ) => {
142
+ const outputSize = ShapeUtil . size ( outputShape ) ;
143
+ return {
144
+ outputs : [ { dims : outputShape , dataType : inputTensor . dataType } ] ,
145
+ dispatchGroup : { x : Math . ceil ( newOutputShape [ 1 ] / tileSize ) , y : Math . ceil ( newOutputShape [ 0 ] / tileSize ) } ,
146
+ programUniforms : [
147
+ { type : DataType . uint32 , data : outputSize } ,
148
+ ...createTensorShapeVariables ( newInputShape , newOutputShape ) ,
149
+ ] ,
150
+ } ;
151
+ } ,
152
+ getShaderSource,
153
+ } ;
154
+ }
155
+
156
+ getShaderSource = ( shaderHelper : ShaderHelper ) => {
157
+ const input = inputVariable ( 'a' , inputDataType , newInputShape . length ) ;
158
+ const output = outputVariable ( 'output' , inputDataType , newOutputShape . length ) ;
159
+ return `
97
160
${ shaderHelper . registerUniform ( 'output_size' , 'u32' ) . declareVariables ( input , output ) }
98
161
99
162
${ permFunctionBody ( perm , inputRank , input , output ) }
@@ -106,17 +169,15 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
106
169
107
170
${ output . setByOffset ( 'global_idx' , input . getByIndices ( 'aIndices' ) ) }
108
171
}` ;
109
- }
172
+ } ;
110
173
return {
111
- name : useShared ? 'TransposeShared' : 'Transpose' ,
174
+ name : 'Transpose' ,
112
175
shaderCache : { hint : `${ permAttr } ` , inputDependencies : [ 'rank' ] } ,
113
176
getRunData : ( ) => {
114
177
const outputSize = ShapeUtil . size ( outputShape ) ;
115
178
return {
116
179
outputs : [ { dims : outputShape , dataType : inputTensor . dataType } ] ,
117
- dispatchGroup : useShared
118
- ? { x : Math . ceil ( newOutputShape [ 1 ] / tileSize ) , y : Math . ceil ( newOutputShape [ 0 ] / tileSize ) }
119
- : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
180
+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
120
181
programUniforms : [
121
182
{ type : DataType . uint32 , data : outputSize } ,
122
183
...createTensorShapeVariables ( newInputShape , newOutputShape ) ,
0 commit comments