Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Sep 5, 2023
1 parent 48c95ca commit 010fd14
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions python/benchmarks/zarr-compressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ class Compressor:


compressors = {
"lz4-default": Compressor(device=Device.CPU, codec=numcodecs.LZ4()),
"lz4-blosc": Compressor(
device=Device.CPU, codec=numcodecs.blosc.Blosc(cname="lz4")
),
"lz4-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.LZ4()),
"snappy-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.Snappy()),
"cascaded-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.Cascaded()),
"gdeflate-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.Gdeflate()),
"bitcomp-nvcomp": Compressor(device=Device.GPU, codec=kvikio.zarr.Bitcomp()),
}


Expand Down Expand Up @@ -85,17 +90,20 @@ def main(args):
)
)

nbytes = f"{format_bytes(args.nbytes)} bytes ({args.nbytes})"
print("Encode/decode benchmark")
print("----------------------------------")
print(f"GPU | {gpu_name}")
print(f"GPU Memory Total | {mem_total}")
print(f"BAR1 Memory Total | {bar1_total}")
print(f"GPU | {gpu_name}")
print(f"GPU Memory Total | {mem_total}")
print(f"BAR1 Memory Total | {bar1_total}")
print("----------------------------------")
print(f"nbytes | {args.nbytes} bytes ({format_bytes(args.nbytes)})")
print(f"4K aligned | {args.nbytes % 4096 == 0}")
print(f"nruns | {args.nruns}")
print(f"nbytes | {nbytes}")
print(f"4K aligned | {args.nbytes % 4096 == 0}")
print(f"nruns | {args.nruns}")
print("==================================")

encode_output = ""
decode_output = ""
# Run each benchmark using the requested APIs
for comp_name, comp in ((n, compressors[n]) for n in args.compressors):
rs = []
Expand All @@ -109,7 +117,7 @@ def main(args):

def pprint_api_res(name, samples):
mean = statistics.mean(samples) if len(samples) > 1 else samples[0]
ret = f"{comp_name} {name}".ljust(18)
ret = f"{comp_name} {name}".ljust(24)
ret += f"| {format_bytes(mean).rjust(10)}/s".ljust(14)
if len(samples) > 1:
stdev = statistics.stdev(samples) / mean * 100
Expand All @@ -120,8 +128,12 @@ def pprint_api_res(name, samples):
ret = ret[:-2] + ")" # Replace trailing comma
return ret

print(pprint_api_res("encode", rs))
print(pprint_api_res("decode", ws))
encode_output += pprint_api_res("", rs) + "\n"
decode_output += pprint_api_res("", ws) + "\n"
print("Encode:")
print(encode_output)
print("Decode:")
print(decode_output)


if __name__ == "__main__":
Expand All @@ -143,7 +155,7 @@ def pprint_api_res(name, samples):
parser.add_argument(
"--nruns",
metavar="RUNS",
default=3,
default=1,
type=int,
help="Number of runs per API (default: %(default)s).",
)
Expand Down

0 comments on commit 010fd14

Please sign in to comment.