Skip to content

Commit a60299c

Browse files
committed
Update code formatting
1 parent f810c57 commit a60299c

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

benchmark_models.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,28 @@
3535
device_name = str(torch.cuda.get_device_name(0))
3636
# Training settings
3737
parser = argparse.ArgumentParser(description="PyTorch Benchmarking")
38-
parser.add_argument("--WARM_UP", "-w", type=int, default=5, required=False, help="Num of warm up")
39-
parser.add_argument("--NUM_TEST", "-n", type=int, default=50, required=False, help="Num of Test")
38+
parser.add_argument(
39+
"--WARM_UP", "-w", type=int, default=5, required=False, help="Num of warm up"
40+
)
41+
parser.add_argument(
42+
"--NUM_TEST", "-n", type=int, default=50, required=False, help="Num of Test"
43+
)
4044
parser.add_argument(
4145
"--BATCH_SIZE", "-b", type=int, default=12, required=False, help="Num of batch size"
4246
)
4347
parser.add_argument(
4448
"--NUM_CLASSES", "-c", type=int, default=1000, required=False, help="Num of class"
4549
)
46-
parser.add_argument("--NUM_GPU", "-g", type=int, default=1, required=False, help="Num of gpus")
4750
parser.add_argument(
48-
"--folder", "-f", type=str, default="result", required=False, help="folder to save results"
51+
"--NUM_GPU", "-g", type=int, default=1, required=False, help="Num of gpus"
52+
)
53+
parser.add_argument(
54+
"--folder",
55+
"-f",
56+
type=str,
57+
default="result",
58+
required=False,
59+
help="folder to save results",
4960
)
5061
args = parser.parse_args()
5162
args.BATCH_SIZE *= args.NUM_GPU
@@ -97,7 +108,9 @@ def train(precision="single"):
97108
end = time.time()
98109
if step >= args.WARM_UP:
99110
durations.append((end - start) * 1000)
100-
print(f"{model_name} model average train time : {sum(durations)/len(durations)}ms")
111+
print(
112+
f"{model_name} model average train time : {sum(durations)/len(durations)}ms"
113+
)
101114
del model
102115
benchmark[model_name] = durations
103116
return benchmark
@@ -115,7 +128,9 @@ def inference(precision="float"):
115128
model = model.to("cuda")
116129
model.eval()
117130
durations = []
118-
print(f"Benchmarking Inference {precision} precision type {model_name} ")
131+
print(
132+
f"Benchmarking Inference {precision} precision type {model_name} "
133+
)
119134
for step, img in enumerate(rand_loader):
120135
img = getattr(img, precision)()
121136
torch.cuda.synchronize()

0 commit comments

Comments
 (0)