From 1035fcb37db75227b3de9ddddf8236772c75780f Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Fri, 5 Apr 2024 14:04:22 -0400 Subject: [PATCH] [CuBLAS] Add CuBLAS benchmarks Some CuBLAS benchmarking results on RTX2080 TI (all measurements are median latencies): SECTION 1 FP32 Matrix Multiply: C (bs x m x n) = A (bs x m x k) @ B(bs x k x n) Group 1 results with m = 512, n = 512, k = 512 bs = 1: cublas_batched_gemm 69.0us cublas_strided_gemm 41.0us hidet.ops.matmul optimized 37.0us PyTorch 44.6us bs = 2: cublas_batched_gemm 111.7us cublas_strided_gemm 75.8us hidet.ops.matmul optimized 69.2us PyTorch 71.7us bs = 4: cublas_batched_gemm 124.9us cublas_strided_gemm 97.2us hidet.ops.matmul optimized 100.8us PyTorch 96.3us bs = 8: cublas_batched_gemm 190.5us cublas_strided_gemm 191.1us hidet.ops.matmul optimized 204.7us PyTorch 187.6us Group 2 results with m = 1024, n = 1024, k = 2048 bs = 1: cublas_batched_gemm 405.1us cublas_strided_gemm 419.2us hidet.ops.matmul optimized 370.7us PyTorch 405.1us bs = 2: cublas_batched_gemm 725.3us cublas_strided_gemm 859.9us hidet.ops.matmul optimized 800.8us PyTorch 719.2us bs = 4: cublas_batched_gemm 1442us cublas_strided_gemm 1592us hidet.ops.matmul optimized 1606us PyTorch 1466us bs = 8: cublas_batched_gemm 2658us cublas_strided_gemm 2830us hidet.ops.matmul optimized 3475us PyTorch 2753us SECTION 2 FP16 Matrix Multiply: C (bs x m x n) = A (bs x m x k) @ B(bs x k x n) Group 1 results with m = 512, n = 512, k = 512 bs = 1: cublas_batched_gemm 63.5us cublas_strided_gemm 34.0us hidet.ops.matmul optimized 34.9us PyTorch 41.0us bs = 2: cublas_batched_gemm 66.0us cublas_strided_gemm 30.2us hidet.ops.matmul optimized 64.8us PyTorch 45.1us bs = 4: cublas_batched_gemm 72.7us cublas_strided_gemm 32.4us hidet.ops.matmul optimized 24.4us PyTorch 46.3us bs = 8: cublas_batched_gemm 81.2us cublas_strided_gemm 36.2us hidet.ops.matmul optimized 38.5us PyTorch 47.8us Group 2 results with m = 1024, n = 1024, k = 2048 bs = 1: cublas_batched_gemm 71.0us cublas_strided_gemm 60.1us hidet.ops.matmul optimized 65.5us PyTorch 90.6us bs = 2: cublas_batched_gemm 114.8us cublas_strided_gemm 112.3us hidet.ops.matmul optimized 123.1us PyTorch 160.5us bs = 4: cublas_batched_gemm 225.1us cublas_strided_gemm 223.4us hidet.ops.matmul optimized 245.6us PyTorch 319.8us bs = 8: cublas_batched_gemm 442.8us cublas_strided_gemm 439.1us hidet.ops.matmul optimized 733.2us PyTorch 634.8us --- python/hidet/cuda/cublas/benchmark.py | 96 +++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 python/hidet/cuda/cublas/benchmark.py diff --git a/python/hidet/cuda/cublas/benchmark.py b/python/hidet/cuda/cublas/benchmark.py new file mode 100644 index 000000000..5aa8939b3 --- /dev/null +++ b/python/hidet/cuda/cublas/benchmark.py @@ -0,0 +1,96 @@ +import math +import torch +import numpy as np + +import hidet +from hidet.cuda.cublas import cublasComputeType +from hidet.utils.benchmark import do_bench +from hidet import ops + + +def benchmark_cublas_batched_gemm(bs, m, n, k, dtype, compute_type): + a, b, c = [], [], [] + for i in range(bs): + a.append(hidet.randn((m, k), device='cuda', dtype=dtype) / math.sqrt(k)) + b.append(hidet.randn((k, n), device='cuda', dtype=dtype) / math.sqrt(k)) + c.append(hidet.empty((m, n), device='cuda', dtype=dtype)) + + latencies = do_bench( + lambda: hidet.cuda.cublas.batched_gemm( + bs, m, n, k, a[0].dtype, b[0].dtype, c[0].dtype, a, b, c, False, False, compute_type + ), + warmup=10, + rep=100, + ) + + print(f"cublas_batched_gemm Results for Configuration: dtype = {dtype}, input shape = {[bs, m, n, k]}, ") + print("Median Latency Is: " + str(latencies[1]) + " milliseconds") + print("-------------------------------------------------") + + +def benchmark_cublas_strided_gemm(bs, m, n, k, dtype, compute_type): + a = hidet.randn((bs, m, k), device='cuda', dtype=dtype) / math.sqrt(k) + b = hidet.randn((bs, k, n), device='cuda', dtype=dtype) / math.sqrt(k) + c = hidet.empty((bs, m, n), device='cuda', dtype=dtype) + + latencies = do_bench( + lambda: hidet.cuda.cublas.strided_gemm( + bs, m, n, k, a.dtype, b.dtype, c.dtype, a, b, c, m * k, k * n, m * n, False, False, compute_type + ), + warmup=10, + rep=100, + ) + + print(f"cublas_strided_gemm Results for Configuration: dtype = {dtype}, input shape = {[bs, m, n, k]}, ") + print("Median Latency Is: " + str(latencies[1]) + " milliseconds") + print("-------------------------------------------------") + + +def benchmark_torch_batched_matmul(bs, m, n, k, dtype, compute_type): + a = torch.from_numpy(np.array(np.random.randn(bs, m, k)).astype(dtype)).cuda() + b = torch.from_numpy(np.array(np.random.randn(bs, k, n)).astype(dtype)).cuda() + + latencies = do_bench(lambda: a @ b, warmup=10, rep=100) + + print(f"torch_batched_matmul Results for Configuration: dtype = {dtype}, input shape = {[bs, m, n, k]}, ") + print("Median Latency Is: " + str(latencies[1]) + " milliseconds") + print("-------------------------------------------------") + + +def benchmark_hidet_batched_matmul(bs, m, n, k, dtype, compute_type): + a = hidet.symbol((bs, m, k), device='cuda', dtype=dtype) + b = hidet.symbol((bs, k, n), device='cuda', dtype=dtype) + c = ops.matmul(a, b) + hidet.option.search_space(2) + graph = hidet.trace_from(c, inputs=[a, b]) + graph = hidet.graph.optimize(graph) + graph = graph.cuda_graph() + + latencies = do_bench(lambda: graph.run_async(), warmup=10, rep=100) + + print(f"hidet_batched_matmul Results for Configuration: dtype = {dtype}, input shape = {[bs, m, n, k]}, ") + print("Median Latency Is: " + str(latencies[1]) + " milliseconds") + print("-------------------------------------------------") + + +if __name__ == '__main__': + sizes = [ + # # Group 1 + [1, 512, 512, 512], + [2, 512, 512, 512], + [4, 512, 512, 512], + [8, 512, 512, 512], + # Group 2 + [1, 1024, 1024, 2048], + [2, 1024, 1024, 2048], + [4, 1024, 1024, 2048], + [8, 1024, 1024, 2048], + ] + dtypes = [['float32', cublasComputeType.CUBLAS_COMPUTE_32F], ['float16', cublasComputeType.CUBLAS_COMPUTE_16F]] + + for data_type in dtypes: + for size in sizes: + # benchmark_cublas_batched_gemm(*(size + data_type)) + benchmark_cublas_strided_gemm(*(size + data_type)) + # benchmark_torch_batched_matmul(*(size + data_type)) + # benchmark_hidet_batched_matmul(*(size + data_type))