|
19 | 19 | quantized_conv_nhwc,
|
20 | 20 | quantized_layer_norm_per_tensor,
|
21 | 21 | quantized_linear,
|
| 22 | + quantized_relu, |
22 | 23 | )
|
23 | 24 | from executorch.backends.cadence.aot.typing_stubs import expand
|
24 | 25 |
|
@@ -744,3 +745,84 @@ def test_quantized_conv(
|
744 | 745 | torch.equal(output, expected_output),
|
745 | 746 | f"Output values don't match expected. Got {output}, expected {expected_output}",
|
746 | 747 | )
|
| 748 | + |
| 749 | + @expand( |
| 750 | + [ |
| 751 | + # Test case 1: Basic int8 case with negative scale |
| 752 | + ( |
| 753 | + "basic_int8", |
| 754 | + torch.tensor([-1, 0, 1, 3], dtype=torch.int8), # input |
| 755 | + torch.tensor([0], dtype=torch.int8), # X_zero_point (scalar broadcast) |
| 756 | + 0, # out_zero_point |
| 757 | + torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) |
| 758 | + torch.tensor([0]), # out_shift |
| 759 | + torch.int8, # dtype |
| 760 | + torch.tensor( |
| 761 | + [0, 0, 0, -2], dtype=torch.int8 |
| 762 | + ), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2) |
| 763 | + ), |
| 764 | + # Test case 2: uint8 with non-zero zero point |
| 765 | + ( |
| 766 | + "uint8_with_zp", |
| 767 | + torch.tensor([126, 128, 130, 132], dtype=torch.uint8), # input |
| 768 | + torch.tensor([128], dtype=torch.uint8), # X_zero_point |
| 769 | + 64, # out_zero_point |
| 770 | + torch.tensor([536870912]), # out_multiplier (0.25 * 2^31) |
| 771 | + torch.tensor([0]), # out_shift |
| 772 | + torch.uint8, # dtype |
| 773 | + torch.tensor( |
| 774 | + [64, 64, 64, 63], dtype=torch.uint8 |
| 775 | + ), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63) |
| 776 | + ), |
| 777 | + # Test case 3: All negative values (should all become zero after ReLU) |
| 778 | + ( |
| 779 | + "all_negative_int8", |
| 780 | + torch.tensor([-5, -3, -1], dtype=torch.int8), # input |
| 781 | + torch.tensor([0], dtype=torch.int8), # X_zero_point |
| 782 | + 10, # out_zero_point |
| 783 | + torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) |
| 784 | + torch.tensor([0]), # out_shift |
| 785 | + torch.int8, # dtype |
| 786 | + torch.tensor( |
| 787 | + [10, 10, 10], dtype=torch.int8 |
| 788 | + ), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10) |
| 789 | + ), |
| 790 | + # Test case 4: All positive values with shift (scale becomes -0.25) |
| 791 | + ( |
| 792 | + "positive_with_shift", |
| 793 | + torch.tensor([2, 4, 6, 8], dtype=torch.int8), # input |
| 794 | + torch.tensor([1], dtype=torch.int8), # X_zero_point |
| 795 | + 5, # out_zero_point |
| 796 | + torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) |
| 797 | + torch.tensor([1]), # out_shift (multiply by 2^1 = 2) |
| 798 | + torch.int8, # dtype |
| 799 | + torch.tensor( |
| 800 | + [4, 2, 0, -2], dtype=torch.int8 |
| 801 | + ), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2) |
| 802 | + ), |
| 803 | + ] |
| 804 | + ) |
| 805 | + def test_quantized_relu( |
| 806 | + self, |
| 807 | + name: str, |
| 808 | + X: torch.Tensor, |
| 809 | + X_zero_point: torch.Tensor, |
| 810 | + out_zero_point: int, |
| 811 | + out_multiplier: torch.Tensor, |
| 812 | + out_shift: torch.Tensor, |
| 813 | + dtype: torch.dtype, |
| 814 | + expected_output: torch.Tensor, |
| 815 | + ) -> None: |
| 816 | + output = quantized_relu( |
| 817 | + X, X_zero_point, out_zero_point, out_multiplier, out_shift |
| 818 | + ) |
| 819 | + |
| 820 | + # Verify output properties |
| 821 | + self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}") |
| 822 | + self.assertEqual(output.shape, X.shape, "Output shape should match input shape") |
| 823 | + |
| 824 | + # Verify output matches expected values |
| 825 | + self.assertTrue( |
| 826 | + torch.equal(output, expected_output), |
| 827 | + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", |
| 828 | + ) |
0 commit comments