@@ -17,78 +17,65 @@ class ConvBase : public JsKernel {
17
17
ConvBase (const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info),
18
18
conv_attrs_ (info),
19
19
w_is_const_(false ) {
20
- TensorShapeVector kernel_shape;
21
20
const size_t pads_vec_size = conv_attrs_.pads .size () == 0 ? 4 : conv_attrs_.pads .size ();
22
21
std::vector<int32_t > local_pads (pads_vec_size, 0 );
23
22
for (size_t i = 0 ; i < conv_attrs_.pads .size () && i < pads_vec_size; ++i) {
24
23
local_pads[i] = gsl::narrow_cast<int32_t >(conv_attrs_.pads [i]);
25
24
}
26
25
26
+ TensorShapeVector kernel_shape;
27
27
if (conv_attrs_.kernel_shape_specified ) {
28
28
ORT_ENFORCE (info.GetAttrs (" kernel_shape" , kernel_shape).IsOK ());
29
29
}
30
+ std::vector<int32_t > kernel_shapes (kernel_shape.size (), 0 );
31
+ if (conv_attrs_.kernel_shape_specified ) {
32
+ for (size_t i = 0 ; i < kernel_shape.size (); ++i) {
33
+ kernel_shapes[i] = gsl::narrow_cast<int32_t >(kernel_shape[i]);
34
+ }
35
+ }
36
+
37
+ std::vector<int32_t > strides (conv_attrs_.strides .size (), 0 );
38
+ for (size_t i = 0 ; i < conv_attrs_.strides .size (); ++i) {
39
+ strides[i] = gsl::narrow_cast<int32_t >(conv_attrs_.strides [i]);
40
+ }
41
+
42
+ std::vector<int32_t > dilations (conv_attrs_.dilations .size (), 0 );
43
+ for (size_t i = 0 ; i < conv_attrs_.dilations .size (); ++i) {
44
+ dilations[i] = gsl::narrow_cast<int32_t >(conv_attrs_.dilations [i]);
45
+ }
46
+
30
47
conv_attrs_.activation = info.GetAttrOrDefault <std::string>(" activation" , " " );
31
48
std::vector<float > activation_params = info.GetAttrsOrDefault <float >(" activation_params" );
32
49
int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault <int64_t >(" channels_last" , 0 );
33
- auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size () > 0 ? kernel_shape[0 ] : 0 ;
34
- auto kernel_shape_1 = conv_attrs_.kernel_shape_specified && kernel_shape.size () > 1 ? kernel_shape[1 ] : 0 ;
50
+
35
51
// currently only support Conv 1D/2D. TODO: support Conv3D and other
36
- if (conv_attrs_.dilations .size () == 1 ||
37
- (conv_attrs_.kernel_shape_specified && kernel_shape.size () == 1 ) ||
38
- conv_attrs_.strides .size () == 1 ) {
39
- JSEP_INIT_KERNEL_ATTRIBUTE (Conv, ({
40
- " format" : $8 ? " NHWC" : " NCHW" ,
41
- " auto_pad" : $1 ,
42
- " dilations" : [$2 ],
43
- " group" : $3 ,
44
- " kernel_shape" : [$4 ],
45
- " pads" : $5 ? Array.from (HEAP32.subarray ($5 , $6 )) : [],
46
- " strides" : [$7 ],
47
- " w_is_const" : () JS_ARROW (!!HEAP8[$9 ]),
48
- " activation" : UTF8ToString ($10 ),
49
- " activation_params" : $11 ? Array.from (HEAPF32.subarray ($11 , $12 )) : []
50
- }),
51
- static_cast <int32_t >(conv_attrs_.auto_pad ),
52
- static_cast <int32_t >(conv_attrs_.dilations .size () > 0 ? conv_attrs_.dilations [0 ] : 0 ),
53
- static_cast <int32_t >(conv_attrs_.group ),
54
- static_cast <int32_t >(kernel_shape_0),
55
- JSEP_HEAP32_INDEX_START (local_pads),
56
- JSEP_HEAP32_INDEX_END (local_pads),
57
- static_cast <int32_t >(conv_attrs_.strides .size () > 0 ? conv_attrs_.strides [0 ] : 0 ),
58
- static_cast <int32_t >(channels_last),
59
- JSEP_HEAP8_INDEX (&w_is_const_),
60
- conv_attrs_.activation .c_str (),
61
- JSEP_HEAP32_INDEX_START (activation_params),
62
- JSEP_HEAP32_INDEX_END (activation_params));
63
- } else {
64
- JSEP_INIT_KERNEL_ATTRIBUTE (Conv, ({
65
- " format" : $11 ? " NHWC" : " NCHW" ,
66
- " auto_pad" : $1 ,
67
- " dilations" : [ $2 , $3 ],
68
- " group" : $4 ,
69
- " kernel_shape" : [ $5 , $6 ],
70
- " pads" : $7 ? Array.from (HEAP32.subarray ($7 , $8 )) : [],
71
- " strides" : [ $9 , $10 ],
72
- " w_is_const" : () JS_ARROW (!!HEAP8[$12 ]),
73
- " activation" : UTF8ToString ($13 ),
74
- " activation_params" : $14 ? Array.from (HEAPF32.subarray ($14 , $15 )) : []
75
- }),
76
- static_cast <int32_t >(conv_attrs_.auto_pad ),
77
- static_cast <int32_t >(conv_attrs_.dilations .size () > 0 ? conv_attrs_.dilations [0 ] : 0 ),
78
- static_cast <int32_t >(conv_attrs_.dilations .size () > 1 ? conv_attrs_.dilations [1 ] : 0 ),
79
- static_cast <int32_t >(conv_attrs_.group ),
80
- static_cast <int32_t >(kernel_shape_0),
81
- static_cast <int32_t >(kernel_shape_1),
82
- JSEP_HEAP32_INDEX_START (local_pads),
83
- JSEP_HEAP32_INDEX_END (local_pads),
84
- static_cast <int32_t >(conv_attrs_.strides .size () > 0 ? conv_attrs_.strides [0 ] : 0 ),
85
- static_cast <int32_t >(conv_attrs_.strides .size () > 1 ? conv_attrs_.strides [1 ] : 0 ),
86
- static_cast <int32_t >(channels_last),
87
- JSEP_HEAP8_INDEX (&w_is_const_),
88
- conv_attrs_.activation .c_str (),
89
- JSEP_HEAP32_INDEX_START (activation_params),
90
- JSEP_HEAP32_INDEX_END (activation_params));
91
- }
52
+ JSEP_INIT_KERNEL_ATTRIBUTE (Conv, ({
53
+ " format" : $11 ? " NHWC" : " NCHW" ,
54
+ " auto_pad" : $1 ,
55
+ " dilations" : $2 ? Array.from (HEAP32.subarray ($2 , $3 )) : [],
56
+ " group" : $4 ,
57
+ " kernel_shape" : $5 ? Array.from (HEAP32.subarray ($5 , $6 )) : [],
58
+ " pads" : $7 ? Array.from (HEAP32.subarray ($7 , $8 )) : [],
59
+ " strides" : $9 ? Array.from (HEAP32.subarray ($9 , $10 )) : [],
60
+ " w_is_const" : () JS_ARROW (!!HEAP8[$12 ]),
61
+ " activation" : UTF8ToString ($13 ),
62
+ " activation_params" : $14 ? Array.from (HEAPF32.subarray ($14 , $15 )) : []
63
+ }),
64
+ static_cast <int32_t >(conv_attrs_.auto_pad ),
65
+ JSEP_HEAP32_INDEX_START (dilations),
66
+ JSEP_HEAP32_INDEX_END (dilations),
67
+ static_cast <int32_t >(conv_attrs_.group ),
68
+ JSEP_HEAP32_INDEX_START (kernel_shape),
69
+ JSEP_HEAP32_INDEX_END (kernel_shape),
70
+ JSEP_HEAP32_INDEX_START (local_pads),
71
+ JSEP_HEAP32_INDEX_END (local_pads),
72
+ JSEP_HEAP32_INDEX_START (strides),
73
+ JSEP_HEAP32_INDEX_END (strides),
74
+ static_cast <int32_t >(channels_last),
75
+ JSEP_HEAP8_INDEX (&w_is_const_),
76
+ conv_attrs_.activation .c_str (),
77
+ JSEP_HEAP32_INDEX_START (activation_params),
78
+ JSEP_HEAP32_INDEX_END (activation_params));
92
79
}
93
80
94
81
Status PrePack (const Tensor& tensor, int input_idx, AllocatorPtr alloc,
0 commit comments