Skip to content

Commit

Permalink
Merge pull request #107 from prickly-u/evaluate_fix
Browse files Browse the repository at this point in the history
evaluate.py fix
  • Loading branch information
gosha20777 authored Feb 18, 2020
2 parents 48aba41 + 9ef4c59 commit 6e1f8e0
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions keras_retinanet/bin/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@
from .. import models
from ..preprocessing.csv_generator import CSVGenerator
from ..preprocessing.pascal_voc import PascalVocGenerator
from ..utils.anchors import make_shapes_callback
from ..utils.config import read_config_file, parse_anchor_parameters
from ..utils.eval import evaluate
from ..utils.gpu import setup_gpu
from ..utils.keras_version import check_keras_version
from ..utils.tf_version import check_tf_version


def create_generator(args):
def create_generator(args, preprocess_image):
""" Create generators for evaluation.
"""
common_args = {
'preprocess_image' : preprocess_image,
}

if args.dataset_type == 'coco':
# import here to prevent unnecessary dependency on cocoapi
from ..preprocessing.coco import CocoGenerator
Expand All @@ -49,6 +54,7 @@ def create_generator(args):
image_max_side=args.image_max_side,
config=args.config,
shuffle_groups=False,
**common_args
)
elif args.dataset_type == 'pascal':
validation_generator = PascalVocGenerator(
Expand All @@ -58,6 +64,7 @@ def create_generator(args):
image_max_side=args.image_max_side,
config=args.config,
shuffle_groups=False,
**common_args
)
elif args.dataset_type == 'csv':
validation_generator = CSVGenerator(
Expand All @@ -67,6 +74,7 @@ def create_generator(args):
image_max_side=args.image_max_side,
config=args.config,
shuffle_groups=False,
**common_args
)
else:
raise ValueError('Invalid data type received: {}'.format(args.dataset_type))
Expand Down Expand Up @@ -129,7 +137,8 @@ def main(args=None):
args.config = read_config_file(args.config)

# create the generator
generator = create_generator(args)
backbone = models.backbone(args.backbone)
generator = create_generator(args, backbone.preprocess_image)

# optionally load anchor parameters
anchor_params = None
Expand All @@ -140,6 +149,8 @@ def main(args=None):
print('Loading model, this may take a second...')
model = models.load_model(args.model, backbone_name=args.backbone)

generator.compute_shapes = make_shapes_callback(model)

# optionally convert the model
if args.convert_model:
model = models.convert_model(model, anchor_params=anchor_params)
Expand Down

0 comments on commit 6e1f8e0

Please sign in to comment.