Skip to content

Commit

Permalink
bugfix , support for torch 2.0 (#22)
Browse files Browse the repository at this point in the history
* bug fix , support for torch 2.0

* add git ignore
  • Loading branch information
hellojixian authored Jun 30, 2023
1 parent a60299c commit 03b0dd5
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
result/
6 changes: 4 additions & 2 deletions benchmark_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def train(precision="single"):
benchmark = {}
for model_type in MODEL_LIST.keys():
for model_name in MODEL_LIST[model_type]:
model = getattr(model_type, model_name)(pretrained=False)
if model_name[-8:] == '_Weights': continue
model = getattr(model_type, model_name)()
if args.NUM_GPU > 1:
model = nn.DataParallel(model, device_ids=range(args.NUM_GPU))
model = getattr(model, precision)()
Expand Down Expand Up @@ -121,7 +122,8 @@ def inference(precision="float"):
with torch.no_grad():
for model_type in MODEL_LIST.keys():
for model_name in MODEL_LIST[model_type]:
model = getattr(model_type, model_name)(pretrained=False)
if model_name[-8:] == '_Weights': continue
model = getattr(model_type, model_name)()
if args.NUM_GPU > 1:
model = nn.DataParallel(model, device_ids=range(args.NUM_GPU))
model = getattr(model, precision)()
Expand Down
2 changes: 1 addition & 1 deletion requirement.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
matplotlib
torchvision
torch>=1.0.0
torch==2.0.0
pandas
plotly
cufflinks
Expand Down

0 comments on commit 03b0dd5

Please sign in to comment.