Skip to content

Commit 5568d9e

Browse files
authored
Update prepare_models.py
Added docstring
1 parent 663b187 commit 5568d9e

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

prepare_models.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
import os
12
import torch
23
from torch.autograd import Variable as V
34
import torchvision.models as models
45
from torch.nn import functional as F
5-
import os
66

7-
''' Function to prepare context and body model.'''
7+
88
def prep_models(context_model='resnet18', body_model='resnet18', model_dir='./'):
9+
''' Download imagenet pretrained models for context_model and body_model.
10+
:param context_model: Model to use for conetxt features.
11+
:param body_model: Model to use for body features.
12+
:param model_dir: Directory path where to store pretrained models.
13+
:return: Yolo model after loading model weights
14+
'''
915
model_name = '%s_places365.pth.tar' % context_model
1016
model_file = os.path.join(model_dir, model_name)
1117
if not os.path.exists(model_file):
@@ -47,7 +53,8 @@ def prep_models(context_model='resnet18', body_model='resnet18', model_dir='./')
4753
print ('completed preparing body model')
4854
return model_context, model_body
4955

56+
5057
if __name__ == '__main__':
51-
prep_models(model_dir='./')
58+
prep_models(model_dir='proj/debug_exp/models')
5259

5360

0 commit comments

Comments
 (0)