Skip to content

Commit cbff80f

Browse files
committedJun 6, 2018
Raise clear error if last training weights are not foundIf using the --weights=last (or --model=last) to resume trainingbut the weights are not found now it raises a clear error message.
1 parent a688a66 commit cbff80f

9 files changed

+19
-16
lines changed
 

‎mrcnn/model.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -2062,26 +2062,30 @@ def find_last(self):
20622062
"""Finds the last checkpoint file of the last trained model in the
20632063
model directory.
20642064
Returns:
2065-
log_dir: The directory where events and weights are saved
2066-
checkpoint_path: the path to the last checkpoint file
2065+
The path of the last checkpoint file
20672066
"""
20682067
# Get directory names. Each directory corresponds to a model
20692068
dir_names = next(os.walk(self.model_dir))[1]
20702069
key = self.config.NAME.lower()
20712070
dir_names = filter(lambda f: f.startswith(key), dir_names)
20722071
dir_names = sorted(dir_names)
20732072
if not dir_names:
2074-
return None, None
2073+
import errno
2074+
raise FileNotFoundError(
2075+
errno.ENOENT,
2076+
"Could not find model directory under {}".format(self.model_dir))
20752077
# Pick last directory
20762078
dir_name = os.path.join(self.model_dir, dir_names[-1])
20772079
# Find the last checkpoint
20782080
checkpoints = next(os.walk(dir_name))[2]
20792081
checkpoints = filter(lambda f: f.startswith("mask_rcnn"), checkpoints)
20802082
checkpoints = sorted(checkpoints)
20812083
if not checkpoints:
2082-
return dir_name, None
2084+
import errno
2085+
raise FileNotFoundError(
2086+
errno.ENOENT, "Could not find weight files in {}".format(dir_name))
20832087
checkpoint = os.path.join(dir_name, checkpoints[-1])
2084-
return dir_name, checkpoint
2088+
return checkpoint
20852089

20862090
def load_weights(self, filepath, by_name=False, exclude=None):
20872091
"""Modified version of the correspoding Keras function with

‎samples/balloon/balloon.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ class InferenceConfig(BalloonConfig):
336336
utils.download_trained_weights(weights_path)
337337
elif args.weights.lower() == "last":
338338
# Find last trained weights
339-
weights_path = model.find_last()[1]
339+
weights_path = model.find_last()
340340
elif args.weights.lower() == "imagenet":
341341
# Start from ImageNet trained weights
342342
weights_path = model.get_imagenet_weights()

‎samples/balloon/inspect_balloon_model.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@
265265
"# weights_path = \"/path/to/mask_rcnn_balloon.h5\"\n",
266266
"\n",
267267
"# Or, load the last model you trained\n",
268-
"weights_path = model.find_last()[1]\n",
268+
"weights_path = model.find_last()\n",
269269
"\n",
270270
"# Load weights\n",
271271
"print(\"Loading weights \", weights_path)\n",

‎samples/coco/coco.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ class InferenceConfig(CocoConfig):
462462
model_path = COCO_MODEL_PATH
463463
elif args.model.lower() == "last":
464464
# Find last trained weights
465-
model_path = model.find_last()[1]
465+
model_path = model.find_last()
466466
elif args.model.lower() == "imagenet":
467467
# Start from ImageNet trained weights
468468
model_path = model.get_imagenet_weights()

‎samples/coco/inspect_model.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@
270270
"elif config.NAME == \"coco\":\n",
271271
" weights_path = COCO_MODEL_PATH\n",
272272
"# Or, uncomment to load the last model you trained\n",
273-
"# weights_path = model.find_last()[1]\n",
273+
"# weights_path = model.find_last()\n",
274274
"\n",
275275
"# Load weights\n",
276276
"print(\"Loading weights \", weights_path)\n",

‎samples/coco/inspect_weights.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@
150150
"elif config.NAME == \"coco\":\n",
151151
" weights_path = COCO_MODEL_PATH\n",
152152
"# Or, uncomment to load the last model you trained\n",
153-
"# weights_path = model.find_last()[1]\n",
153+
"# weights_path = model.find_last()\n",
154154
"\n",
155155
"# Load weights\n",
156156
"print(\"Loading weights \", weights_path)\n",

‎samples/nucleus/inspect_nucleus_model.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@
258258
"# weights_path = \"/path/to/mask_rcnn_nucleus.h5\"\n",
259259
"\n",
260260
"# Or, load the last model you trained\n",
261-
"weights_path = model.find_last()[1]\n",
261+
"weights_path = model.find_last()\n",
262262
"\n",
263263
"# Load weights\n",
264264
"print(\"Loading weights \", weights_path)\n",

‎samples/nucleus/nucleus.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def detect(model, dataset_dir, subset):
464464
utils.download_trained_weights(weights_path)
465465
elif args.weights.lower() == "last":
466466
# Find last trained weights
467-
weights_path = model.find_last()[1]
467+
weights_path = model.find_last()
468468
elif args.weights.lower() == "imagenet":
469469
# Start from ImageNet trained weights
470470
weights_path = model.get_imagenet_weights()

‎samples/shapes/train_shapes.ipynb

+3-4
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@
458458
" \"mrcnn_bbox\", \"mrcnn_mask\"])\n",
459459
"elif init_with == \"last\":\n",
460460
" # Load the last model you trained and continue training\n",
461-
" model.load_weights(model.find_last()[1], by_name=True)"
461+
" model.load_weights(model.find_last(), by_name=True)"
462462
]
463463
},
464464
{
@@ -875,10 +875,9 @@
875875
"# Get path to saved weights\n",
876876
"# Either set a specific path or find last trained weights\n",
877877
"# model_path = os.path.join(ROOT_DIR, \".h5 file name here\")\n",
878-
"model_path = model.find_last()[1]\n",
878+
"model_path = model.find_last()\n",
879879
"\n",
880-
"# Load trained weights (fill in path to trained weights here)\n",
881-
"assert model_path != \"\", \"Provide path to trained weights\"\n",
880+
"# Load trained weights\n",
882881
"print(\"Loading weights from \", model_path)\n",
883882
"model.load_weights(model_path, by_name=True)"
884883
]

0 commit comments

Comments
 (0)