@@ -298,8 +298,7 @@ def quantized_layer_norm_per_tensor(
298
298
)
299
299
300
300
301
- @impl (m , "quantized_conv_nchw" )
302
- def quantized_conv_nchw (
301
+ def quantized_conv (
303
302
input_tensor : torch .Tensor ,
304
303
weight : torch .Tensor ,
305
304
bias : torch .Tensor ,
@@ -374,6 +373,120 @@ def quantized_conv_nchw(
374
373
)
375
374
376
375
376
+ @impl (m , "quantized_conv_nchw" )
377
+ def quantized_conv_nchw (
378
+ input_tensor : torch .Tensor ,
379
+ weight : torch .Tensor ,
380
+ bias : torch .Tensor ,
381
+ stride : tuple [int , int ],
382
+ padding : tuple [int , int ],
383
+ dilation : tuple [int , int ],
384
+ groups : int ,
385
+ in_zero_point : int ,
386
+ weight_zero_point : torch .Tensor ,
387
+ bias_scale : torch .Tensor ,
388
+ output_scale : float ,
389
+ output_zero_point : int ,
390
+ out_multiplier : torch .Tensor ,
391
+ out_shift : torch .Tensor ,
392
+ ) -> torch .Tensor :
393
+ """
394
+ Quantized convolution operation.
395
+
396
+ Args:
397
+ - input_tensor (Tensor): The activations tensor
398
+ - weight (Tensor): The weight tensor
399
+ - bias (Tensor): The bias tensor
400
+ - stride (Tuple[int]): The stride of the convolution
401
+ - padding (Tuple[int]): The padding of the convolution
402
+ - dilation (Tuple[int]): The dilation of the convolution
403
+ - groups (int): The number of groups
404
+ - in_zero_point (int): The quantized mapping of zero for the input
405
+ - weight_zero_point (Tensor): The quantized mapping of zero for the weight
406
+ - bias_scale (Tensor): The quantized bias scale
407
+ - output_scale (float): The scale of the output
408
+ - output_zero_point (int): The zero point of the output
409
+ - out_multiplier (Tensor): Unused
410
+ - out_shift (Tensor): Unused
411
+ """
412
+ if not input_tensor .is_contiguous (memory_format = torch .contiguous_format ):
413
+ raise ValueError ("Input tensor must be in NCHW format" )
414
+ return quantized_conv (
415
+ input_tensor ,
416
+ weight ,
417
+ bias ,
418
+ stride ,
419
+ padding ,
420
+ dilation ,
421
+ groups ,
422
+ in_zero_point ,
423
+ weight_zero_point ,
424
+ bias_scale ,
425
+ output_scale ,
426
+ output_zero_point ,
427
+ out_multiplier ,
428
+ out_shift ,
429
+ )
430
+
431
+
432
+ @impl (m , "quantized_conv_nhwc" )
433
+ def quantized_conv_nhwc (
434
+ input_tensor : torch .Tensor ,
435
+ weight : torch .Tensor ,
436
+ bias : torch .Tensor ,
437
+ stride : tuple [int , int ],
438
+ padding : tuple [int , int ],
439
+ dilation : tuple [int , int ],
440
+ groups : int ,
441
+ in_zero_point : int ,
442
+ weight_zero_point : torch .Tensor ,
443
+ bias_scale : torch .Tensor ,
444
+ output_scale : float ,
445
+ output_zero_point : int ,
446
+ out_multiplier : torch .Tensor ,
447
+ out_shift : torch .Tensor ,
448
+ ) -> torch .Tensor :
449
+ """
450
+ Quantized convolution operation.
451
+
452
+ Args:
453
+ - input_tensor (Tensor): The activations tensor
454
+ - weight (Tensor): The weight tensor
455
+ - bias (Tensor): The bias tensor
456
+ - stride (Tuple[int]): The stride of the convolution
457
+ - padding (Tuple[int]): The padding of the convolution
458
+ - dilation (Tuple[int]): The dilation of the convolution
459
+ - groups (int): The number of groups
460
+ - in_zero_point (int): The quantized mapping of zero for the input
461
+ - weight_zero_point (Tensor): The quantized mapping of zero for the weight
462
+ - bias_scale (Tensor): The quantized bias scale
463
+ - output_scale (float): The scale of the output
464
+ - output_zero_point (int): The zero point of the output
465
+ - out_multiplier (Tensor): Unused
466
+ - out_shift (Tensor): Unused
467
+ """
468
+
469
+ if not input_tensor .is_contiguous (memory_format = torch .channels_last ):
470
+ raise ValueError ("Input tensor must be in NHWC format" )
471
+
472
+ return quantized_conv (
473
+ input_tensor ,
474
+ weight ,
475
+ bias ,
476
+ stride ,
477
+ padding ,
478
+ dilation ,
479
+ groups ,
480
+ in_zero_point ,
481
+ weight_zero_point ,
482
+ bias_scale ,
483
+ output_scale ,
484
+ output_zero_point ,
485
+ out_multiplier ,
486
+ out_shift ,
487
+ )
488
+
489
+
377
490
@impl (m , "requantize" )
378
491
def requantize (
379
492
input : torch .Tensor ,
0 commit comments