Skip to content

Commit 68e9c5a

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Backend-agnostic quantized_conv_nhwc (channels last) (pytorch#13954)
Summary: Ongoing work in providing python backend-agnostic references for Cadence custom ops. Reviewed By: hsharma35 Differential Revision: D81526626
1 parent 705150c commit 68e9c5a

File tree

2 files changed

+455
-268
lines changed

2 files changed

+455
-268
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,7 @@ def quantized_layer_norm_per_tensor(
298298
)
299299

300300

301-
@impl(m, "quantized_conv_nchw")
302-
def quantized_conv_nchw(
301+
def quantized_conv(
303302
input_tensor: torch.Tensor,
304303
weight: torch.Tensor,
305304
bias: torch.Tensor,
@@ -374,6 +373,120 @@ def quantized_conv_nchw(
374373
)
375374

376375

376+
@impl(m, "quantized_conv_nchw")
377+
def quantized_conv_nchw(
378+
input_tensor: torch.Tensor,
379+
weight: torch.Tensor,
380+
bias: torch.Tensor,
381+
stride: tuple[int, int],
382+
padding: tuple[int, int],
383+
dilation: tuple[int, int],
384+
groups: int,
385+
in_zero_point: int,
386+
weight_zero_point: torch.Tensor,
387+
bias_scale: torch.Tensor,
388+
output_scale: float,
389+
output_zero_point: int,
390+
out_multiplier: torch.Tensor,
391+
out_shift: torch.Tensor,
392+
) -> torch.Tensor:
393+
"""
394+
Quantized convolution operation.
395+
396+
Args:
397+
- input_tensor (Tensor): The activations tensor
398+
- weight (Tensor): The weight tensor
399+
- bias (Tensor): The bias tensor
400+
- stride (Tuple[int]): The stride of the convolution
401+
- padding (Tuple[int]): The padding of the convolution
402+
- dilation (Tuple[int]): The dilation of the convolution
403+
- groups (int): The number of groups
404+
- in_zero_point (int): The quantized mapping of zero for the input
405+
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
406+
- bias_scale (Tensor): The quantized bias scale
407+
- output_scale (float): The scale of the output
408+
- output_zero_point (int): The zero point of the output
409+
- out_multiplier (Tensor): Unused
410+
- out_shift (Tensor): Unused
411+
"""
412+
if not input_tensor.is_contiguous(memory_format=torch.contiguous_format):
413+
raise ValueError("Input tensor must be in NCHW format")
414+
return quantized_conv(
415+
input_tensor,
416+
weight,
417+
bias,
418+
stride,
419+
padding,
420+
dilation,
421+
groups,
422+
in_zero_point,
423+
weight_zero_point,
424+
bias_scale,
425+
output_scale,
426+
output_zero_point,
427+
out_multiplier,
428+
out_shift,
429+
)
430+
431+
432+
@impl(m, "quantized_conv_nhwc")
433+
def quantized_conv_nhwc(
434+
input_tensor: torch.Tensor,
435+
weight: torch.Tensor,
436+
bias: torch.Tensor,
437+
stride: tuple[int, int],
438+
padding: tuple[int, int],
439+
dilation: tuple[int, int],
440+
groups: int,
441+
in_zero_point: int,
442+
weight_zero_point: torch.Tensor,
443+
bias_scale: torch.Tensor,
444+
output_scale: float,
445+
output_zero_point: int,
446+
out_multiplier: torch.Tensor,
447+
out_shift: torch.Tensor,
448+
) -> torch.Tensor:
449+
"""
450+
Quantized convolution operation.
451+
452+
Args:
453+
- input_tensor (Tensor): The activations tensor
454+
- weight (Tensor): The weight tensor
455+
- bias (Tensor): The bias tensor
456+
- stride (Tuple[int]): The stride of the convolution
457+
- padding (Tuple[int]): The padding of the convolution
458+
- dilation (Tuple[int]): The dilation of the convolution
459+
- groups (int): The number of groups
460+
- in_zero_point (int): The quantized mapping of zero for the input
461+
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
462+
- bias_scale (Tensor): The quantized bias scale
463+
- output_scale (float): The scale of the output
464+
- output_zero_point (int): The zero point of the output
465+
- out_multiplier (Tensor): Unused
466+
- out_shift (Tensor): Unused
467+
"""
468+
469+
if not input_tensor.is_contiguous(memory_format=torch.channels_last):
470+
raise ValueError("Input tensor must be in NHWC format")
471+
472+
return quantized_conv(
473+
input_tensor,
474+
weight,
475+
bias,
476+
stride,
477+
padding,
478+
dilation,
479+
groups,
480+
in_zero_point,
481+
weight_zero_point,
482+
bias_scale,
483+
output_scale,
484+
output_zero_point,
485+
out_multiplier,
486+
out_shift,
487+
)
488+
489+
377490
@impl(m, "requantize")
378491
def requantize(
379492
input: torch.Tensor,

0 commit comments

Comments
 (0)