diff --git a/python/benchmarks/zarr-compressors.py b/python/benchmarks/zarr-compressors.py index 051fd3f59b..b36302e140 100644 --- a/python/benchmarks/zarr-compressors.py +++ b/python/benchmarks/zarr-compressors.py @@ -30,10 +30,15 @@ class Compressor: compressors = { + "lz4-default": Compressor(device=Device.CPU, codec=numcodecs.LZ4()), "lz4-blosc": Compressor( device=Device.CPU, codec=numcodecs.blosc.Blosc(cname="lz4") ), "lz4-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.LZ4()), + "snappy-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.Snappy()), + "cascaded-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.Cascaded()), + "gdeflate-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.Gdeflate()), + "bitcomp-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.Bitcomp()), } @@ -85,17 +90,20 @@ def main(args): ) ) + nbytes = f"{format_bytes(args.nbytes)} bytes ({args.nbytes})" print("Encode/decode benchmark") print("----------------------------------") - print(f"GPU | {gpu_name}") - print(f"GPU Memory Total | {mem_total}") - print(f"BAR1 Memory Total | {bar1_total}") + print(f"GPU | {gpu_name}") + print(f"GPU Memory Total | {mem_total}") + print(f"BAR1 Memory Total | {bar1_total}") print("----------------------------------") - print(f"nbytes | {args.nbytes} bytes ({format_bytes(args.nbytes)})") - print(f"4K aligned | {args.nbytes % 4096 == 0}") - print(f"nruns | {args.nruns}") + print(f"nbytes | {nbytes}") + print(f"4K aligned | {args.nbytes % 4096 == 0}") + print(f"nruns | {args.nruns}") print("==================================") + encode_output = "" + decode_output = "" # Run each benchmark using the requested APIs for comp_name, comp in ((n, compressors[n]) for n in args.compressors): rs = [] @@ -109,7 +117,7 @@ def main(args): def pprint_api_res(name, samples): mean = statistics.mean(samples) if len(samples) > 1 else samples[0] - ret = f"{comp_name} {name}".ljust(18) + ret = f"{comp_name} {name}".ljust(24) ret += f"| {format_bytes(mean).rjust(10)}/s".ljust(14) if len(samples) > 1: stdev = statistics.stdev(samples) / mean * 100 @@ -120,8 +128,12 @@ def pprint_api_res(name, samples): ret = ret[:-2] + ")" # Replace trailing comma return ret - print(pprint_api_res("encode", rs)) - print(pprint_api_res("decode", ws)) + encode_output += pprint_api_res("", rs) + "\n" + decode_output += pprint_api_res("", ws) + "\n" + print("Encode:") + print(encode_output) + print("Decode:") + print(decode_output) if __name__ == "__main__": @@ -143,7 +155,7 @@ def pprint_api_res(name, samples): parser.add_argument( "--nruns", metavar="RUNS", - default=3, + default=1, type=int, help="Number of runs per API (default: %(default)s).", )