From 5f1ae2a381b8a19cae0ad44c5fd22de5d6abcc82 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Wed, 26 Feb 2025 13:40:21 +0800 Subject: [PATCH] Fix index put core shpes (#461) --- benchmark/core_shapes.yaml | 16 +++++ benchmark/test_select_and_slice_perf.py | 89 +++++++++++++------------ 2 files changed, 62 insertions(+), 43 deletions(-) diff --git a/benchmark/core_shapes.yaml b/benchmark/core_shapes.yaml index 43f373a86..c7e8b3040 100644 --- a/benchmark/core_shapes.yaml +++ b/benchmark/core_shapes.yaml @@ -192,3 +192,19 @@ AttentionBenchmark: - [4, 8, 2048, 128] - [4, 8, 3072, 128] - [4, 8, 4096, 128] + +IndexPutAccFalseBenchmark: + shapes: + - [[268435456,], [[65536,],], [65536,]] + - [[32, 32], [[8,], [2, 8]], [8,]] + - [[1024, 1024], [[4, 64],], [1024,]] + - [[512, 512, 512], [[2, 128], [128,], [128,]], [128,]] + - [[512, 512, 512], [[2, 128],], [512,]] + +IndexPutAccTrueBenchmark: + shapes: + - [[268435456,], [[65536,],], [65536,]] + - [[32, 32], [[8,], [8,]], [8,]] + - [[1024, 1024], [[64,], [64,]], [64,]] + - [[512, 512, 512], [[128,], [128,], [128,]], [128,]] + - [[512, 512, 512], [[2, 128], [2, 128], [2, 128]], [2, 128]] diff --git a/benchmark/test_select_and_slice_perf.py b/benchmark/test_select_and_slice_perf.py index 6fa619b27..534666271 100644 --- a/benchmark/test_select_and_slice_perf.py +++ b/benchmark/test_select_and_slice_perf.py @@ -304,45 +304,46 @@ def inner(shapes, dtype, device): return inner -@pytest.mark.index_put -def test_index_put_acc_false_perf(): - class IndexPutBenchmark(GenericBenchmark): - def set_more_shapes(self): - INDEX_PUT_SHAPE = ( - ((2**28,), ((2**16,),), (2**16,)), - ((32, 32), ((8,), (8,)), (8,)), - ((32, 32), ((8,), (2, 8)), (8,)), - ((32, 32), ((2, 8),), (32,)), - ((1024, 1024), ((64,), (64,)), (64,)), +class IndexPutAccFalseBenchmark(GenericBenchmark): + def set_more_shapes(self): + INDEX_PUT_SHAPE = ( + ((2**28,), ((2**16,),), (2**16,)), + ((32, 32), ((8,), (8,)), (8,)), + ((32, 32), ((8,), (2, 8)), (8,)), + ((32, 32), ((2, 8),), (32,)), + ((1024, 1024), ((64,), (64,)), (64,)), + ( + (1024, 1024), ( - (1024, 1024), + (64,), ( - (64,), - ( - 4, - 64, - ), + 4, + 64, ), - (64,), ), + (64,), + ), + ( + (1024, 1024), ( - (1024, 1024), ( - ( - 4, - 64, - ), + 4, + 64, ), - (1024,), ), - ((512, 512, 512), ((128,), (128,), (128,)), (128,)), - ((512, 512, 512), ((2, 128), (128,), (128,)), (128,)), - ((512, 512, 512), ((2, 128),), (512,)), - ) - self.shapes = INDEX_PUT_SHAPE - return None - - bench = IndexPutBenchmark( + (1024,), + ), + ((512, 512, 512), ((128,), (128,), (128,)), (128,)), + ((512, 512, 512), ((2, 128), (128,), (128,)), (128,)), + ((512, 512, 512), ((2, 128),), (512,)), + ) + self.shapes = INDEX_PUT_SHAPE + return None + + +@pytest.mark.index_put +def test_index_put_acc_false_perf(): + bench = IndexPutAccFalseBenchmark( op_name="index_put", torch_op=torch.index_put, input_fn=index_put_input_fn(False), @@ -351,20 +352,22 @@ def set_more_shapes(self): bench.run() +class IndexPutAccTrueBenchmark(GenericBenchmark): + def set_more_shapes(self): + INDEX_PUT_SHAPE = ( + ((2**28,), ((2**16,),), (2**16,)), + ((32, 32), ((8,), (8,)), (8,)), + ((1024, 1024), ((64,), (64,)), (64,)), + ((512, 512, 512), ((128,), (128,), (128,)), (128,)), + ((512, 512, 512), ((2, 128), (2, 128), (2, 128)), (2, 128)), + ) + self.shapes = INDEX_PUT_SHAPE + return None + + @pytest.mark.index_put def test_index_put_acc_true_perf(): - class IndexPutBenchmark(GenericBenchmark): - def set_more_shapes(self): - INDEX_PUT_SHAPE = ( - ((2**28,), ((2**16,),), (2**16,)), - ((32, 32), ((8,), (8,)), (8,)), - ((1024, 1024), ((64,), (64,)), (64,)), - ((512, 512, 512), ((128,), (128,), (128,)), (128,)), - ) - self.shapes = INDEX_PUT_SHAPE - return None - - bench = IndexPutBenchmark( + bench = IndexPutAccTrueBenchmark( op_name="index_put", torch_op=torch.index_put, input_fn=index_put_input_fn(True),