From e71e0aa93574d47d7cc2482f3d9b68190fce5470 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 1 Mar 2023 14:24:40 +0200 Subject: [PATCH 01/13] YOLO model --- test/assets/yolov4-tiny-3l.cfg | 327 +++ test/test_models.py | 158 +- test/test_models_detection_anchor_utils.py | 39 +- test/test_models_detection_target_matching.py | 80 + test/test_models_detection_yolo_networks.py | 83 + torchvision/models/detection/__init__.py | 3 + torchvision/models/detection/anchor_utils.py | 54 + .../models/detection/target_matching.py | 375 ++++ torchvision/models/detection/yolo.py | 414 ++++ torchvision/models/detection/yolo_loss.py | 291 +++ torchvision/models/detection/yolo_networks.py | 1980 +++++++++++++++++ torchvision/models/yolo.py | 729 ++++++ 12 files changed, 4476 insertions(+), 57 deletions(-) create mode 100644 test/assets/yolov4-tiny-3l.cfg create mode 100644 test/test_models_detection_target_matching.py create mode 100644 test/test_models_detection_yolo_networks.py create mode 100644 torchvision/models/detection/target_matching.py create mode 100644 torchvision/models/detection/yolo.py create mode 100644 torchvision/models/detection/yolo_loss.py create mode 100644 torchvision/models/detection/yolo_networks.py create mode 100644 torchvision/models/yolo.py diff --git a/test/assets/yolov4-tiny-3l.cfg b/test/assets/yolov4-tiny-3l.cfg new file mode 100644 index 00000000000..c2dcb29e481 --- /dev/null +++ b/test/assets/yolov4-tiny-3l.cfg @@ -0,0 +1,327 @@ +[net] +batch=64 +subdivisions=1 +width=608 +height=608 +channels=3 +momentum=0.9 +decay=0.0005 +angle=0 +saturation = 1.5 +exposure = 1.5 +hue=.1 + +learning_rate=0.00261 +burn_in=1000 +max_batches = 500200 +policy=steps +steps=400000,450000 +scales=.1,.1 + +[convolutional] +batch_normalize=1 +filters=32 +size=3 +stride=2 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=64 +size=3 +stride=2 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=64 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[convolutional] +batch_normalize=1 +filters=32 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=32 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -1,-2 + +[convolutional] +batch_normalize=1 +filters=64 +size=1 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -6,-1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[convolutional] +batch_normalize=1 +filters=64 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=64 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -1,-2 + +[convolutional] +batch_normalize=1 +filters=128 +size=1 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -6,-1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=256 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers=-1 +groups=2 +group_id=1 + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -1,-2 + +[convolutional] +batch_normalize=1 +filters=256 +size=1 +stride=1 +pad=1 +activation=leaky + +[route] +layers = -6,-1 + +[maxpool] +size=2 +stride=2 + +[convolutional] +batch_normalize=1 +filters=512 +size=3 +stride=1 +pad=1 +activation=leaky + +################################## + +[convolutional] +batch_normalize=1 +filters=256 +size=1 +stride=1 +pad=1 +activation=leaky + +[convolutional] +batch_normalize=1 +filters=512 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +size=1 +stride=1 +pad=1 +filters=255 +activation=linear + + + +[yolo] +mask = 6,7,8 +anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401 +classes=80 +num=9 +jitter=.3 +scale_x_y = 1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +iou_loss=ciou +ignore_thresh = .7 +truth_thresh = 1 +random=0 +resize=1.5 +nms_kind=greedynms +beta_nms=0.6 + +[route] +layers = -4 + +[convolutional] +batch_normalize=1 +filters=128 +size=1 +stride=1 +pad=1 +activation=leaky + +[upsample] +stride=2 + +[route] +layers = -1, 23 + +[convolutional] +batch_normalize=1 +filters=256 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +size=1 +stride=1 +pad=1 +filters=255 +activation=linear + +[yolo] +mask = 3,4,5 +anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401 +classes=80 +num=9 +jitter=.3 +scale_x_y = 1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +iou_loss=ciou +ignore_thresh = .7 +truth_thresh = 1 +random=0 +resize=1.5 +nms_kind=greedynms +beta_nms=0.6 + + +[route] +layers = -3 + +[convolutional] +batch_normalize=1 +filters=64 +size=1 +stride=1 +pad=1 +activation=leaky + +[upsample] +stride=2 + +[route] +layers = -1, 15 + +[convolutional] +batch_normalize=1 +filters=128 +size=3 +stride=1 +pad=1 +activation=leaky + +[convolutional] +size=1 +stride=1 +pad=1 +filters=255 +activation=linear + +[yolo] +mask = 0,1,2 +anchors = 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401 +classes=80 +num=9 +jitter=.3 +scale_x_y = 1.05 +cls_normalizer=1.0 +iou_normalizer=0.07 +iou_loss=ciou +ignore_thresh = .7 +truth_thresh = 1 +random=0 +resize=1.5 +nms_kind=greedynms +beta_nms=0.6 diff --git a/test/test_models.py b/test/test_models.py index e1a288f4eb5..f9128d7d0e1 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -19,10 +19,12 @@ from PIL import Image from torchvision import models, transforms from torchvision.models import get_model_builder, list_models +from torchvision.models.detection import yolo_darknet ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1" +DARKNET_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "yolov4-tiny-3l.cfg") def list_model_fns(module): @@ -370,6 +372,10 @@ def _check_input_backprop(model, inputs): "input_shape": (1, 3, 16, 224, 224), }, "googlenet": {"init_weights": True}, + "yolov4": { + "num_classes": 10, + "input_shape": (3, 224, 224), + }, } # speeding up slow models: slow_models = [ @@ -467,6 +473,10 @@ def is_skippable(model_name, device): "max_trainable": 5, "n_trn_params_per_layer": [54, 64, 83, 96, 106, 107], }, + "yolov4": { + "max_trainable": 5, + "n_trn_params_per_layer": [1, 1, 1, 1, 1, 1], # TODO: Fill the correct values. + }, } @@ -783,6 +793,61 @@ def check_out(out): _check_input_backprop(model, x) +def check_model_output(out): + assert len(out) == 1 + + def compact(tensor): + tensor = tensor.cpu() + size = tensor.size() + elements_per_sample = functools.reduce(operator.mul, size[1:], 1) + if elements_per_sample > 30: + return compute_mean_std(tensor) + else: + return subsample_tensor(tensor) + + def subsample_tensor(tensor): + num_elems = tensor.size(0) + num_samples = 20 + if num_elems <= num_samples: + return tensor + + ith_index = num_elems // num_samples + return tensor[ith_index - 1 :: ith_index] + + def compute_mean_std(tensor): + # can't compute mean of integral tensor + tensor = tensor.to(torch.double) + mean = torch.mean(tensor) + std = torch.std(tensor) + return {"mean": mean, "std": std} + + output = map_nested_tensor_object(out, tensor_map_fn=compact) + prec = 0.01 + try: + # We first try to assert the entire output if possible. This is not + # only the best way to assert results but also handles the cases + # where we need to create a new expected result. + _assert_expected(output, model_name, prec=prec) + except AssertionError: + # Unfortunately detection models are flaky due to the unstable sort + # in NMS. If matching across all outputs fails, use the same approach + # as in NMSTester.test_nms_cuda to see if this is caused by duplicate + # scores. + expected_file = _get_expected_file(model_name) + expected = torch.load(expected_file) + torch.testing.assert_close( + output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False + ) + + # Note: Fmassa proposed turning off NMS by adapting the threshold + # and then using the Hungarian algorithm as in DETR to find the + # best match between output and expected boxes and eliminate some + # of the flakiness. Worth exploring. + return False # Partial validation performed + + return True # Full validation performed + + @pytest.mark.parametrize("model_fn", list_model_fns(models.detection)) @pytest.mark.parametrize("dev", cpu_and_gpu()) def test_detection_model(model_fn, dev): @@ -809,61 +874,7 @@ def test_detection_model(model_fn, dev): out = model(model_input) assert model_input[0] is x - def check_out(out): - assert len(out) == 1 - - def compact(tensor): - tensor = tensor.cpu() - size = tensor.size() - elements_per_sample = functools.reduce(operator.mul, size[1:], 1) - if elements_per_sample > 30: - return compute_mean_std(tensor) - else: - return subsample_tensor(tensor) - - def subsample_tensor(tensor): - num_elems = tensor.size(0) - num_samples = 20 - if num_elems <= num_samples: - return tensor - - ith_index = num_elems // num_samples - return tensor[ith_index - 1 :: ith_index] - - def compute_mean_std(tensor): - # can't compute mean of integral tensor - tensor = tensor.to(torch.double) - mean = torch.mean(tensor) - std = torch.std(tensor) - return {"mean": mean, "std": std} - - output = map_nested_tensor_object(out, tensor_map_fn=compact) - prec = 0.01 - try: - # We first try to assert the entire output if possible. This is not - # only the best way to assert results but also handles the cases - # where we need to create a new expected result. - _assert_expected(output, model_name, prec=prec) - except AssertionError: - # Unfortunately detection models are flaky due to the unstable sort - # in NMS. If matching across all outputs fails, use the same approach - # as in NMSTester.test_nms_cuda to see if this is caused by duplicate - # scores. - expected_file = _get_expected_file(model_name) - expected = torch.load(expected_file) - torch.testing.assert_close( - output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False - ) - - # Note: Fmassa proposed turning off NMS by adapting the threshold - # and then using the Hungarian algorithm as in DETR to find the - # best match between output and expected boxes and eliminate some - # of the flakiness. Worth exploring. - return False # Partial validation performed - - return True # Full validation performed - - full_validation = check_out(out) + full_validation = check_model_output(out) _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) if dev == "cuda": @@ -871,7 +882,7 @@ def compute_mean_std(tensor): out = model(model_input) # See autocast_flaky_numerics comment at top of file. if model_name not in autocast_flaky_numerics: - full_validation &= check_out(out) + full_validation &= check_model_output(out) if not full_validation: msg = ( @@ -886,6 +897,41 @@ def compute_mean_std(tensor): _check_input_backprop(model, model_input) +@pytest.mark.parametrize("dev", cpu_and_gpu()) +def test_yolo_darknet(dev): + set_rng_seed(0) + dtype = torch.get_default_dtype() + input_shape = (3, 300, 300) + + model = yolo_darknet(DARKNET_CONFIG) + model.eval().to(device=dev, dtype=dtype) + x = _get_image(input_shape=input_shape, real_image=False, device=dev, dtype=dtype) + model_input = [x] + with torch.no_grad(), freeze_rng_state(): + out = model(model_input) + assert model_input[0] is x + + full_validation = check_model_output(out) + _check_jit_scriptable(model, ([x],), unwrapper=None, eager_out=out) + + if dev == "cuda": + with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state(): + out = model(model_input) + full_validation &= check_model_output(out) + + if not full_validation: + msg = ( + f"The output of yolo_darknet could only be partially validated. " + "This is likely due to unit-test flakiness, but you may " + "want to do additional manual checks if you made " + "significant changes to the codebase." + ) + warnings.warn(msg, RuntimeWarning) + pytest.skip(msg) + + _check_input_backprop(model, model_input) + + @pytest.mark.parametrize("model_fn", list_model_fns(models.detection)) def test_detection_model_validation(model_fn): set_rng_seed(0) diff --git a/test/test_models_detection_anchor_utils.py b/test/test_models_detection_anchor_utils.py index 645d4624d64..661b8685bda 100644 --- a/test/test_models_detection_anchor_utils.py +++ b/test/test_models_detection_anchor_utils.py @@ -1,7 +1,7 @@ import pytest import torch from common_utils import assert_equal -from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator +from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator, global_xy, grid_centers, grid_offsets from torchvision.models.detection.image_list import ImageList @@ -97,3 +97,40 @@ def test_defaultbox_generator(self): assert tuple(dboxes[1].shape) == (4, 4) torch.testing.assert_close(dboxes[0], dboxes_output, rtol=1e-5, atol=1e-8) torch.testing.assert_close(dboxes[1], dboxes_output, rtol=1e-5, atol=1e-8) + + +@pytest.mark.parametrize("width,height", [(10, 5)]) +def test_grid_offsets(width: int, height: int): + size = torch.tensor([width, height]) + offsets = grid_offsets(size) + assert offsets.shape == (height, width, 2) + assert torch.equal(offsets[0, :, 0], torch.arange(width, dtype=offsets.dtype)) + assert torch.equal(offsets[0, :, 1], torch.zeros(width, dtype=offsets.dtype)) + assert torch.equal(offsets[:, 0, 0], torch.zeros(height, dtype=offsets.dtype)) + assert torch.equal(offsets[:, 0, 1], torch.arange(height, dtype=offsets.dtype)) + + +@pytest.mark.parametrize("width,height", [(10, 5)]) +def test_grid_centers(width: int, height: int): + size = torch.tensor([width, height]) + centers = grid_centers(size) + assert centers.shape == (height, width, 2) + assert torch.equal(centers[0, :, 0], 0.5 + torch.arange(width, dtype=torch.float)) + assert torch.equal(centers[0, :, 1], 0.5 * torch.ones(width)) + assert torch.equal(centers[:, 0, 0], 0.5 * torch.ones(height)) + assert torch.equal(centers[:, 0, 1], 0.5 + torch.arange(height, dtype=torch.float)) + + +def test_global_xy(): + xy = torch.ones((2, 4, 4, 3, 2)) * 0.5 # 4x4 grid of coordinates to the center of the cell. + image_size = torch.tensor([400, 200]) + xy = global_xy(xy, image_size) + assert xy.shape == (2, 4, 4, 3, 2) + assert torch.all(xy[:, :, 0, :, 0] == 50) + assert torch.all(xy[:, 0, :, :, 1] == 25) + assert torch.all(xy[:, :, 1, :, 0] == 150) + assert torch.all(xy[:, 1, :, :, 1] == 75) + assert torch.all(xy[:, :, 2, :, 0] == 250) + assert torch.all(xy[:, 2, :, :, 1] == 125) + assert torch.all(xy[:, :, 3, :, 0] == 350) + assert torch.all(xy[:, 3, :, :, 1] == 175) diff --git a/test/test_models_detection_target_matching.py b/test/test_models_detection_target_matching.py new file mode 100644 index 00000000000..dc99669a947 --- /dev/null +++ b/test/test_models_detection_target_matching.py @@ -0,0 +1,80 @@ +import pytest +import torch +from torchvision.models.detection.anchor_utils import grid_centers +from torchvision.models.detection.target_matching import aligned_iou, iou_below, is_inside_box, _sim_ota_match + + +@pytest.mark.parametrize( + "dims1, dims2, expected_ious", + [ + ( + torch.tensor([[1.0, 1.0], [10.0, 1.0], [100.0, 10.0]]), + torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + torch.tensor([[1.0 / 10.0, 1.0 / 40.0], [1.0 / 19.0, 2.0 / 48.0], [10.0 / 1000.0, 20.0 / 1020.0]]), + ) + ], +) +def test_aligned_iou(dims1, dims2, expected_ious): + torch.testing.assert_close(aligned_iou(dims1, dims2), expected_ious) + + +def test_iou_below(): + tl = torch.rand((10, 10, 3, 2)) * 100 + br = tl + 10 + pred_boxes = torch.cat((tl, br), -1) + target_boxes = torch.stack((pred_boxes[1, 1, 0], pred_boxes[3, 5, 1])) + result = iou_below(pred_boxes, target_boxes, 0.9) + assert result.shape == (10, 10, 3) + assert not result[1, 1, 0] + assert not result[3, 5, 1] + + +def test_is_inside_box(): + """ + centers: + [[1,1; 3,1; 5,1; 7,1; 9,1; 11,1; 13,1; 15,1; 17,1; 19,1] + [1,3; 3,3; 5,3; 7,3; 9,3; 11,3; 13,3; 15,3; 17,3; 19,3] + [1,5; 3,5; 5,5; 7,5; 9,5; 11,5; 13,5; 15,5; 17,5; 19,5] + [1,7; 3,7; 5,7; 7,7; 9,7; 11,7; 13,7; 15,7; 17,7; 19,7] + [1,9; 3,9; 5,9; 7,9; 9,9; 11,9; 13,9; 15,9; 17,9; 19,9]] + + is_inside[0]: + [[F, F, F, F, F, F, F, F, F, F] + [F, T, T, F, F, F, F, F, F, F] + [F, T, T, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F]] + + is_inside[1]: + [[F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, T, T, F]] + """ + size = torch.tensor([10, 5]) + centers = grid_centers(size) * 2.0 + centers = centers.view(-1, 2) + boxes = torch.tensor([[2, 2, 6, 6], [14, 8, 18, 10]]) + is_inside = is_inside_box(centers, boxes).view(2, 5, 10) + assert torch.count_nonzero(is_inside) == 6 + assert torch.all(is_inside[0, 1:3, 1:3]) + assert torch.all(is_inside[1, 4, 7:9]) + + +def test_sim_ota_match(): + # IoUs will determined that 2 and 1 predictions will be selected for the first and the second target. + ious = torch.tensor([[0.1, 0.1, 0.9, 0.9], [0.2, 0.3, 0.4, 0.1]]) + # Costs will determine that the first and the last prediction will be selected for the first target, and the first + # prediction will be selected for the second target. Since the first prediction was selected for both targets, it + # will be matched to the best target only (the second one). + costs = torch.tensor([[0.3, 0.5, 0.4, 0.3], [0.1, 0.2, 0.5, 0.3]]) + matched_preds, matched_targets = _sim_ota_match(costs, ious) + assert len(matched_preds) == 4 + assert matched_preds[0] + assert not matched_preds[1] + assert not matched_preds[2] + assert matched_preds[3] + assert len(matched_targets) == 2 # Two predictions were matched. + assert matched_targets[0] == 1 # Which target was matched to the first prediction. + assert matched_targets[1] == 0 # Which target was matched to the last prediction. diff --git a/test/test_models_detection_yolo_networks.py b/test/test_models_detection_yolo_networks.py new file mode 100644 index 00000000000..121f59df7cf --- /dev/null +++ b/test/test_models_detection_yolo_networks.py @@ -0,0 +1,83 @@ +import pytest +import torch.nn as nn +from torchvision.models.detection.yolo_networks import ( + _create_convolutional, + _create_maxpool, + _create_shortcut, + _create_upsample, +) + + +@pytest.mark.parametrize( + "config", + [ + ({"batch_normalize": 1, "filters": 8, "size": 3, "stride": 1, "pad": 1, "activation": "leaky"}), + ({"batch_normalize": 0, "filters": 2, "size": 1, "stride": 1, "pad": 1, "activation": "mish"}), + ({"batch_normalize": 1, "filters": 6, "size": 3, "stride": 2, "pad": 1, "activation": "logistic"}), + ({"batch_normalize": 0, "filters": 4, "size": 3, "stride": 2, "pad": 0, "activation": "linear"}), + ], +) +def test_create_convolutional(config): + conv, _ = _create_convolutional(config, [3]) + + assert conv.conv.out_channels == config["filters"] + assert conv.conv.kernel_size == (config["size"], config["size"]) + assert conv.conv.stride == (config["stride"], config["stride"]) + + pad_size = (config["size"] - 1) // 2 if config["pad"] else 0 + if config["pad"]: + assert conv.conv.padding == (pad_size, pad_size) + + if config["batch_normalize"]: + assert isinstance(conv.norm, nn.BatchNorm2d) + + if config["activation"] == "linear": + assert isinstance(conv.act, nn.Identity) + elif config["activation"] == "logistic": + assert isinstance(conv.act, nn.Sigmoid) + else: + assert conv.act.__class__.__name__.lower().startswith(config["activation"]) + + +@pytest.mark.parametrize( + "config", + [ + ({ "size": 2, "stride": 2 }), + ({ "size": 6, "stride": 3 }), + ], +) +def test_create_maxpool(config): + pad_size, remainder = divmod(max(config["size"], config["stride"]) - config["stride"], 2) + maxpool, _ = _create_maxpool(config, [3]) + + assert maxpool.maxpool.kernel_size == config["size"] + assert maxpool.maxpool.stride == config["stride"] + assert maxpool.maxpool.padding == pad_size + if remainder != 0: + assert isinstance(maxpool.pad, nn.ZeroPad2d) + + +@pytest.mark.parametrize( + "config", + [ + ({"from": 1, "activation": "linear"}), + ({"from": 3, "activation": "linear"}), + ], +) +def test_create_shortcut(config): + shortcut, _ = _create_shortcut(config, [3]) + + assert shortcut.source_layer == config["from"] + + +@pytest.mark.parametrize( + "config", + [ + ({"stride": 2}), + ({"stride": 4}), + ], +) +def test_create_upsample(config): + upsample, _ = _create_upsample(config, [3]) + + assert upsample.scale_factor == float(config["stride"]) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 4146651c737..faff00ff86a 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -1,3 +1,4 @@ +from .darknet_network import DarknetNetwork from .faster_rcnn import * from .fcos import * from .keypoint_rcnn import * @@ -5,3 +6,5 @@ from .retinanet import * from .ssd import * from .ssdlite import * +from .yolo import * +from .yolo_networks import YOLOV4TinyNetwork, YOLOV4Network, YOLOV4P6Network, YOLOV5Network, YOLOV7Network, YOLOXNetwork diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 253f6502a9b..d7c01a4af11 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -266,3 +266,57 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten ) dboxes.append(dboxes_in_image) return dboxes + + +def grid_offsets(grid_size: Tensor) -> Tensor: + """Given a grid size, returns a tensor containing offsets to the grid cells. + + Args: + The width and height of the grid in a tensor. + + Returns: + A ``[height, width, 2]`` tensor containing the grid cell `(x, y)` offsets. + """ + x_range = torch.arange(grid_size[0].item(), device=grid_size.device) + y_range = torch.arange(grid_size[1].item(), device=grid_size.device) + grid_y, grid_x = torch.meshgrid((y_range, x_range), indexing="ij") + return torch.stack((grid_x, grid_y), -1) + + +def grid_centers(grid_size: Tensor) -> Tensor: + """Given a grid size, returns a tensor containing coordinates to the centers of the grid cells. + + Returns: + A ``[height, width, 2]`` tensor containing coordinates to the centers of the grid cells. + """ + return grid_offsets(grid_size) + 0.5 + + +@torch.jit.script +def global_xy(xy: Tensor, image_size: Tensor) -> Tensor: + """Adds offsets to the predicted box center coordinates to obtain global coordinates to the image. + + The predicted coordinates are interpreted as coordinates inside a grid cell whose width and height is 1. Adding + offset to the cell, dividing by the grid size, and multiplying by the image size, we get global coordinates in the + image scale. + + The function needs the ``@torch.jit.script`` decorator in order for ONNX generation to work. The tracing based + generator will loose track of e.g. ``xy.shape[1]`` and treat it as a Python variable and not a tensor. This will + cause the dimension to be treated as a constant in the model, which prevents dynamic input sizes. + + Args: + xy: The predicted center coordinates before scaling. Values from zero to one in a tensor sized + ``[batch_size, height, width, boxes_per_cell, 2]``. + image_size: Width and height in a vector that will be used to scale the coordinates. + + Returns: + Global coordinates scaled to the size of the network input image, in a tensor with the same shape as the input + tensor. + """ + height = xy.shape[1] + width = xy.shape[2] + grid_size = torch.tensor([width, height], device=xy.device) + # Scripting requires explicit conversion to a floating point type. + offset = grid_offsets(grid_size).to(xy.dtype).unsqueeze(2) # [height, width, 1, 2] + scale = torch.true_divide(image_size, grid_size) + return (xy + offset) * scale diff --git a/torchvision/models/detection/target_matching.py b/torchvision/models/detection/target_matching.py new file mode 100644 index 00000000000..d78456a74f5 --- /dev/null +++ b/torchvision/models/detection/target_matching.py @@ -0,0 +1,375 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Sequence, Tuple, Union + +import torch +from torch import Tensor + +from ...ops import box_convert, box_iou +from .anchor_utils import grid_centers +from .yolo_loss import YOLOLoss + + +def aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: + """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at + the same coordinates. + + Args: + dims1: Width and height of `N` boxes. Tensor of size ``[N, 2]``. + dims2: Width and height of `M` boxes. Tensor of size ``[M, 2]``. + + Returns: + Tensor of size ``[N, M]`` containing the pairwise IoU values for every element in ``dims1`` and ``dims2`` + """ + area1 = dims1[:, 0] * dims1[:, 1] # [N] + area2 = dims2[:, 0] * dims2[:, 1] # [M] + + inter_wh = torch.min(dims1[:, None, :], dims2) # [N, M, 2] + inter = inter_wh[:, :, 0] * inter_wh[:, :, 1] # [N, M] + union = area1[:, None] + area2 - inter # [N, M] + + return inter / union + + +def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Tensor: + """Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target + significantly (IoU greater than ``threshold``). + + Args: + pred_boxes: The predicted corner coordinates. Tensor of size ``[height, width, boxes_per_cell, 4]``. + target_boxes: Corner coordinates of the target boxes. Tensor of size ``[height, width, boxes_per_cell, 4]``. + + Returns: + A boolean tensor sized ``[height, width, boxes_per_cell]``, with ``False`` where the predicted box overlaps a + target significantly and ``True`` elsewhere. + """ + shape = pred_boxes.shape[:-1] + pred_boxes = pred_boxes.view(-1, 4) + ious = box_iou(pred_boxes, target_boxes) + best_iou = ious.max(-1).values + below_threshold = best_iou <= threshold + return below_threshold.view(shape) + + +def is_inside_box(points: Tensor, boxes: Tensor) -> Tensor: + """Get pairwise truth values of whether the point is inside the box. + + Args: + points: point (x, y) coordinates, [points, 2] + boxes: box (x1, y1, x2, y2) coordinates, [boxes, 4] + + Returns: + A tensor shaped ``[boxes, points]`` containing pairwise truth values of whether the points are inside the boxes. + """ + points = points.unsqueeze(0) # [1, points, 2] + boxes = boxes.unsqueeze(1) # [boxes, 1, 4] + lt = points - boxes[..., :2] # [boxes, points, 2] + rb = boxes[..., 2:] - points # [boxes, points, 2] + deltas = torch.cat((lt, rb), -1) # [boxes, points, 4] + return deltas.min(-1).values > 0.0 # [boxes, points] + + +class ShapeMatching(ABC): + """Selects which anchors are used to predict each target, by comparing the shape of the target box to a set of + prior shapes. + + Most YOLO variants match targets to anchors based on prior shapes that are assigned to the anchors in the model + configuration. The subclasses of ``ShapeMatching`` implement matching rules that compare the width and height of + the targets to each prior shape (regardless of the location where the target is). When the model includes multiple + detection layers, different shapes are defined for each layer. Usually there are three detection layers and three + prior shapes per layer. + + Args: + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + """ + + def __init__(self, ignore_bg_threshold: float = 0.7) -> None: + self.ignore_bg_threshold = ignore_bg_threshold + + def __call__( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + image_size: Tensor, + ) -> Tuple[List[Tensor], Tensor, Tensor]: + """For each target, selects predictions from the same grid cell, where the center of the target box is. + + Typically there are three predictions per grid cell. Subclasses implement ``match()``, which selects the + predictions within the grid cell. + + Args: + preds: Predictions for a single image. + targets: Training targets for a single image. + image_size: Input image width and height. + + Returns: + The indices of the matched predictions, background mask, and a mask for selecting the matched targets. + """ + height, width = preds["boxes"].shape[:2] + device = preds["boxes"].device + + # A multiplier for scaling image coordinates to feature map coordinates + grid_size = torch.tensor([width, height], device=device) + image_to_grid = torch.true_divide(grid_size, image_size) + + # Bounding box center coordinates are converted to the feature map dimensions so that the whole number tells the + # cell index and the fractional part tells the location inside the cell. + xywh = box_convert(targets["boxes"], in_fmt="xyxy", out_fmt="cxcywh") + grid_xy = xywh[:, :2] * image_to_grid + cell_i = grid_xy[:, 0].to(torch.int64).clamp(0, width - 1) + cell_j = grid_xy[:, 1].to(torch.int64).clamp(0, height - 1) + + target_selector, anchor_selector = self.match(xywh[:, 2:]) + cell_i = cell_i[target_selector] + cell_j = cell_j[target_selector] + + # Background mask is used to select anchors that are not responsible for predicting any object, for + # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a + # predicted box overlaps any target significantly, or if a prediction is matched to a target. + background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) + background_mask[cell_j, cell_i, anchor_selector] = False + + pred_selector = [cell_j, cell_i, anchor_selector] + + return pred_selector, background_mask, target_selector + + @abstractmethod + def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: + """Selects anchors for each target based on the predicted shapes. The subclasses implement this method. + + Args: + wh: A matrix of predicted width and height values. + + Returns: + matched_targets, matched_anchors: Two vectors or a `2xN` matrix. The first vector is used to select the + targets that this layer matched and the second one lists the matching anchors within the grid cell. + """ + pass + + +class HighestIoUMatching(ShapeMatching): + """For each target, select the prior shape that gives the highest IoU. + + This is the original YOLO matching rule. + + Args: + prior_shapes: A list of all the prior box dimensions. The list should contain (width, height) tuples in the + network input resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + """ + + def __init__( + self, prior_shapes: Sequence[Tuple[int, int]], prior_shape_idxs: Sequence[int], ignore_bg_threshold: float = 0.7 + ) -> None: + super().__init__(ignore_bg_threshold) + self.prior_shapes = prior_shapes + # anchor_map maps the anchor indices to anchors in this layer, or to -1 if it's not an anchor of this layer. + # This layer ignores the target if all the selected anchors are in another layer. + self.anchor_map = [ + prior_shape_idxs.index(idx) if idx in prior_shape_idxs else -1 for idx in range(len(prior_shapes)) + ] + + def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: + prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) + anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=wh.device) + + ious = aligned_iou(wh, prior_wh) + highest_iou_anchors = ious.max(1).indices + highest_iou_anchors = anchor_map[highest_iou_anchors] + matched_targets = highest_iou_anchors >= 0 + matched_anchors = highest_iou_anchors[matched_targets] + return matched_targets, matched_anchors + + +class IoUThresholdMatching(ShapeMatching): + """For each target, select all prior shapes that give a high enough IoU. + + Args: + prior_shapes: A list of all the prior box dimensions. The list should contain (width, height) tuples in the + network input resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. + threshold: IoU treshold for matching. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + """ + + def __init__( + self, + prior_shapes: Sequence[Tuple[int, int]], + prior_shape_idxs: Sequence[int], + threshold: float, + ignore_bg_threshold: float = 0.7, + ) -> None: + super().__init__(ignore_bg_threshold) + self.prior_shapes = [prior_shapes[idx] for idx in prior_shape_idxs] + self.threshold = threshold + + def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: + prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) + + ious = aligned_iou(wh, prior_wh) + above_threshold = (ious > self.threshold).nonzero() + return above_threshold.T + + +class SizeRatioMatching(ShapeMatching): + """For each target, select those prior shapes, whose width and height relative to the target is below given + ratio. + + This is the matching rule used by Ultralytics YOLOv5 implementation. + + Args: + prior_shapes: A list of all the prior box dimensions. The list should contain (width, height) tuples in the + network input resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. + threshold: Size ratio threshold for matching. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + """ + + def __init__( + self, + prior_shapes: Sequence[Tuple[int, int]], + prior_shape_idxs: Sequence[int], + threshold: float, + ignore_bg_threshold: float = 0.7, + ) -> None: + super().__init__(ignore_bg_threshold) + self.prior_shapes = [prior_shapes[idx] for idx in prior_shape_idxs] + self.threshold = threshold + + def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: + prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) + + wh_ratio = wh[:, None, :] / prior_wh[None, :, :] # [num_targets, num_anchors, 2] + wh_ratio = torch.max(wh_ratio, 1.0 / wh_ratio) + wh_ratio = wh_ratio.max(2).values # [num_targets, num_anchors] + below_threshold = (wh_ratio < self.threshold).nonzero() + return below_threshold.T + + +def _sim_ota_match(costs: Tensor, ious: Tensor) -> Tuple[Tensor, Tensor]: + """Implements the SimOTA matching rule. + + The number of units supplied by each supplier (training target) needs to be decided in the Optimal Transport + problem. "Dynamic k Estimation" uses the sum of the top 10 IoU values (casted to int) between the target and the + predicted boxes. + + Args: + costs: Sum of losses for (prediction, target) pairs: ``[targets, predictions]`` + ious: IoUs for (prediction, target) pairs: ``[targets, predictions]`` + + Returns: + A mask of predictions that were matched, and the indices of the matched targets. The latter contains as many + elements as there are ``True`` values in the mask. + """ + matching_matrix = torch.zeros_like(costs, dtype=torch.bool) + + if ious.numel() > 0: + # For each target, define k as the sum of the 10 highest IoUs. + top10_iou = torch.topk(ious, min(10, ious.shape[1])).values.sum(1) + ks = torch.clip(top10_iou.int(), min=1) + + # For each target, select k predictions with lowest cost. + for target_idx, (cost, k) in enumerate(zip(costs, ks)): + prediction_idx = torch.topk(cost, k, largest=False).indices + matching_matrix[target_idx, prediction_idx] = True + + # If there's more than one match for some prediction, match it with the best target. Now we consider all + # targets, regardless of whether they were originally matched with the prediction or not. + more_than_one_match = matching_matrix.sum(0) > 1 + best_targets = costs[:, more_than_one_match].argmin(0) + matching_matrix[:, more_than_one_match] = False + matching_matrix[best_targets, more_than_one_match] = True + + # For those predictions that were matched, get the index of the target. + pred_mask = matching_matrix.sum(0) > 0 + target_selector = matching_matrix[:, pred_mask].int().argmax(0) + return pred_mask, target_selector + + +class SimOTAMatching: + """Selects which anchors are used to predict each target using the SimOTA matching rule. + + This is the matching rule used by YOLOX. + + Args: + loss_func: A ``LossFunction`` object that can be used to calculate the pairwise costs. + range: For each target, restrict to the anchors that are within an `N x N` grid cell are centered at the target, + where `N` is the value of this parameter. + """ + + def __init__(self, loss_func: YOLOLoss, range: float = 5.0) -> None: + self.loss_func = loss_func + self.range = range + + def __call__( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + image_size: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + """For each target, selects predictions using the SimOTA matching rule. + + Args: + preds: Predictions for a single image. + targets: Training targets for a single image. + image_size: Input image width and height. + + Returns: + A mask of predictions that were matched, background mask (inverse of the first mask), and the indices of the + matched targets. The last tensor contains as many elements as there are ``True`` values in the first mask. + """ + height, width, boxes_per_cell, num_classes = preds["classprobs"].shape + device = preds["boxes"].device + + # A multiplier for scaling feature map coordinates to image coordinates + grid_size = torch.tensor([width, height], device=device) + grid_to_image = torch.true_divide(image_size, grid_size) + + # Create a matrix for selecting the anchors that are inside the target bounding boxes. + centers = grid_centers(grid_size).view(-1, 2) * grid_to_image + inside_matrix = is_inside_box(centers, targets["boxes"]) + + # Set the width and height of all target bounding boxes to self.range grid cells and create a matrix for + # selecting the anchors that are now inside the boxes. If a small target has no anchors inside its bounding + # box, it will be matched to one of these anchors, but a high penalty will ensure that anchors that are inside + # the bounding box will be preferred. + xywh = box_convert(targets["boxes"], in_fmt="xyxy", out_fmt="cxcywh") + xy = xywh[:, :2] + wh = self.range * grid_to_image * torch.ones_like(xy) + xywh = torch.cat((xy, wh), -1) + boxes = box_convert(xywh, in_fmt="cxcywh", out_fmt="xyxy") + close_matrix = is_inside_box(centers, boxes) + + # In the first step we restrict ourselves to the grid cells whose center is inside or close enough to one or + # more targets. The prediction grids are flattened and masked using a [height * width] boolean vector. + mask = (inside_matrix | close_matrix).sum(0) > 0 + shape = (height * width, boxes_per_cell) + fg_preds = { + "boxes": preds["boxes"].view(*shape, 4)[mask].view(-1, 4), + "confidences": preds["confidences"].view(shape)[mask].view(-1), + "classprobs": preds["classprobs"].view(*shape, num_classes)[mask].view(-1, num_classes), + } + + losses, ious = self.loss_func.pairwise(fg_preds, targets, input_is_normalized=False) + costs = losses.overlap + losses.confidence + losses.classification + costs += 100000.0 * ~inside_matrix[:, mask].repeat_interleave(boxes_per_cell, 1) + pred_mask, target_selector = _sim_ota_match(costs, ious) + + # Add the anchor dimension to the mask and replace True values with the results of the actual SimOTA matching. + mask = mask.view(height, width).unsqueeze(-1).repeat(1, 1, boxes_per_cell) + mask[mask.nonzero().T.tolist()] = pred_mask + + background_mask = torch.logical_not(mask) + + return mask, background_mask, target_selector diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py new file mode 100644 index 00000000000..969ef8ed409 --- /dev/null +++ b/torchvision/models/detection/yolo.py @@ -0,0 +1,414 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from ...ops import batched_nms +from ...transforms import functional as F +from .._api import register_model, WeightsEnum +from .._utils import _ovewrite_value_param +from ..yolo import YOLOV4Backbone +from .backbone_utils import _validate_trainable_layers +from .yolo_networks import DarknetNetwork, YOLOV4Network + +TARGET = Dict[str, Any] +TARGETS = List[TARGET] + + +def validate_batch(images: List[Tensor], targets: TARGETS) -> None: + """Reads a batch of data, validates the format, and stacks the images into a single tensor. + + Args: + batch: The batch of data read by the :class:`~torch.utils.data.DataLoader`. + + Returns: + The input batch with images stacked into a single tensor. + """ + if not images: + raise ValueError("No images in batch.") + + shape = images[0].shape + for image in images: + if not isinstance(image, Tensor): + raise ValueError(f"Expected image to be of type Tensor, got {type(image).__name__}.") + if image.shape != shape: + raise ValueError(f"Images with different shapes in one batch: {shape} and {image.shape}") + + if targets is None: + return + + if len(images) != len(targets): + raise ValueError(f"Got {len(images)} images, but targets for {len(targets)} images.") + + for target in targets: + boxes = target["boxes"] + if not isinstance(boxes, Tensor): + raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes).__name__}.") + if (boxes.ndim != 2) or (boxes.shape[-1] != 4): + raise ValueError(f"Expected target boxes to be tensors of shape [N, 4], got {list(boxes.shape)}.") + labels = target["labels"] + if not isinstance(labels, Tensor): + raise ValueError(f"Expected target labels to be of type Tensor, got {type(labels).__name__}.") + if (labels.ndim < 1) or (labels.ndim > 2) or (len(labels) != len(boxes)): + raise ValueError( + f"Expected target labels to be tensors of shape [N] or [N, num_classes], got {list(labels.shape)}." + ) + + +class YOLO(nn.Module): + """YOLO implementation that supports the most important features of YOLOv3, YOLOv4, YOLOv5, YOLOv7, Scaled-YOLOv4, + and YOLOX. + + *YOLOv3 paper*: `Joseph Redmon and Ali Farhadi `_ + + *YOLOv4 paper*: `Alexey Bochkovskiy, Chien-Yao Wang, and Hong-Yuan Mark Liao `_ + + *YOLOv7 paper*: `Chien-Yao Wang, Alexey Bochkovskiy, and Hong-Yuan Mark Liao `_ + + *Scaled-YOLOv4 paper*: `Chien-Yao Wang, Alexey Bochkovskiy, and Hong-Yuan Mark Liao + `_ + + *YOLOX paper*: `Zheng Ge, Songtao Liu, Feng Wang, Zeming Li, and Jian Sun `_ + + The network architecture can be written in PyTorch, or read from a Darknet configuration file using the + :class:`~.yolo_networks.DarknetNetwork` class. ``DarknetNetwork`` is also able to read weights that have been saved + by Darknet. + + The input is expected to be a list of images. Each image is a tensor with shape ``[channels, height, width]``. The + images from a single batch will be stacked into a single tensor, so the sizes have to match. Different batches can + have different image sizes, as long as the size is divisible by the ratio in which the network downsamples the + input. + + During training, the model expects both the image tensors and a list of targets. *Each target is a dictionary + containing the following tensors*: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in `(x1, y1, x2, y2)` format + - labels (``Int64Tensor[N]`` or ``BoolTensor[N, classes]``): the class label or a boolean class mask for each + ground-truth box + + :func:`~.yolo_module.YOLO.forward` method returns all predictions from all detection layers in one tensor with shape + ``[N, anchors, classes + 5]``, where ``anchors`` is the total number of anchors in all detection layers. The + coordinates are scaled to the input image size. During training it also returns a dictionary containing the + classification, box overlap, and confidence losses. + + During inference, the model requires only the image tensor. :func:`~.yolo_module.YOLO.infer` method filters and + processes the predictions. If a prediction has a high score for more than one class, it will be duplicated. *The + processed output is returned in a dictionary containing the following tensors*: + + - boxes (``FloatTensor[N, 4]``): predicted bounding box `(x1, y1, x2, y2)` coordinates in image space + - scores (``FloatTensor[N]``): detection confidences + - labels (``Int64Tensor[N]``): the predicted labels for each object + + Detection using a Darknet configuration and pretrained weights: + + >>> from urllib.request import urlretrieve + >>> import torch + >>> from torchvision.models.detection import DarknetNetwork, YOLO + >>> + >>> urlretrieve("https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny-3l.cfg", "yolov4-tiny-3l.cfg") + >>> urlretrieve("https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-tiny.conv.29", "yolov4-tiny.conv.29") + >>> network = DarknetNetwork("yolov4-tiny-3l.cfg", "yolov4-tiny.conv.29") + >>> model = YOLO(network) + >>> image = torch.rand(3, 608, 608) + >>> detections = model.infer(image) + + Detection using a predefined YOLOv4 network: + + >>> import torch + >>> from torchvision.models.detection import YOLOV4Network, YOLO + >>> + >>> network = YOLOV4Network(num_classes=91) + >>> model = YOLO(network) + >>> image = torch.rand(3, 608, 608) + >>> detections = model.infer(image) + + Args: + network: A module that represents the network layers. This can be obtained from a Darknet configuration using + :func:`~.yolo_networks.DarknetNetwork`, or it can be defined as PyTorch code. + confidence_threshold: Postprocessing will remove bounding boxes whose confidence score is not higher than this + threshold. + nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher confidence box is + higher than this threshold, if the predicted categories are equal. + detections_per_image: Keep at most this number of highest-confidence detections per image. + """ + + def __init__( + self, + network: nn.Module, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45, + detections_per_image: int = 300, + ) -> None: + super().__init__() + + self.network = network + self.confidence_threshold = confidence_threshold + self.nms_threshold = nms_threshold + self.detections_per_image = detections_per_image + + def forward( + self, + images: List[Tensor], + targets: Optional[TARGETS] = None, + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets + are provided, computes the losses from the detection layers. + + Detections are concatenated from the detection layers. Each detection layer will produce a number of detections + that depends on the size of the feature map and the number of anchors per feature map cell. + + Args: + images: Images to be processed. Tensor of size + ``[batch_size, channels, height, width]``. + targets: If set, computes losses from detection layers against these targets. A list of + target dictionaries, one for each image. + + Returns: + detections (:class:`~torch.Tensor`), losses (:class:`~torch.Tensor`): Detections, and if targets were + provided, a dictionary of losses. Detections are shaped ``[batch_size, anchors, classes + 5]``, where + ``anchors`` is the feature map size (width * height) times the number of anchors per cell. The predicted box + coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. + """ + validate_batch(images, targets) + detections, losses, hits = self.network(torch.stack(images), targets) + + detections = torch.cat(detections, 1) + if targets is None: + return detections + + losses = torch.stack(losses).sum(0) + return detections, losses, hits + + def infer(self, image: Tensor) -> Dict[str, Tensor]: + """Feeds an image to the network and returns the detected bounding boxes, confidence scores, and class + labels. + + If a prediction has a high score for more than one class, it will be duplicated. + + Args: + image: An input image, a tensor of uint8 values sized ``[channels, height, width]``. + + Returns: + A dictionary containing tensors "boxes", "scores", and "labels". "boxes" is a matrix of detected bounding + box `(x1, y1, x2, y2)` coordinates. "scores" is a vector of confidence scores for the bounding box + detections. "labels" is a vector of predicted class labels. + """ + if not isinstance(image, Tensor): + image = F.to_tensor(image) + + was_training = self.training + self.eval() + + detections = self([image]) + detections = self.process_detections(detections) + detections = detections[0] + + if was_training: + self.train() + return detections + + def process_detections(self, preds: Tensor) -> List[Dict[str, Tensor]]: + """Splits the detection tensor returned by a forward pass into a list of prediction dictionaries, and + filters them based on confidence threshold, non-maximum suppression (NMS), and maximum number of + predictions. + + If for any single detection there are multiple categories whose score is above the confidence threshold, the + detection will be duplicated to create one detection for each category. NMS processes one category at a time, + iterating over the bounding boxes in descending order of confidence score, and removes lower scoring boxes that + have an IoU greater than the NMS threshold with a higher scoring box. + + The returned detections are sorted by descending confidence. The items of the dictionaries are as follows: + - boxes (``Tensor[batch_size, N, 4]``): detected bounding box `(x1, y1, x2, y2)` coordinates + - scores (``Tensor[batch_size, N]``): detection confidences + - labels (``Int64Tensor[batch_size, N]``): the predicted class IDs + + Args: + preds: A tensor of detected bounding boxes and their attributes. + + Returns: + Filtered detections. A list of prediction dictionaries, one for each image. + """ + + def process(boxes: Tensor, confidences: Tensor, classprobs: Tensor) -> Dict[str, Any]: + scores = classprobs * confidences[:, None] + + # Select predictions with high scores. If a prediction has a high score for more than one class, it will be + # duplicated. + idxs, labels = (scores > self.confidence_threshold).nonzero().T + boxes = boxes[idxs] + scores = scores[idxs, labels] + + keep = batched_nms(boxes, scores, labels, self.nms_threshold) + keep = keep[: self.detections_per_image] + return {"boxes": boxes[keep], "scores": scores[keep], "labels": labels[keep]} + + return [process(p[..., :4], p[..., 4], p[..., 5:]) for p in preds] + + def process_targets(self, targets: TARGETS) -> TARGETS: + """Duplicates multi-label targets to create one target for each label. + + Args: + targets: List of target dictionaries. Each dictionary must contain "boxes" and "labels". "labels" is either + a one-dimensional list of class IDs, or a two-dimensional boolean class map. + + Returns: + Single-label targets. A list of target dictionaries, one for each image. + """ + + def process(boxes: Tensor, labels: Tensor, **other: Any) -> Dict[str, Any]: + if labels.ndim == 2: + idxs, labels = labels.nonzero().T + boxes = boxes[idxs] + return {"boxes": boxes, "labels": labels, **other} + + return [process(**t) for t in targets] + + +class YOLOV4_Backbone_Weights(WeightsEnum): + DEFAULT = None + + +class YOLOV4_Weights(WeightsEnum): + DEFAULT = None + + +def freeze_backbone_layers(backbone: nn.Module, trainable_layers: Optional[int], is_trained: bool) -> None: + """Freezes backbone layers layers that won't be used for training. + + Args: + backbone: The backbone network. + trainable_layers: Number of trainable layers (stages), starting from the final stage. + is_trained: Set to ``True`` when using pre-trained weights. Otherwise will issue a warning if + ``trainable_layers`` is set. + """ + num_layers = len(backbone.stages) + trainable_layers = _validate_trainable_layers(is_trained, trainable_layers, num_layers, 3) + + layers_to_train = [f"stages.{idx}" for idx in range(num_layers - trainable_layers, num_layers)] + if trainable_layers == num_layers: + layers_to_train.append("stem") + + for name, parameter in backbone.named_parameters(): + if all([not name.startswith(layer) for layer in layers_to_train]): + parameter.requires_grad_(False) + + +@register_model() +def yolov4( + weights: Optional[YOLOV4_Weights] = None, + progress: bool = True, + in_channels: int = 3, + num_classes: Optional[int] = None, + weights_backbone: Optional[YOLOV4_Backbone_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45, + detections_per_image: int = 300, + **kwargs: Any, +) -> YOLO: + """ + Constructs a YOLOv4 model. + + .. betastatus:: detection module + + Example: + + >>> import torch + >>> from torchvision.models.detection import yolov4, YOLOV4_Weights + >>> + >>> model = yolov4(weights=YOLOV4_Weights.DEFAULT) + >>> image = torch.rand(3, 608, 608) + >>> detections = model.infer(image) + + Args: + weights: Pretrained weights to use. See :class:`~.YOLOV4_Weights` below for more details and possible values. By + default, the model will be initialized randomly. + progress: If ``True``, displays a progress bar of the download to ``stderr``. + in_channels: Number of channels in the input image. + num_classes: Number of output classes of the model (including the background). By default, this value is set to + 91 or read from the weights. + weights_backbone: Pretrained weights for the backbone. See :class:`~.YOLOV4_Backbone_Weights` below for more + details and possible values. By default, the backbone will be initialized randomly. + trainable_backbone_layers: Number of trainable (not frozen) layers (stages), starting from the final stage. + Valid values are between 0 and the number of stages in the backbone. By default, this value is set to 3. + confidence_threshold: Postprocessing will remove bounding boxes whose confidence score is not higher than this + threshold. + nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher confidence box is + higher than this threshold, if the predicted categories are equal. + detections_per_image: Keep at most this number of highest-confidence detections per image. + **kwargs: Parameters passed to the ``.YOLOV4Network`` class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: .YOLOV4_Weights + :members: + + .. autoclass:: .YOLOV4_Backbone_Weights + :members: + """ + weights = YOLOV4_Weights.verify(weights) + weights_backbone = YOLOV4_Backbone_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + backbone_kwargs = {key: kwargs[key] for key in ("widths", "activation", "normalization") if key in kwargs} + backbone = YOLOV4Backbone(in_channels, **backbone_kwargs) + + is_trained = weights is not None or weights_backbone is not None + freeze_backbone_layers(backbone, trainable_backbone_layers, is_trained) + + if weights_backbone is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + network = YOLOV4Network(num_classes, backbone, **kwargs) + model = YOLO(network, confidence_threshold, nms_threshold, detections_per_image) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +def yolo_darknet( + config_path: str, + weights_path: Optional[str] = None, + confidence_threshold: float = 0.2, + nms_threshold: float = 0.45, + detections_per_image: int = 300, + **kwargs: Any, +) -> YOLO: + """ + Constructs a YOLO model from a Darknet configuration file. + + .. betastatus:: detection module + + Example: + + >>> from urllib.request import urlretrieve + >>> from torchvision.models.detection import yolo_darknet + >>> + >>> urlretrieve("https://raw.githubusercontent.com/AlexeyAB/darknet/master/cfg/yolov4-tiny-3l.cfg", "yolov4-tiny-3l.cfg") + >>> urlretrieve("https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-tiny.conv.29", "yolov4-tiny.conv.29") + >>> model = yolo_darknet("yolov4-tiny-3l.cfg", "yolov4-tiny.conv.29") + >>> image = torch.rand(3, 608, 608) + >>> detections = model.infer(image) + + Args: + config_path: Path to a Darknet configuration file that defines the network architecture. + weights_path: Path to a Darknet weights file to load. + confidence_threshold: Postprocessing will remove bounding boxes whose confidence score is not higher than this + threshold. + nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher confidence box is + higher than this threshold, if the predicted categories are equal. + detections_per_image: Keep at most this number of highest-confidence detections per image. + **kwargs: Parameters passed to the ``.YOLOV4Network`` class. Please refer to the `source code + `_ + for more details about this class. + """ + network = DarknetNetwork(config_path, weights_path) + return YOLO(network, confidence_threshold, nms_threshold, detections_per_image) diff --git a/torchvision/models/detection/yolo_loss.py b/torchvision/models/detection/yolo_loss.py new file mode 100644 index 00000000000..d8f05f93d49 --- /dev/null +++ b/torchvision/models/detection/yolo_loss.py @@ -0,0 +1,291 @@ +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn.functional import binary_cross_entropy, binary_cross_entropy_with_logits + +from torchvision.ops import ( + box_iou, + generalized_box_iou, + generalized_box_iou_loss, + distance_box_iou, + distance_box_iou_loss, + complete_box_iou, + complete_box_iou_loss, +) + + +def box_iou_loss(boxes1: Tensor, boxes2: Tensor) -> Tensor: + return 1.0 - box_iou(boxes1, boxes2).diagonal() + + +_iou_and_loss_functions = { + "iou": (box_iou, box_iou_loss), + "giou": (generalized_box_iou, generalized_box_iou_loss), + "diou": (distance_box_iou, distance_box_iou_loss), + "ciou": (complete_box_iou, complete_box_iou_loss), +} + + +def _get_iou_and_loss_functions(name: str) -> Tuple[Callable, Callable]: + """Returns functions for calculating the IoU and the IoU loss, given the IoU variant name. + + Args: + name: Name of the IoU variant. Either "iou", "giou", "diou", or "ciou". + + Returns: + A tuple of two functions. The first function calculates the pairwise IoU and the second function calculates the + elementwise loss. + """ + if name not in _iou_and_loss_functions: + raise ValueError(f"Unknown IoU function '{name}'.") + iou_func, loss_func = _iou_and_loss_functions[name] + if not callable(iou_func): + raise ValueError(f"The IoU function '{name}' is not supported by the installed version of Torchvision.") + assert callable(loss_func) + return iou_func, loss_func + + +def _size_compensation(targets: Tensor, image_size: Tensor) -> Tuple[Tensor, Tensor]: + """Calcuates the size compensation factor for the overlap loss. + + The overlap losses for each target should be multiplied by the returned weight. The returned value is + `2 - (unit_width * unit_height)`, which is large for small boxes (the maximum value is 2) and small for large boxes + (the minimum value is 1). + + Args: + targets: An ``[N, 4]`` matrix of target `(x1, y1, x2, y2)` coordinates. + image_size: Image size, which is used to scale the target boxes to unit coordinates. + + Returns: + The size compensation factor. + """ + unit_wh = targets[:, 2:] / image_size + return 2 - (unit_wh[:, 0] * unit_wh[:, 1]) + + +def _pairwise_confidence_loss( + preds: Tensor, overlap: Tensor, bce_func: Callable, predict_overlap: Optional[float] +) -> Tensor: + """Calculates the confidence loss for every pair of a foreground anchor and a target. + + If ``predict_overlap`` is ``True``, ``overlap`` will be used as the target confidence. Otherwise the target + confidence is 1. The method returns a matrix of losses for target/prediction pairs. + + Args: + preds: An ``[N]`` vector of predicted confidences. + overlap: An ``[M, N]`` matrix of overlaps between all target and predicted bounding boxes. + bce_func: A function for calculating binary cross entropy. + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the overlap. + + Returns: + An ``[M, N]`` matrix of confidence losses between all targets and predictions. + """ + if predict_overlap is not None: + # When predicting overlap, target confidence is different for each pair of a prediction and a target. The + # tensors have to be broadcasted to [M, N]. + preds = preds.unsqueeze(0).expand(overlap.shape) + targets = torch.ones_like(preds) - predict_overlap + # Distance-IoU may return negative "overlaps", so we have to make sure that the targets are not negative. + targets += predict_overlap * overlap.detach().clamp(min=0) + return bce_func(preds, targets, reduction="none") + else: + # When not predicting overlap, target confidence is the same for every target, but we should still return a + # matrix. + targets = torch.ones_like(preds) + return bce_func(preds, targets, reduction="none").unsqueeze(0).expand(overlap.shape) + + +def _foreground_confidence_loss( + preds: Tensor, overlap: Tensor, bce_func: Callable, predict_overlap: Optional[float] +) -> Tensor: + """Calculates the sum of the confidence losses for foreground anchors and their matched targets. + + If ``predict_overlap`` is ``True``, ``overlap`` will be used as the target confidence. Otherwise the target + confidence is 1. The method returns a vector of losses for each foreground anchor. + + Args: + preds: A vector of predicted confidences. + overlap: A vector of overlaps between matched target and predicted bounding boxes. + bce_func: A function for calculating binary cross entropy. + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the overlap. + + Returns: + The sum of the confidence losses for foreground anchors. + """ + targets = torch.ones_like(preds) + if predict_overlap is not None: + targets -= predict_overlap + # Distance-IoU may return negative "overlaps", so we have to make sure that the targets are not negative. + targets += predict_overlap * overlap.detach().clamp(min=0) + return bce_func(preds, targets, reduction="sum") + + +def _background_confidence_loss(preds: Tensor, bce_func: Callable) -> Tensor: + """Calculates the sum of the confidence losses for background anchors. + + Args: + preds: A vector of predicted confidences for background anchors. + bce_func: A function for calculating binary cross entropy. + + Returns: + The sum of the background confidence losses. + """ + targets = torch.zeros_like(preds) + return bce_func(preds, targets, reduction="sum") + + +def _target_labels_to_probs(targets: Tensor, num_classes: int, dtype: torch.dtype) -> Tensor: + """If ``targets`` is a vector of class labels, converts it to a matrix of one-hot class probabilities. + + Args: + targets: An ``[M, C]`` matrix of target class probabilities or an ``[M]`` vector of class labels. + num_classes: The number of classes (C dimension) for the new targets. If ``targets`` is already two-dimensional, + checks that the length of the second dimension matches this number. + dtype: Floating-point data type to be used for the one-hot targets. + + Returns: + An ``[M, C]`` matrix of target class probabilities. + """ + if targets.ndim == 1: + # The data may contain a different number of classes than what the model predicts. In case a label is + # greater than the number of predicted classes, it will be mapped to the last class. + last_class = torch.tensor(num_classes - 1, device=targets.device) + targets = torch.min(targets, last_class) + targets = torch.nn.functional.one_hot(targets, num_classes) + elif targets.shape[-1] != num_classes: + raise ValueError( + f"The number of classes in the data ({targets.shape[-1]}) doesn't match the number of classes " + f"predicted by the model ({num_classes})." + ) + return targets.to(dtype=dtype) + + +@dataclass +class Losses: + overlap: Tensor + confidence: Tensor + classification: Tensor + + +class YOLOLoss: + """A class for calculating the YOLO losses from predictions and targets. + + Args: + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou" (default). + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + """ + + def __init__( + self, + overlap_func: Union[str, Callable] = "ciou", + predict_overlap: Optional[float] = None, + overlap_multiplier: float = 5.0, + confidence_multiplier: float = 1.0, + class_multiplier: float = 1.0, + ): + if callable(overlap_func): + self._pairwise_overlap = overlap_func + self._elementwise_overlap_loss = lambda boxes1, boxes2: 1.0 - overlap_func(boxes1, boxes2).diagonal() + else: + self._pairwise_overlap, self._elementwise_overlap_loss = _get_iou_and_loss_functions(overlap_func) + + self.predict_overlap = predict_overlap + self.overlap_multiplier = overlap_multiplier + self.confidence_multiplier = confidence_multiplier + self.class_multiplier = class_multiplier + + def pairwise( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + input_is_normalized: bool, + ) -> Tuple[Losses, Tensor]: + """Calculates matrices containing the losses for all prediction/target pairs. + + This method is called for obtaining costs for SimOTA matching. + + Args: + preds: A dictionary of predictions, containing "boxes", "confidences", and "classprobs". + targets: A dictionary of training targets, containing "boxes" and "labels". + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. + + Returns: + Loss matrices and an overlap matrix. + """ + if input_is_normalized: + bce_func = binary_cross_entropy + else: + bce_func = binary_cross_entropy_with_logits + + overlap = self._pairwise_overlap(targets["boxes"], preds["boxes"]) + overlap_loss = 1.0 - overlap + + confidence_loss = _pairwise_confidence_loss(preds["confidences"], overlap, bce_func, self.predict_overlap) + + pred_probs = preds["classprobs"].unsqueeze(0) # [1, preds, classes] + target_probs = _target_labels_to_probs(targets["labels"], pred_probs.shape[-1], pred_probs.dtype) + target_probs = target_probs.unsqueeze(1) # [targets, 1, classes] + pred_probs, target_probs = torch.broadcast_tensors(pred_probs, target_probs) + class_loss = bce_func(pred_probs, target_probs, reduction="none").sum(-1) + + losses = Losses( + overlap_loss * self.overlap_multiplier, + confidence_loss * self.confidence_multiplier, + class_loss * self.class_multiplier, + ) + + return losses, overlap + + def elementwise_sums( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + input_is_normalized: bool, + image_size: Tensor, + ) -> Losses: + """Calculates the sums of the losses for optimization, over prediction/target pairs, assuming the + predictions and targets have been matched (there are as many predictions and targets). + + Args: + preds: A dictionary of predictions, containing "boxes", "confidences", and "classprobs". + targets: A dictionary of training targets, containing "boxes" and "labels". + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. + image_size: Width and height in a vector that defines the scale of the target coordinates. + + Returns: + The final losses. + """ + if input_is_normalized: + bce_func = binary_cross_entropy + else: + bce_func = binary_cross_entropy_with_logits + + overlap_loss = self._elementwise_overlap_loss(targets["boxes"], preds["boxes"]) + overlap = 1.0 - overlap_loss + overlap_loss = (overlap_loss * _size_compensation(targets["boxes"], image_size)).sum() + + confidence_loss = _foreground_confidence_loss(preds["confidences"], overlap, bce_func, self.predict_overlap) + confidence_loss += _background_confidence_loss(preds["bg_confidences"], bce_func) + + pred_probs = preds["classprobs"] + target_probs = _target_labels_to_probs(targets["labels"], pred_probs.shape[-1], pred_probs.dtype) + class_loss = bce_func(pred_probs, target_probs, reduction="sum") + + losses = Losses( + overlap_loss * self.overlap_multiplier, + confidence_loss * self.confidence_multiplier, + class_loss * self.class_multiplier, + ) + + return losses diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py new file mode 100644 index 00000000000..630cbe27711 --- /dev/null +++ b/torchvision/models/detection/yolo_networks.py @@ -0,0 +1,1980 @@ +import io +import re +from collections import OrderedDict +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from warnings import warn + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor + +from ...ops import box_convert +from ..yolo import ( + Conv, + CSPSPP, + CSPStage, + ELANStage, + FastSPP, + MaxPool, + RouteLayer, + ShortcutLayer, + YOLOV4Backbone, + YOLOV4TinyBackbone, + YOLOV5Backbone, + YOLOV7Backbone, +) +from .anchor_utils import global_xy +from .target_matching import ( + HighestIoUMatching, + IoUThresholdMatching, + ShapeMatching, + SimOTAMatching, + SizeRatioMatching, +) +from .yolo_loss import YOLOLoss + +CONFIG = Dict[str, Any] +CREATE_LAYER_OUTPUT = Tuple[nn.Module, int] # layer, num_outputs +TARGET = Dict[str, Any] +TARGETS = List[TARGET] +NETWORK_OUTPUT = Tuple[List[Tensor], List[Tensor], List[int]] # detections, losses, hits + + +class DetectionLayer(nn.Module): + """A YOLO detection layer. + + A YOLO model has usually 1 - 3 detection layers at different resolutions. The loss is summed from all of them. + + Args: + num_classes: Number of different classes that this layer predicts. + prior_shapes: A list of prior box dimensions for this layer, used for scaling the predicted dimensions. The list + should contain (width, height) tuples in the network input resolution. + matching_func: The matching algorithm to be used for assigning targets to anchors. + loss_func: ``LossFunction`` object for calculating the losses. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + input_is_normalized: The input is normalized by logistic activation in the previous layer. In this case the + detection layer will not take the sigmoid of the coordinate and probability predictions, and the width and + height are scaled up so that the maximum value is four times the anchor dimension. This is used by the + Darknet configurations of Scaled-YOLOv4. + """ + + def __init__( + self, + num_classes: int, + prior_shapes: List[Tuple[int, int]], + matching_func: Callable, + loss_func: YOLOLoss, + xy_scale: float = 1.0, + input_is_normalized: bool = False, + ) -> None: + super().__init__() + + self.num_classes = num_classes + self.prior_shapes = prior_shapes + self.matching_func = matching_func + self.loss_func = loss_func + self.xy_scale = xy_scale + self.input_is_normalized = input_is_normalized + + def forward(self, x: Tensor, image_size: Tensor) -> Tuple[Tensor, List[Dict[str, Tensor]]]: + """Runs a forward pass through this YOLO detection layer. + + Maps cell-local coordinates to global coordinates in the image space, scales the bounding boxes with the + anchors, converts the center coordinates to corner coordinates, and maps probabilities to the `]0, 1[` range + using sigmoid. + + If targets are given, computes also losses from the predictions and the targets. This layer is responsible only + for the targets that best match one of the anchors assigned to this layer. Training losses will be saved to the + ``losses`` attribute. ``hits`` attribute will be set to the number of targets that this layer was responsible + for. ``losses`` is a tensor of three elements: the overlap, confidence, and classification loss. + + Args: + x: The output from the previous layer. The size of this tensor has to be + ``[batch_size, anchors_per_cell * (num_classes + 5), height, width]``. + image_size: Image width and height in a vector (defines the scale of the predicted and target coordinates). + + Returns: + The layer output, with normalized probabilities, in a tensor sized + ``[batch_size, anchors_per_cell * height * width, num_classes + 5]`` and a list of dictionaries, containing + the same predictions, but with unnormalized probabilities (for loss calculation). + """ + batch_size, num_features, height, width = x.shape + num_attrs = self.num_classes + 5 + anchors_per_cell = int(torch.div(num_features, num_attrs, rounding_mode="floor")) + if anchors_per_cell != len(self.prior_shapes): + raise ValueError( + "The model predicts {} bounding boxes per spatial location, but {} prior box dimensions are defined " + "for this layer.".format(anchors_per_cell, len(self.prior_shapes)) + ) + + # Reshape the output to have the bounding box attributes of each grid cell on its own row. + x = x.permute(0, 2, 3, 1) # [batch_size, height, width, anchors_per_cell * num_attrs] + x = x.view(batch_size, height, width, anchors_per_cell, num_attrs) + + # Take the sigmoid of the bounding box coordinates, confidence score, and class probabilities, unless the input + # is normalized by the previous layer activation. Confidence and class losses use the unnormalized values if + # possible. + norm_x = x if self.input_is_normalized else torch.sigmoid(x) + xy = norm_x[..., :2] + wh = x[..., 2:4] + confidence = x[..., 4] + classprob = x[..., 5:] + norm_confidence = norm_x[..., 4] + norm_classprob = norm_x[..., 5:] + + # Eliminate grid sensitivity. The previous layer should output extremely high values for the sigmoid to produce + # x/y coordinates close to one. YOLOv4 solves this by scaling the x/y coordinates. + xy = xy * self.xy_scale - 0.5 * (self.xy_scale - 1) + + image_xy = global_xy(xy, image_size) + prior_shapes = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) + if self.input_is_normalized: + image_wh = 4 * torch.square(wh) * prior_shapes + else: + image_wh = torch.exp(wh) * prior_shapes + box = torch.cat((image_xy, image_wh), -1) + box = box_convert(box, in_fmt="cxcywh", out_fmt="xyxy") + output = torch.cat((box, norm_confidence.unsqueeze(-1), norm_classprob), -1) + output = output.reshape(batch_size, height * width * anchors_per_cell, num_attrs) + + # It's better to use binary_cross_entropy_with_logits() for loss computation, so we'll provide the unnormalized + # confidence and classprob, when available. + preds = [{"boxes": b, "confidences": c, "classprobs": p} for b, c, p in zip(box, confidence, classprob)] + + return output, preds + + def match_targets( + self, + preds: List[Dict[str, Tensor]], + return_preds: List[Dict[str, Tensor]], + targets: List[Dict[str, Tensor]], + image_size: Tensor, + ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: + """Matches the predictions to targets. + + Args: + preds: List of predictions for each image, as returned by the ``forward()`` method of this layer. These will + be matched to the training targets. + return_preds: List of predictions for each image. The matched predictions will be returned from this list. + When calculating the auxiliary loss for deep supervision, predictions from a different layer are used + for loss computation. + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + + Returns: + Two dictionaries, the matched predictions and targets. + """ + batch_size = len(preds) + if (len(targets) != batch_size) or (len(return_preds) != batch_size): + raise ValueError("Different batch size for predictions and targets.") + + matches = [] + for image_preds, image_return_preds, image_targets in zip(preds, return_preds, targets): + if image_targets["boxes"].shape[0] > 0: + pred_selector, background_selector, target_selector = self.matching_func( + image_preds, image_targets, image_size + ) + matched_preds = { + "boxes": image_return_preds["boxes"][pred_selector], + "confidences": image_return_preds["confidences"][pred_selector], + "bg_confidences": image_return_preds["confidences"][background_selector], + "classprobs": image_return_preds["classprobs"][pred_selector], + } + matched_targets = { + "boxes": image_targets["boxes"][target_selector], + "labels": image_targets["labels"][target_selector], + } + else: + device = image_preds["confidences"].device + matched_preds = { + "boxes": torch.empty((0, 4), device=device), + "confidences": torch.empty(0, device=device), + "bg_confidences": image_preds["confidences"].flatten(), + "classprobs": torch.empty((0, self.num_classes), device=device), + } + matched_targets = { + "boxes": torch.empty((0, 4), device=device), + "labels": torch.empty(0, dtype=torch.int64, device=device), + } + matches.append((matched_preds, matched_targets)) + + matched_preds = { + "boxes": torch.cat(tuple(m[0]["boxes"] for m in matches)), + "confidences": torch.cat(tuple(m[0]["confidences"] for m in matches)), + "bg_confidences": torch.cat(tuple(m[0]["bg_confidences"] for m in matches)), + "classprobs": torch.cat(tuple(m[0]["classprobs"] for m in matches)), + } + matched_targets = { + "boxes": torch.cat(tuple(m[1]["boxes"] for m in matches)), + "labels": torch.cat(tuple(m[1]["labels"] for m in matches)), + } + return matched_preds, matched_targets + + def calculate_losses( + self, + preds: List[Dict[str, Tensor]], + targets: List[Dict[str, Tensor]], + image_size: Tensor, + loss_preds: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[Tensor, int]: + """Matches the predictions to targets and computes the losses. + + Args: + preds: List of predictions for each image, as returned by ``forward()``. These will be matched to the + training targets and used to compute the losses (unless another set of predictions for loss computation + is given in ``loss_preds``). + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + loss_preds: List of predictions for each image. If given, these will be used for loss computation, instead + of the same predictions that were used for matching. This is needed for deep supervision in YOLOv7. + + Returns: + A vector of the overlap, confidence, and classification loss, normalized by batch size, and the number of + targets that were matched to this layer. + """ + if loss_preds is None: + loss_preds = preds + + matched_preds, matched_targets = self.match_targets(preds, loss_preds, targets, image_size) + + losses = self.loss_func.elementwise_sums(matched_preds, matched_targets, self.input_is_normalized, image_size) + losses = torch.stack((losses.overlap, losses.confidence, losses.classification)) / len(preds) + + hits = len(matched_targets["boxes"]) + + return losses, hits + + +def _create_detection_layer( + prior_shapes: Sequence[Tuple[int, int]], + prior_shape_idxs: Sequence[int], + matching_algorithm: Optional[str] = None, + matching_threshold: Optional[float] = None, + sim_ota_range: float = 5.0, + ignore_bg_threshold: float = 0.7, + overlap_func: Union[str, Callable] = "ciou", + predict_overlap: float = 1.0, + overlap_loss_multiplier: float = 5.0, + confidence_loss_multiplier: float = 1.0, + class_loss_multiplier: float = 1.0, + **kwargs: Any, +) -> DetectionLayer: + """Creates a detection layer module and the required loss function and target matching objects. + + Args: + prior_shapes: A list of all the prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain (width, height) tuples in the network input + resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + sim_ota_range: The "simota" matching algorithm will restrict to the anchors that are within an `N x N` grid cell + area centered at the target, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou" (default). + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + num_classes: Number of different classes that this layer predicts. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + input_is_normalized: The input is normalized by logistic activation in the previous layer. In this case the + detection layer will not take the sigmoid of the coordinate and probability predictions, and the width and + height are scaled up so that the maximum value is four times the anchor dimension. This is used by the + Darknet configurations of Scaled-YOLOv4. + """ + matching_func: Union[ShapeMatching, SimOTAMatching] + if matching_algorithm == "simota": + loss_func = YOLOLoss( + overlap_func, None, overlap_loss_multiplier, confidence_loss_multiplier, class_loss_multiplier + ) + matching_func = SimOTAMatching(loss_func, sim_ota_range) + elif matching_algorithm == "size": + if matching_threshold is None: + raise ValueError("matching_threshold is required with size ratio matching.") + matching_func = SizeRatioMatching(prior_shapes, prior_shape_idxs, matching_threshold, ignore_bg_threshold) + elif matching_algorithm == "iou": + if matching_threshold is None: + raise ValueError("matching_threshold is required with IoU threshold matching.") + matching_func = IoUThresholdMatching(prior_shapes, prior_shape_idxs, matching_threshold, ignore_bg_threshold) + elif matching_algorithm == "maxiou" or matching_algorithm is None: + matching_func = HighestIoUMatching(prior_shapes, prior_shape_idxs, ignore_bg_threshold) + else: + raise ValueError(f"Matching algorithm `{matching_algorithm}´ is unknown.") + + loss_func = YOLOLoss( + overlap_func, predict_overlap, overlap_loss_multiplier, confidence_loss_multiplier, class_loss_multiplier + ) + + layer_shapes = [prior_shapes[i] for i in prior_shape_idxs] + return DetectionLayer(prior_shapes=layer_shapes, matching_func=matching_func, loss_func=loss_func, **kwargs) + + +def _run_detection( + detection_layer: DetectionLayer, + layer_input: Tensor, + targets: Optional[List[Dict[str, Tensor]]], + image_size: Tensor, + detections: List[Tensor], + losses: List[Tensor], + hits: List[int], +) -> None: + """Runs the detection layer on the inputs and appends the output to the ``detections`` list. + + If ``targets`` is given, also calculates the losses and appends to the ``losses`` list. + + Args: + detection_layer: The detection layer. + layer_input: Input to the detection layer. + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + detections: A list where a tensor containing the detections will be appended to. + losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. + hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + """ + output, preds = detection_layer(layer_input, image_size) + detections.append(output) + + if targets is not None: + layer_losses, layer_hits = detection_layer.calculate_losses(preds, targets, image_size) + losses.append(layer_losses) + hits.append(layer_hits) + + +def _run_detection_with_aux_head( + detection_layer: DetectionLayer, + aux_detection_layer: DetectionLayer, + layer_input: Tensor, + aux_input: Tensor, + targets: Optional[List[Dict[str, Tensor]]], + image_size: Tensor, + aux_weight: float, + detections: List[Tensor], + losses: List[Tensor], + hits: List[int], +) -> None: + """Runs the detection layer on the inputs and appends the output to the ``detections`` list. + + If ``targets`` is given, also runs the auxiliary detection layer on the auxiliary inputs, calculates the losses, and + appends the losses to the ``losses`` list. + + Args: + detection_layer: The lead detection layer. + aux_detection_layer: The auxiliary detection layer. + layer_input: Input to the lead detection layer. + aux_input: Input to the auxiliary detection layer. + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + aux_weight: Weight of the auxiliary loss. + detections: A list where a tensor containing the detections will be appended to. + losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. + hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + """ + output, preds = detection_layer(layer_input, image_size) + detections.append(output) + + if targets is not None: + # Match lead head predictions to targets and calculate losses from lead head outputs. + layer_losses, layer_hits = detection_layer.calculate_losses(preds, targets, image_size) + losses.append(layer_losses) + hits.append(layer_hits) + + # Match lead head predictions to targets and calculate losses from auxiliary head outputs. + _, aux_preds = aux_detection_layer(aux_input, image_size) + layer_losses, layer_hits = aux_detection_layer.calculate_losses( + preds, targets, image_size, loss_preds=aux_preds + ) + losses.append(layer_losses * aux_weight) + hits.append(layer_hits) + + +@torch.jit.script +def _get_image_size(images: Tensor) -> Tensor: + """Get the image size from an input tensor. + + The function needs the ``@torch.jit.script`` decorator in order for ONNX generation to work. The tracing based + generator will loose track of e.g. ``images.shape[1]`` and treat it as a Python variable and not a tensor. This will + cause the dimension to be treated as a constant in the model, which prevents dynamic input sizes. + + Args: + images: An image batch to take the width and height from. + + Returns: + A tensor that contains the image width and height. + """ + height = images.shape[2] + width = images.shape[3] + return torch.tensor([width, height], device=images.device) + + +class YOLOV4TinyNetwork(nn.Module): + """The "tiny" network architecture from YOLOv4. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + width: The number of channels in the narrowest convolutional layer. The wider convolutional layers will use a + number of channels that is a multiple of this value. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain (width, height) tuples in the network input + resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They + are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning + that you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou" (default). + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + width: int = 32, + activation: Optional[str] = "leaky", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[List[Tuple[int, int]]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + (12, 16), + (19, 36), + (40, 28), + (36, 75), + (76, 55), + (72, 146), + (142, 110), + (192, 243), + (459, 401), + ] + anchors_per_cell = 3 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 3) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 3.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def conv(in_channels: int, out_channels: int, kernel_size: int = 1) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size, stride=1, activation=activation, norm=normalization) + + def upsample(in_channels: int, out_channels: int) -> nn.Module: + channels = conv(in_channels, out_channels) + upsample = nn.Upsample(scale_factor=2, mode="nearest") + return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)])) + + def outputs(in_channels: int) -> nn.Module: + return nn.Conv2d(in_channels, num_outputs, kernel_size=1, stride=1, bias=True) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + assert prior_shapes is not None + return _create_detection_layer( + prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + ) + + self.backbone = backbone or YOLOV4TinyBackbone(width=width, activation=activation, normalization=normalization) + + self.fpn5 = conv(width * 16, width * 8) + self.out5 = nn.Sequential( + OrderedDict( + [ + ("channels", conv(width * 8, width * 16)), + (f"outputs_{num_outputs}", outputs(width * 16)), + ] + ) + ) + self.upsample5 = upsample(width * 8, width * 4) + + self.fpn4 = conv(width * 12, width * 8, kernel_size=3) + self.out4 = nn.Sequential(OrderedDict([(f"outputs_{num_outputs}", outputs(width * 8))])) + self.upsample4 = upsample(width * 8, width * 2) + + self.fpn3 = conv(width * 6, width * 4, kernel_size=3) + self.out3 = nn.Sequential(OrderedDict([(f"outputs_{num_outputs}", outputs(width * 4))])) + + self.detect3 = detect([0, 1, 2]) + self.detect4 = detect([3, 4, 5]) + self.detect5 = detect([6, 7, 8]) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = _get_image_size(x) + + c3, c4, c5 = self.backbone(x)[-3:] + + p5 = self.fpn5(c5) + x = torch.cat((self.upsample5(p5), c4), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample4(p4), c3), dim=1) + p3 = self.fpn3(x) + + _run_detection(self.detect5, self.out5(p5), targets, image_size, detections, losses, hits) + _run_detection(self.detect4, self.out4(p4), targets, image_size, detections, losses, hits) + _run_detection(self.detect3, self.out3(p3), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class YOLOV4Network(nn.Module): + """Network architecture that corresponds approximately to the Cross Stage Partial Network from YOLOv4. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + widths: Number of channels at each network stage. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain (width, height) tuples in the network input + resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They + are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning + that you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou" (default). + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + widths: Sequence[int] = (32, 64, 128, 256, 512, 1024), + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[List[Tuple[int, int]]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + (12, 16), + (19, 36), + (40, 28), + (36, 75), + (76, 55), + (72, 146), + (142, 110), + (192, 243), + (459, 401), + ] + anchors_per_cell = 3 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 3) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 3.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return CSPSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization) + + def csp(in_channels: int, out_channels: int) -> nn.Module: + return CSPStage( + in_channels, + out_channels, + depth=2, + shortcut=False, + norm=normalization, + activation=activation, + ) + + def out(in_channels: int) -> nn.Module: + conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + outputs = nn.Conv2d(in_channels, num_outputs, kernel_size=1) + return nn.Sequential(OrderedDict([("conv", conv), (f"outputs_{num_outputs}", outputs)])) + + def upsample(in_channels: int, out_channels: int) -> nn.Module: + channels = conv(in_channels, out_channels) + upsample = nn.Upsample(scale_factor=2, mode="nearest") + return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)])) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + assert prior_shapes is not None + return _create_detection_layer( + prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + ) + + if backbone is not None: + self.backbone = backbone + else: + self.backbone = YOLOV4Backbone(widths=widths, activation=activation, normalization=normalization) + + w3 = widths[-3] + w4 = widths[-2] + w5 = widths[-1] + + self.spp = spp(w5, w5) + + self.pre4 = conv(w4, w4 // 2) + self.upsample5 = upsample(w5, w4 // 2) + self.fpn4 = csp(w4, w4) + + self.pre3 = conv(w3, w3 // 2) + self.upsample4 = upsample(w4, w3 // 2) + self.fpn3 = csp(w3, w3) + + self.downsample3 = downsample(w3, w3) + self.pan4 = csp(w3 + w4, w4) + + self.downsample4 = downsample(w4, w4) + self.pan5 = csp(w4 + w5, w5) + + self.out3 = out(w3) + self.out4 = out(w4) + self.out5 = out(w5) + + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = _get_image_size(x) + + c3, c4, x = self.backbone(x)[-3:] + c5 = self.spp(x) + + x = torch.cat((self.upsample5(c5), self.pre4(c4)), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample4(p4), self.pre3(c3)), dim=1) + n3 = self.fpn3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), c5), dim=1) + n5 = self.pan5(x) + + _run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) + _run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) + _run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class YOLOV4P6Network(nn.Module): + """Network architecture that corresponds approximately to the variant of YOLOv4 with four detection layers. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + widths: Number of channels at each network stage. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain (width, height) tuples in the network input + resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They + are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning + that you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou" (default). + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + widths: Sequence[int] = (32, 64, 128, 256, 512, 1024, 1024), + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[List[Tuple[int, int]]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + (13, 17), + (31, 25), + (24, 51), + (61, 45), + (61, 45), + (48, 102), + (119, 96), + (97, 189), + (97, 189), + (217, 184), + (171, 384), + (324, 451), + (324, 451), + (545, 357), + (616, 618), + (1024, 1024), + ] + anchors_per_cell = 4 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 4) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 4.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return CSPSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization) + + def csp(in_channels: int, out_channels: int) -> nn.Module: + return CSPStage( + in_channels, + out_channels, + depth=2, + shortcut=False, + norm=normalization, + activation=activation, + ) + + def out(in_channels: int) -> nn.Module: + conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + outputs = nn.Conv2d(in_channels, num_outputs, kernel_size=1) + return nn.Sequential(OrderedDict([("conv", conv), (f"outputs_{num_outputs}", outputs)])) + + def upsample(in_channels: int, out_channels: int) -> nn.Module: + channels = conv(in_channels, out_channels) + upsample = nn.Upsample(scale_factor=2, mode="nearest") + return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)])) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + assert prior_shapes is not None + return _create_detection_layer( + prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + ) + + if backbone is not None: + self.backbone = backbone + else: + self.backbone = YOLOV4Backbone( + widths=widths, depths=(1, 1, 3, 15, 15, 7, 7), activation=activation, normalization=normalization + ) + + w3 = widths[-4] + w4 = widths[-3] + w5 = widths[-2] + w6 = widths[-1] + + self.spp = spp(w6, w6) + + self.pre5 = conv(w5, w5 // 2) + self.upsample6 = upsample(w6, w5 // 2) + self.fpn5 = csp(w5, w5) + + self.pre4 = conv(w4, w4 // 2) + self.upsample5 = upsample(w5, w4 // 2) + self.fpn4 = csp(w4, w4) + + self.pre3 = conv(w3, w3 // 2) + self.upsample4 = upsample(w4, w3 // 2) + self.fpn3 = csp(w3, w3) + + self.downsample3 = downsample(w3, w3) + self.pan4 = csp(w3 + w4, w4) + + self.downsample4 = downsample(w4, w4) + self.pan5 = csp(w4 + w5, w5) + + self.downsample5 = downsample(w5, w5) + self.pan6 = csp(w5 + w6, w6) + + self.out3 = out(w3) + self.out4 = out(w4) + self.out5 = out(w5) + self.out6 = out(w6) + + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + self.detect6 = detect(range(anchors_per_cell * 3, anchors_per_cell * 4)) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = _get_image_size(x) + + c3, c4, c5, x = self.backbone(x)[-4:] + c6 = self.spp(x) + + x = torch.cat((self.upsample6(c6), self.pre5(c5)), dim=1) + p5 = self.fpn5(x) + x = torch.cat((self.upsample5(p5), self.pre4(c4)), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample4(p4), self.pre3(c3)), dim=1) + n3 = self.fpn3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), p5), dim=1) + n5 = self.pan5(x) + x = torch.cat((self.downsample5(n5), c6), dim=1) + n6 = self.pan6(x) + + _run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) + _run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) + _run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + _run_detection(self.detect6, self.out6(n6), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class YOLOV5Network(nn.Module): + """The YOLOv5 network architecture. Different variants (n/s/m/l/x) can be achieved by adjusting the ``depth`` + and ``width`` parameters. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + width: Number of channels in the narrowest convolutional layer. The wider convolutional layers will use a number + of channels that is a multiple of this value. The values used by the different variants are 16 (yolov5n), 32 + (yolov5s), 48 (yolov5m), 64 (yolov5l), and 80 (yolov5x). + depth: Repeat the bottleneck layers this many times. Can be used to make the network deeper. The values used by + the different variants are 1 (yolov5n, yolov5s), 2 (yolov5m), 3 (yolov5l), and 4 (yolov5x). + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain (width, height) tuples in the network input + resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They + are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning + that you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou" (default). + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + width: int = 64, + depth: int = 3, + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[List[Tuple[int, int]]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + (12, 16), + (19, 36), + (40, 28), + (36, 75), + (76, 55), + (72, 146), + (142, 110), + (192, 243), + (459, 401), + ] + anchors_per_cell = 3 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 3) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 3.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return FastSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization) + + def out(in_channels: int) -> nn.Module: + outputs = nn.Conv2d(in_channels, num_outputs, kernel_size=1) + return nn.Sequential(OrderedDict([(f"outputs_{num_outputs}", outputs)])) + + def csp(in_channels: int, out_channels: int) -> nn.Module: + return CSPStage( + in_channels, + out_channels, + depth=depth, + shortcut=False, + norm=normalization, + activation=activation, + ) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + assert prior_shapes is not None + return _create_detection_layer( + prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + ) + + self.backbone = backbone or YOLOV5Backbone( + depth=depth, width=width, activation=activation, normalization=normalization + ) + + self.spp = spp(width * 16, width * 16) + + self.pan3 = csp(width * 8, width * 4) + self.out3 = out(width * 4) + + self.fpn4 = nn.Sequential( + OrderedDict( + [ + ("csp", csp(width * 16, width * 8)), + ("conv", conv(width * 8, width * 4)), + ] + ) + ) + self.pan4 = csp(width * 8, width * 8) + self.out4 = out(width * 8) + + self.fpn5 = conv(width * 16, width * 8) + self.pan5 = csp(width * 16, width * 16) + self.out5 = out(width * 16) + + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + + self.downsample3 = downsample(width * 4, width * 4) + self.downsample4 = downsample(width * 8, width * 8) + + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = _get_image_size(x) + + c3, c4, x = self.backbone(x)[-3:] + c5 = self.spp(x) + + p5 = self.fpn5(c5) + x = torch.cat((self.upsample(p5), c4), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample(p4), c3), dim=1) + + n3 = self.pan3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), p5), dim=1) + n5 = self.pan5(x) + + _run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) + _run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) + _run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class YOLOV7Network(nn.Module): + """Network architecture that corresponds to the W6 variant of YOLOv7 with four detection layers. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + widths: Number of channels at each network stage. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain (width, height) tuples in the network input + resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They + are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning + that you typically want to sort the shapes from the smallest to the largest. + aux_weight: Weight for the loss from the auxiliary heads. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou" (default). + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + widths: Sequence[int] = (32, 64, 128, 256, 512, 1024, 1024), + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[List[Tuple[int, int]]] = None, + aux_weight: float = 0.25, + **kwargs: Any, + ) -> None: + super().__init__() + + self.aux_weight = aux_weight + + # By default use the prior shapes that have been learned from the COCO data. + if prior_shapes is None: + prior_shapes = [ + (13, 17), + (31, 25), + (24, 51), + (61, 45), + (61, 45), + (48, 102), + (119, 96), + (97, 189), + (97, 189), + (217, 184), + (171, 384), + (324, 451), + (324, 451), + (545, 357), + (616, 618), + (1024, 1024), + ] + anchors_per_cell = 4 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 4) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 4.") + num_outputs = (5 + num_classes) * anchors_per_cell + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return CSPSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization) + + def elan(in_channels: int, out_channels: int) -> nn.Module: + return ELANStage( + in_channels, + out_channels, + split_channels=out_channels, + depth=4, + block_depth=1, + norm=normalization, + activation=activation, + ) + + def out(in_channels: int, hidden_channels: int) -> nn.Module: + conv = Conv( + in_channels, hidden_channels, kernel_size=3, stride=1, activation=activation, norm=normalization + ) + outputs = nn.Conv2d(hidden_channels, num_outputs, kernel_size=1) + return nn.Sequential(OrderedDict([("conv", conv), (f"outputs_{num_outputs}", outputs)])) + + def upsample(in_channels: int, out_channels: int) -> nn.Module: + channels = conv(in_channels, out_channels) + upsample = nn.Upsample(scale_factor=2, mode="nearest") + return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)])) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def detect(prior_shape_idxs: Sequence[int], range: float) -> DetectionLayer: + assert prior_shapes is not None + return _create_detection_layer( + prior_shapes, + prior_shape_idxs, + sim_ota_range=range, + num_classes=num_classes, + input_is_normalized=False, + **kwargs, + ) + + if backbone is not None: + self.backbone = backbone + else: + self.backbone = YOLOV7Backbone( + widths=widths, depth=2, block_depth=2, activation=activation, normalization=normalization + ) + + w3 = widths[-4] # 256 + w4 = widths[-3] # 512 + w5 = widths[-2] # 768 + w6 = widths[-1] # 1024 + + self.spp = spp(w6, w6 // 2) + + self.pre5 = conv(w5, w5 // 2) + self.upsample6 = upsample(w6 // 2, w5 // 2) + self.fpn5 = elan(w5, w5 // 2) + + self.pre4 = conv(w4, w4 // 2) + self.upsample5 = upsample(w5 // 2, w4 // 2) + self.fpn4 = elan(w4, w4 // 2) + + self.pre3 = conv(w3, w3 // 2) + self.upsample4 = upsample(w4 // 2, w3 // 2) + self.fpn3 = elan(w3, w3 // 2) + + self.downsample3 = downsample(w3 // 2, w4 // 2) + self.pan4 = elan(w4, w4 // 2) + + self.downsample4 = downsample(w4 // 2, w5 // 2) + self.pan5 = elan(w5, w5 // 2) + + self.downsample5 = downsample(w5 // 2, w6 // 2) + self.pan6 = elan(w6, w6 // 2) + + self.out3 = out(w3 // 2, w3) + self.aux_out3 = out(w3 // 2, w3 + (w3 // 4)) + self.out4 = out(w4 // 2, w4) + self.aux_out4 = out(w4 // 2, w4 + (w4 // 4)) + self.out5 = out(w5 // 2, w5) + self.aux_out5 = out(w5 // 2, w5 + (w5 // 4)) + self.out6 = out(w6 // 2, w6) + self.aux_out6 = out(w6 // 2, w6 + (w6 // 4)) + + self.detect3 = detect(range(0, anchors_per_cell), 5.0) + self.aux_detect3 = detect(range(0, anchors_per_cell), 3.0) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2), 5.0) + self.aux_detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2), 3.0) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3), 5.0) + self.aux_detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3), 3.0) + self.detect6 = detect(range(anchors_per_cell * 3, anchors_per_cell * 4), 5.0) + self.aux_detect6 = detect(range(anchors_per_cell * 3, anchors_per_cell * 4), 3.0) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = _get_image_size(x) + + c3, c4, c5, x = self.backbone(x)[-4:] + c6 = self.spp(x) + + x = torch.cat((self.upsample6(c6), self.pre5(c5)), dim=1) + p5 = self.fpn5(x) + x = torch.cat((self.upsample5(p5), self.pre4(c4)), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample4(p4), self.pre3(c3)), dim=1) + n3 = self.fpn3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), p5), dim=1) + n5 = self.pan5(x) + x = torch.cat((self.downsample5(n5), c6), dim=1) + n6 = self.pan6(x) + + _run_detection_with_aux_head( + self.detect3, + self.aux_detect3, + self.out3(n3), + self.aux_out3(n3), + targets, + image_size, + self.aux_weight, + detections, + losses, + hits, + ) + _run_detection_with_aux_head( + self.detect4, + self.aux_detect4, + self.out4(n4), + self.aux_out4(p4), + targets, + image_size, + self.aux_weight, + detections, + losses, + hits, + ) + _run_detection_with_aux_head( + self.detect5, + self.aux_detect5, + self.out5(n5), + self.aux_out5(p5), + targets, + image_size, + self.aux_weight, + detections, + losses, + hits, + ) + _run_detection_with_aux_head( + self.detect6, + self.aux_detect6, + self.out6(n6), + self.aux_out6(c6), + targets, + image_size, + self.aux_weight, + detections, + losses, + hits, + ) + return detections, losses, hits + + +class YOLOXHead(nn.Module): + """A module that produces features for YOLO detection layer, decoupling the classification and localization + features. + + Args: + in_channels: Number of input channels that the module expects. + hidden_channels: Number of output channels in the hidden layers. + anchors_per_cell: Number of detections made at each spatial location of the feature map. + num_classes: Number of different classes that this model predicts. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + anchors_per_cell: int, + num_classes: int, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + def conv(in_channels: int, out_channels: int, kernel_size: int = 1) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size, stride=1, activation=activation, norm=norm) + + def linear(in_channels: int, out_channels: int) -> nn.Module: + return nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def features(num_channels: int) -> nn.Module: + return nn.Sequential( + conv(num_channels, num_channels, kernel_size=3), + conv(num_channels, num_channels, kernel_size=3), + ) + + def classprob(num_channels: int) -> nn.Module: + num_outputs = anchors_per_cell * num_classes + outputs = linear(num_channels, num_outputs) + return nn.Sequential(OrderedDict([("convs", features(num_channels)), (f"outputs_{num_outputs}", outputs)])) + + self.stem = conv(in_channels, hidden_channels) + self.feat = features(hidden_channels) + self.box = linear(hidden_channels, anchors_per_cell * 4) + self.confidence = linear(hidden_channels, anchors_per_cell) + self.classprob = classprob(hidden_channels) + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + features = self.feat(x) + box = self.box(features) + confidence = self.confidence(features) + classprob = self.classprob(x) + return torch.cat((box, confidence, classprob), dim=1) + + +class YOLOXNetwork(nn.Module): + """The YOLOX network architecture. Different variants (nano/tiny/s/m/l/x) can be achieved by adjusting the + ``depth`` and ``width`` parameters. + + Args: + num_classes: Number of different classes that this model predicts. + backbone: A backbone network that returns the output from each stage. + width: Number of channels in the narrowest convolutional layer. The wider convolutional layers will use a number + of channels that is a multiple of this value. The values used by the different variants are 24 (yolox-tiny), + 32 (yolox-s), 48 (yolox-m), and 64 (yolox-l). + depth: Repeat the bottleneck layers this many times. Can be used to make the network deeper. The values used by + the different variants are 1 (yolox-tiny, yolox-s), 2 (yolox-m), and 3 (yolox-l). + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain (width, height) tuples in the network input + resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They + are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning + that you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU + with some target greater than this threshold, the predictor will not be taken into account when calculating + the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou" (default). + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps + to produce coordinate values close to one. + """ + + def __init__( + self, + num_classes: int, + backbone: Optional[nn.Module] = None, + width: int = 64, + depth: int = 3, + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + prior_shapes: Optional[List[Tuple[int, int]]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + + # By default use one anchor per cell and the stride as the prior size. + if prior_shapes is None: + prior_shapes = [(8, 8), (16, 16), (32, 32)] + anchors_per_cell = 1 + else: + anchors_per_cell, modulo = divmod(len(prior_shapes), 3) + if modulo != 0: + raise ValueError("The number of provided prior shapes needs to be divisible by 3.") + + def spp(in_channels: int, out_channels: int) -> nn.Module: + return FastSPP(in_channels, out_channels, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def conv(in_channels: int, out_channels: int, kernel_size: int = 1) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size, stride=1, activation=activation, norm=normalization) + + def csp(in_channels: int, out_channels: int) -> nn.Module: + return CSPStage( + in_channels, + out_channels, + depth=depth, + shortcut=False, + norm=normalization, + activation=activation, + ) + + def head(in_channels: int, hidden_channels: int) -> YOLOXHead: + return YOLOXHead( + in_channels, + hidden_channels, + anchors_per_cell, + num_classes, + activation=activation, + norm=normalization, + ) + + def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + assert prior_shapes is not None + return _create_detection_layer( + prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + ) + + self.backbone = backbone or YOLOV5Backbone( + depth=depth, width=width, activation=activation, normalization=normalization + ) + + self.spp = spp(width * 16, width * 16) + + self.pan3 = csp(width * 8, width * 4) + self.out3 = head(width * 4, width * 4) + + self.fpn4 = nn.Sequential( + OrderedDict( + [ + ("csp", csp(width * 16, width * 8)), + ("conv", conv(width * 8, width * 4)), + ] + ) + ) + self.pan4 = csp(width * 8, width * 8) + self.out4 = head(width * 8, width * 4) + + self.fpn5 = conv(width * 16, width * 8) + self.pan5 = csp(width * 16, width * 16) + self.out5 = head(width * 16, width * 4) + + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + + self.downsample3 = downsample(width * 4, width * 4) + self.downsample4 = downsample(width * 8, width * 8) + + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = _get_image_size(x) + + c3, c4, x = self.backbone(x)[-3:] + c5 = self.spp(x) + + p5 = self.fpn5(c5) + x = torch.cat((self.upsample(p5), c4), dim=1) + p4 = self.fpn4(x) + x = torch.cat((self.upsample(p4), c3), dim=1) + + n3 = self.pan3(x) + x = torch.cat((self.downsample3(n3), p4), dim=1) + n4 = self.pan4(x) + x = torch.cat((self.downsample4(n4), p5), dim=1) + n5 = self.pan5(x) + + _run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) + _run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) + _run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + return detections, losses, hits + + +class DarknetNetwork(nn.Module): + """This class can be used to parse the configuration files of the Darknet YOLOv4 implementation.""" + + def __init__( + self, config_path: str, weights_path: Optional[str] = None, in_channels: Optional[int] = None, **kwargs: Any + ) -> None: + """Parses a Darknet configuration file and creates the network structure. + + Iterates through the layers from the configuration and creates corresponding PyTorch modules. If + ``weights_path`` is given and points to a Darknet model file, loads the convolutional layer weights from the + file. + + Args: + config_path: Path to a Darknet configuration file that defines the network architecture. + weights_path: Path to a Darknet model file. If given, the model weights will be read from this file. + in_channels: Number of channels in the input image. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching + rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is + below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the + prior shape that gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding + anchor has IoU with some target greater than this threshold, the predictor will not be taken into + account when calculating the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or + a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", + and "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + """ + super().__init__() + + with open(config_path) as config_file: + sections = self._read_config(config_file) + + if len(sections) < 2: + raise ValueError("The model configuration file should include at least two sections.") + + self.__dict__.update(sections[0]) + global_config = sections[0] + layer_configs = sections[1:] + + if in_channels is None: + in_channels = global_config.get("channels", 3) + assert isinstance(in_channels, int) + + self.layers = nn.ModuleList() + # num_inputs will contain the number of channels in the input of every layer up to the current layer. It is + # initialized with the number of channels in the input image. + num_inputs = [in_channels] + for layer_config in layer_configs: + config = {**global_config, **layer_config} + layer, num_outputs = _create_layer(config, num_inputs, **kwargs) + self.layers.append(layer) + num_inputs.append(num_outputs) + + if weights_path is not None: + with open(weights_path) as weight_file: + self.load_weights(weight_file) + + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: + outputs: List[Tensor] = [] # Outputs from all layers + detections: List[Tensor] = [] # Outputs from detection layers + losses: List[Tensor] = [] # Losses from detection layers + hits: List[int] = [] # Number of targets each detection layer was responsible for + + image_size = _get_image_size(x) + + for layer in self.layers: + if isinstance(layer, (RouteLayer, ShortcutLayer)): + x = layer(outputs) + elif isinstance(layer, DetectionLayer): + x, preds = layer(x, image_size) + detections.append(x) + if targets is not None: + layer_losses, layer_hits = layer.calculate_losses(preds, targets, image_size) + losses.append(layer_losses) + hits.append(layer_hits) + else: + x = layer(x) + + outputs.append(x) + + return detections, losses, hits + + def load_weights(self, weight_file: io.IOBase) -> None: + """Loads weights to layer modules from a pretrained Darknet model. + + One may want to continue training from pretrained weights, on a dataset with a different number of object + categories. The number of kernels in the convolutional layers just before each detection layer depends on the + number of output classes. The Darknet solution is to truncate the weight file and stop reading weights at the + first incompatible layer. For this reason the function silently leaves the rest of the layers unchanged, when + the weight file ends. + + Args: + weight_file: A file-like object containing model weights in the Darknet binary format. + """ + if not isinstance(weight_file, io.IOBase): + raise ValueError("weight_file must be a file-like object.") + + version = np.fromfile(weight_file, count=3, dtype=np.int32) + images_seen = np.fromfile(weight_file, count=1, dtype=np.int64) + print( + f"Loading weights from Darknet model version {version[0]}.{version[1]}.{version[2]} " + f"that has been trained on {images_seen[0]} images." + ) + + def read(tensor: Tensor) -> int: + """Reads the contents of ``tensor`` from the current position of ``weight_file``. + + Returns the number of elements read. If there's no more data in ``weight_file``, returns 0. + """ + np_array = np.fromfile(weight_file, count=tensor.numel(), dtype=np.float32) + num_elements = np_array.size + if num_elements > 0: + source = torch.from_numpy(np_array).view_as(tensor) + with torch.no_grad(): + tensor.copy_(source) + return num_elements + + for layer in self.layers: + # Weights are loaded only to convolutional layers + if not isinstance(layer, Conv): + continue + + # If convolution is followed by batch normalization, read the batch normalization parameters. Otherwise we + # read the convolution bias. + if isinstance(layer.norm, nn.Identity): + assert layer.conv.bias is not None + read(layer.conv.bias) + else: + assert isinstance(layer.norm, nn.BatchNorm2d) + assert layer.norm.running_mean is not None + assert layer.norm.running_var is not None + read(layer.norm.bias) + read(layer.norm.weight) + read(layer.norm.running_mean) + read(layer.norm.running_var) + + read_count = read(layer.conv.weight) + if read_count == 0: + return + + def _read_config(self, config_file: Iterable[str]) -> List[Dict[str, Any]]: + """Reads a Darnet network configuration file and returns a list of configuration sections. + + Args: + config_file: The configuration file to read. + + Returns: + A list of configuration sections. + """ + section_re = re.compile(r"\[([^]]+)\]") + list_variables = ("layers", "anchors", "mask", "scales") + variable_types = { + "activation": str, + "anchors": int, + "angle": float, + "batch": int, + "batch_normalize": bool, + "beta_nms": float, + "burn_in": int, + "channels": int, + "classes": int, + "cls_normalizer": float, + "decay": float, + "exposure": float, + "filters": int, + "from": int, + "groups": int, + "group_id": int, + "height": int, + "hue": float, + "ignore_thresh": float, + "iou_loss": str, + "iou_normalizer": float, + "iou_thresh": float, + "jitter": float, + "layers": int, + "learning_rate": float, + "mask": int, + "max_batches": int, + "max_delta": float, + "momentum": float, + "mosaic": bool, + "new_coords": int, + "nms_kind": str, + "num": int, + "obj_normalizer": float, + "pad": bool, + "policy": str, + "random": bool, + "resize": float, + "saturation": float, + "scales": float, + "scale_x_y": float, + "size": int, + "steps": str, + "stride": int, + "subdivisions": int, + "truth_thresh": float, + "width": int, + } + + section = None + sections = [] + + def convert(key: str, value: str) -> Union[str, int, float, List[Union[str, int, float]]]: + """Converts a value to the correct type based on key.""" + if key not in variable_types: + warn("Unknown YOLO configuration variable: " + key) + return value + if key in list_variables: + return [variable_types[key](v) for v in value.split(",")] + else: + return variable_types[key](value) + + for line in config_file: + line = line.strip() + if (not line) or (line[0] == "#"): + continue + + section_match = section_re.match(line) + if section_match: + if section is not None: + sections.append(section) + section = {"type": section_match.group(1)} + else: + if section is None: + raise RuntimeError("Darknet network configuration file does not start with a section header.") + key, value = line.split("=") + key = key.rstrip() + value = value.lstrip() + section[key] = convert(key, value) + if section is not None: + sections.append(section) + + return sections + + +def _create_layer(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the + layer config. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + create_func: Dict[str, Callable[..., CREATE_LAYER_OUTPUT]] = { + "convolutional": _create_convolutional, + "maxpool": _create_maxpool, + "route": _create_route, + "shortcut": _create_shortcut, + "upsample": _create_upsample, + "yolo": _create_yolo, + } + return create_func[config["type"]](config, num_inputs, **kwargs) + + +def _create_convolutional(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a convolutional layer. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + batch_normalize = config.get("batch_normalize", False) + padding = (config["size"] - 1) // 2 if config["pad"] else 0 + + layer = Conv( + num_inputs[-1], + config["filters"], + kernel_size=config["size"], + stride=config["stride"], + padding=padding, + bias=not batch_normalize, + activation=config["activation"], + norm="batchnorm" if batch_normalize else None, + ) + return layer, config["filters"] + + +def _create_maxpool(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a max pooling layer. + + Padding is added so that the output resolution will be the input resolution divided by stride, rounded upwards. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + layer = MaxPool(config["size"], config["stride"]) + return layer, num_inputs[-1] + + +def _create_route(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a routing layer. + + A routing layer concatenates the output (or part of it) from the layers specified by the "layers" configuration + option. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + num_chunks = config.get("groups", 1) + chunk_idx = config.get("group_id", 0) + + # 0 is the first layer, -1 is the previous layer + last = len(num_inputs) - 1 + source_layers = [layer if layer >= 0 else last + layer for layer in config["layers"]] + + layer = RouteLayer(source_layers, num_chunks, chunk_idx) + + # The number of outputs of a source layer is the number of inputs of the next layer. + num_outputs = sum(num_inputs[layer + 1] // num_chunks for layer in source_layers) + + return layer, num_outputs + + +def _create_shortcut(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a shortcut layer. + + A shortcut layer adds a residual connection from the layer specified by the "from" configuration option. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + layer = ShortcutLayer(config["from"]) + return layer, num_inputs[-1] + + +def _create_upsample(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: + """Creates a layer that upsamples the data. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output. + """ + layer = nn.Upsample(scale_factor=config["stride"], mode="nearest") + return layer, num_inputs[-1] + + +def _create_yolo( + config: CONFIG, + num_inputs: List[int], + prior_shapes: Optional[List[Tuple[int, int]]] = None, + matching_algorithm: Optional[str] = None, + matching_threshold: Optional[float] = None, + ignore_bg_threshold: Optional[float] = None, + overlap_func: Optional[Union[str, Callable]] = None, + predict_overlap: float = 1.0, + overlap_loss_multiplier: Optional[float] = None, + confidence_loss_multiplier: Optional[float] = None, + class_loss_multiplier: Optional[float] = None, + **kwargs: Any, +) -> CREATE_LAYER_OUTPUT: + """Creates a YOLO detection layer. + + Args: + config: Dictionary of configuration options for this layer. + num_inputs: Number of channels in the input of every layer up to this layer. Not used by the detection layer. + prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for + matching the targets to the anchors. The list should contain (width, height) tuples in the network input + resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They + are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning + that you typically want to sort the shapes from the smallest to the largest. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + + Returns: + module (:class:`~torch.nn.Module`), num_outputs (int): The created PyTorch module and the number of channels in + its output (always 0 for a detection layer). + """ + if prior_shapes is None: + # The "anchors" list alternates width and height. + dims = config["anchors"] + prior_shapes = [(dims[i], dims[i + 1]) for i in range(0, len(dims), 2)] + if ignore_bg_threshold is None: + ignore_bg_threshold = config.get("ignore_thresh", 1.0) + assert isinstance(ignore_bg_threshold, float) + if overlap_func is None: + overlap_func = config.get("iou_loss", "iou") + assert isinstance(overlap_func, str) + if overlap_loss_multiplier is None: + overlap_loss_multiplier = config.get("iou_normalizer", 1.0) + assert isinstance(overlap_loss_multiplier, float) + if confidence_loss_multiplier is None: + confidence_loss_multiplier = config.get("obj_normalizer", 1.0) + assert isinstance(confidence_loss_multiplier, float) + if class_loss_multiplier is None: + class_loss_multiplier = config.get("cls_normalizer", 1.0) + assert isinstance(class_loss_multiplier, float) + + layer = _create_detection_layer( + num_classes=config["classes"], + prior_shapes=prior_shapes, + prior_shape_idxs=config["mask"], + matching_algorithm=matching_algorithm, + matching_threshold=matching_threshold, + ignore_bg_threshold=ignore_bg_threshold, + overlap_func=overlap_func, + predict_overlap=predict_overlap, + overlap_loss_multiplier=overlap_loss_multiplier, + confidence_loss_multiplier=confidence_loss_multiplier, + class_loss_multiplier=class_loss_multiplier, + xy_scale=config.get("scale_x_y", 1.0), + input_is_normalized=config.get("new_coords", 0) > 0, + ) + return layer, 0 diff --git a/torchvision/models/yolo.py b/torchvision/models/yolo.py new file mode 100644 index 00000000000..695b613353f --- /dev/null +++ b/torchvision/models/yolo.py @@ -0,0 +1,729 @@ +from collections import OrderedDict +from typing import List, Optional, Sequence, Tuple + +import torch +from torch import Tensor, nn + + +def _get_padding(kernel_size: int, stride: int) -> Tuple[int, nn.Module]: + """Returns the amount of padding needed by convolutional and max pooling layers. + + Determines the amount of padding needed to make the output size of the layer the input size divided by the stride. + The first value that the function returns is the amount of padding to be added to all sides of the input matrix + (``padding`` argument of the operation). If an uneven amount of padding is needed in different sides of the input, + the second variable that is returned is an ``nn.ZeroPad2d`` operation that adds an additional column and row of + padding. If the input size is not divisible by the stride, the output size will be rounded upwards. + + Args: + kernel_size: Size of the kernel. + stride: Stride of the operation. + + Returns: + padding, pad_op: The amount of padding to be added to all sides of the input and an ``nn.Identity`` or + ``nn.ZeroPad2d`` operation to add one more column and row of padding if necessary. + """ + # The output size is generally (input_size + padding - max(kernel_size, stride)) / stride + 1 and we want to + # make it equal to input_size / stride. + padding, remainder = divmod(max(kernel_size, stride) - stride, 2) + + # If the kernel size is an even number, we need one cell of extra padding, on top of the padding added by MaxPool2d + # on both sides. + pad_op: nn.Module = nn.Identity() if remainder == 0 else nn.ZeroPad2d((0, 1, 0, 1)) + + return padding, pad_op + + +def _create_activation_module(name: Optional[str]) -> nn.Module: + """Creates a layer activation module given its type as a string. + + Args: + name: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", "linear", + or "none". + """ + if name == "relu": + return nn.ReLU(inplace=True) + if name == "leaky": + return nn.LeakyReLU(0.1, inplace=True) + if name == "mish": + return Mish() + if name == "silu" or name == "swish": + return nn.SiLU(inplace=True) + if name == "logistic": + return nn.Sigmoid() + if name == "linear" or name == "none" or name is None: + return nn.Identity() + raise ValueError(f"Activation type `{name}´ is unknown.") + + +def _create_normalization_module(name: Optional[str], num_channels: int) -> nn.Module: + """Creates a layer normalization module given its type as a string. + + Group normalization uses always 8 channels. The most common network widths are divisible by this number. + + Args: + name: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + num_channels: The number of input channels that the module expects. + """ + if name == "batchnorm": + return nn.BatchNorm2d(num_channels, eps=0.001) + if name == "groupnorm": + return nn.GroupNorm(8, num_channels, eps=0.001) + if name == "none" or name is None: + return nn.Identity() + raise ValueError(f"Normalization layer type `{name}´ is unknown.") + + +class Conv(nn.Module): + """A convolutional layer with optional layer normalization and activation. + + If ``padding`` is ``None``, the module tries to add padding so much that the output size will be the input size + divided by the stride. If the input size is not divisible by the stride, the output size will be rounded upwards. + + Args: + in_channels: Number of input channels that the layer expects. + out_channels: Number of output channels that the convolution produces. + kernel_size: Size of the convolving kernel. + stride: Stride of the convolution. + padding: Padding added to all four sides of the input. + bias: If ``True``, adds a learnable bias to the output. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: Optional[int] = None, + bias: bool = False, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ): + super().__init__() + + if padding is None: + padding, self.pad = _get_padding(kernel_size, stride) + else: + self.pad = nn.Identity() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) + self.norm = _create_normalization_module(norm, out_channels) + self.act = _create_activation_module(activation) + + def forward(self, x: Tensor) -> Tensor: + x = self.pad(x) + x = self.conv(x) + x = self.norm(x) + return self.act(x) + + +class MaxPool(nn.Module): + """A max pooling layer with padding. + + The module tries to add padding so much that the output size will be the input size divided by the stride. If the + input size is not divisible by the stride, the output size will be rounded upwards. + """ + + def __init__(self, kernel_size: int, stride: int): + super().__init__() + padding, self.pad = _get_padding(kernel_size, stride) + self.maxpool = nn.MaxPool2d(kernel_size, stride, padding) + + def forward(self, x: Tensor) -> Tensor: + x = self.pad(x) + return self.maxpool(x) + + +class RouteLayer(nn.Module): + """A routing layer concatenates the output (or part of it) from given layers. + + Args: + source_layers: Indices of the layers whose output will be concatenated. + num_chunks: Layer outputs will be split into this number of chunks. + chunk_idx: Only the chunks with this index will be concatenated. + """ + + def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) -> None: + super().__init__() + self.source_layers = source_layers + self.num_chunks = num_chunks + self.chunk_idx = chunk_idx + + def forward(self, outputs: List[Tensor]) -> Tensor: + chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] for layer in self.source_layers] + return torch.cat(chunks, dim=1) + + +class ShortcutLayer(nn.Module): + """A shortcut layer adds a residual connection from the source layer. + + Args: + source_layer: Index of the layer whose output will be added to the output of the previous layer. + """ + + def __init__(self, source_layer: int) -> None: + super().__init__() + self.source_layer = source_layer + + def forward(self, outputs: List[Tensor]) -> Tensor: + return outputs[-1] + outputs[self.source_layer] + + +class Mish(nn.Module): + """Mish activation.""" + + def forward(self, x: Tensor) -> Tensor: + return x * torch.tanh(nn.functional.softplus(x)) + + +class ReOrg(nn.Module): + """Re-organizes the tensor so that every square region of four cells is placed into four different channels. + + The result is a tensor with half the width and height, and four times as many channels. + """ + + def forward(self, x: Tensor) -> Tensor: + tl = x[..., ::2, ::2] + bl = x[..., 1::2, ::2] + tr = x[..., ::2, 1::2] + br = x[..., 1::2, 1::2] + return torch.cat((tl, bl, tr, br), dim=1) + + +class BottleneckBlock(nn.Module): + """A residual block with a bottleneck layer. + + Args: + in_channels: Number of input channels that the block expects. + out_channels: Number of output channels that the block produces. + hidden_channels: Number of output channels the (hidden) bottleneck layer produces. By default the number of + output channels of the block. + shortcut: Whether the block should include a shortcut connection. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: Optional[int] = None, + shortcut: bool = True, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + if hidden_channels is None: + hidden_channels = out_channels + + self.convs = nn.Sequential( + Conv(in_channels, hidden_channels, kernel_size=1, stride=1, activation=activation, norm=norm), + Conv(hidden_channels, out_channels, kernel_size=3, stride=1, activation=activation, norm=norm), + ) + self.shortcut = shortcut and in_channels == out_channels + + def forward(self, x: Tensor) -> Tensor: + y = self.convs(x) + return x + y if self.shortcut else y + + +class TinyStage(nn.Module): + """One stage of the "tiny" network architecture from YOLOv4. + + Args: + num_channels: Number of channels in the input of the stage. Partial output will have as many channels and full + output will have twice as many channels. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + num_channels: int, + activation: Optional[str] = "leaky", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + hidden_channels = num_channels // 2 + self.conv1 = Conv(hidden_channels, hidden_channels, kernel_size=3, stride=1, activation=activation, norm=norm) + self.conv2 = Conv(hidden_channels, hidden_channels, kernel_size=3, stride=1, activation=activation, norm=norm) + self.mix = Conv(num_channels, num_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + def forward(self, x: Tensor) -> Tensor: + partial = torch.chunk(x, 2, dim=1)[1] + y1 = self.conv1(partial) + y2 = self.conv2(y1) + partial_output = self.mix(torch.cat((y2, y1), dim=1)) + full_output = torch.cat((x, partial_output), dim=1) + return partial_output, full_output + + +class CSPStage(nn.Module): + """One stage of a Cross Stage Partial Network (CSPNet). + + Encapsulates a number of bottleneck blocks in the "fusion first" CSP structure. + + `Chien-Yao Wang et al. `_ + + Args: + in_channels: Number of input channels that the CSP stage expects. + out_channels: Number of output channels that the CSP stage produces. + depth: Number of bottleneck blocks that the CSP stage contains. + shortcut: Whether the bottleneck blocks should include a shortcut connection. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + depth: int = 1, + shortcut: bool = True, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + # Instead of splitting the N output channels of a convolution into two parts, we can equivalently perform two + # convolutions with N/2 output channels. + hidden_channels = out_channels // 2 + + self.split1 = Conv(in_channels, hidden_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + self.split2 = Conv(in_channels, hidden_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + bottlenecks: List[nn.Module] = [ + BottleneckBlock(hidden_channels, hidden_channels, shortcut=shortcut, norm=norm, activation=activation) + for _ in range(depth) + ] + self.bottlenecks = nn.Sequential(*bottlenecks) + self.mix = Conv(hidden_channels * 2, out_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + def forward(self, x: Tensor) -> Tensor: + y1 = self.bottlenecks(self.split1(x)) + y2 = self.split2(x) + return self.mix(torch.cat((y1, y2), dim=1)) + + +class ELANStage(nn.Module): + """One stage of an Efficient Layer Aggregation Network (ELAN). + + `Chien-Yao Wang et al. `_ + + Args: + in_channels: Number of input channels that the ELAN stage expects. + out_channels: Number of output channels that the ELAN stage produces. + hidden_channels: Number of output channels that the computational blocks produce. The default value is half the + number of output channels of the block, as in YOLOv7-W6, but the value varies between the variants. + split_channels: Number of channels in each part after splitting the input to the cross stage connection and the + computational blocks. The default value is the number of hidden channels, as in all YOLOv7 backbones. Most + YOLOv7 heads use twice the number of hidden channels. + depth: Number of computational blocks that the ELAN stage contains. The default value is 2. YOLOv7 backbones use + 2 to 4 blocks per stage. + block_depth: Number of convolutional layers in one computational block. The default value is 2. YOLOv7 backbones + have two convolutions per block. YOLOv7 heads (except YOLOv7-X) have 2 to 8 blocks with only one convolution + in each. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: Optional[int] = None, + split_channels: Optional[int] = None, + depth: int = 2, + block_depth: int = 2, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + def conv3x3(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=1, activation=activation, norm=norm) + + def block(in_channels: int, out_channels: int) -> nn.Module: + convs = [conv3x3(in_channels, out_channels)] + for _ in range(block_depth - 1): + convs.append(conv3x3(out_channels, out_channels)) + return nn.Sequential(*convs) + + # Instead of splitting the N output channels of a convolution into two parts, we can equivalently perform two + # convolutions with N/2 output channels. However, in many YOLOv7 architectures, the number of hidden channels is + # not exactly half the number of output channels. + if hidden_channels is None: + hidden_channels = out_channels // 2 + + if split_channels is None: + split_channels = hidden_channels + + self.split1 = Conv(in_channels, split_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + self.split2 = Conv(in_channels, split_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + blocks = [block(split_channels, hidden_channels)] + for _ in range(depth - 1): + blocks.append(block(hidden_channels, hidden_channels)) + self.blocks = nn.ModuleList(blocks) + + total_channels = (split_channels * 2) + (hidden_channels * depth) + self.mix = Conv(total_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + def forward(self, x: Tensor) -> Tensor: + outputs = [self.split1(x), self.split2(x)] + x = outputs[-1] + for block in self.blocks: + x = block(x) + outputs.append(x) + return self.mix(torch.cat(outputs, dim=1)) + + +class CSPSPP(nn.Module): + """Spatial pyramid pooling module from the Cross Stage Partial Network from YOLOv4. + + Args: + in_channels: Number of input channels that the module expects. + out_channels: Number of output channels that the module produces. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ): + super().__init__() + + def conv(in_channels: int, out_channels: int, kernel_size: int = 1) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=kernel_size, stride=1, activation=activation, norm=norm) + + self.conv1 = nn.Sequential( + conv(in_channels, out_channels), + conv(out_channels, out_channels, kernel_size=3), + conv(out_channels, out_channels), + ) + self.conv2 = conv(in_channels, out_channels) + + self.maxpool1 = MaxPool(kernel_size=5, stride=1) + self.maxpool2 = MaxPool(kernel_size=9, stride=1) + self.maxpool3 = MaxPool(kernel_size=13, stride=1) + + self.mix1 = nn.Sequential( + conv(4 * out_channels, out_channels), + conv(out_channels, out_channels, kernel_size=3), + ) + self.mix2 = Conv(2 * out_channels, out_channels) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.conv1(x) + x2 = self.maxpool1(x1) + x3 = self.maxpool2(x1) + x4 = self.maxpool3(x1) + y1 = self.mix1(torch.cat((x1, x2, x3, x4), dim=1)) + y2 = self.conv2(x) + return self.mix2(torch.cat((y1, y2), dim=1)) + + +class FastSPP(nn.Module): + """Fast spatial pyramid pooling module from YOLOv5. + + Args: + in_channels: Number of input channels that the module expects. + out_channels: Number of output channels that the module produces. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + norm: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + activation: Optional[str] = "silu", + norm: Optional[str] = "batchnorm", + ): + super().__init__() + hidden_channels = in_channels // 2 + self.conv = Conv(in_channels, hidden_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + self.maxpool = MaxPool(kernel_size=5, stride=1) + self.mix = Conv(hidden_channels * 4, out_channels, kernel_size=1, stride=1, activation=activation, norm=norm) + + def forward(self, x: Tensor) -> Tensor: + y1 = self.conv(x) + y2 = self.maxpool(y1) + y3 = self.maxpool(y2) + y4 = self.maxpool(y3) + return self.mix(torch.cat((y1, y2, y3, y4), dim=1)) + + +class YOLOV4TinyBackbone(nn.Module): + """Backbone of the "tiny" network architecture from YOLOv4. + + Args: + in_channels: Number of channels in the input image. + width: Number of channels in the narrowest convolutional layer. The wider convolutional layers will use a number + of channels that is a multiple of this value. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int = 3, + width: int = 32, + activation: Optional[str] = "leaky", + normalization: Optional[str] = "batchnorm", + ): + super().__init__() + + def smooth(num_channels: int) -> nn.Module: + return Conv(num_channels, num_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + conv_module = Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + return nn.Sequential(OrderedDict([("downsample", conv_module), ("smooth", smooth(out_channels))])) + + def maxpool(out_channels: int) -> nn.Module: + return nn.Sequential( + OrderedDict( + [ + ("pad", nn.ZeroPad2d((0, 1, 0, 1))), + ("maxpool", MaxPool(kernel_size=2, stride=2)), + ("smooth", smooth(out_channels)), + ] + ) + ) + + def stage(out_channels: int, use_maxpool: bool) -> nn.Module: + if use_maxpool: + downsample_module = maxpool(out_channels) + else: + downsample_module = downsample(out_channels // 2, out_channels) + stage_module = TinyStage(out_channels, activation=activation, norm=normalization) + return nn.Sequential(OrderedDict([("downsample", downsample_module), ("stage", stage_module)])) + + stages = [ + Conv(in_channels, width, kernel_size=3, stride=2, activation=activation, norm=normalization), + stage(width * 2, False), + stage(width * 4, True), + stage(width * 8, True), + maxpool(width * 16), + ] + self.stages = nn.ModuleList(stages) + + def forward(self, x: Tensor) -> List[Tensor]: + c1 = self.stages[0](x) + c2, x = self.stages[1](c1) + c3, x = self.stages[2](x) + c4, x = self.stages[3](x) + c5 = self.stages[4](x) + return [c1, c2, c3, c4, c5] + + +class YOLOV4Backbone(nn.Module): + """A backbone that corresponds approximately to the Cross Stage Partial Network from YOLOv4. + + Args: + in_channels: Number of channels in the input image. + widths: Number of channels at each network stage. Typically ``(32, 64, 128, 256, 512, 1024)``. The P6 variant + adds one more stage with 1024 channels. + depths: Number of bottleneck layers at each network stage. Typically ``(1, 1, 2, 8, 8, 4)``. The P6 variant uses + ``(1, 1, 3, 15, 15, 7, 7)``. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int = 3, + widths: Sequence[int] = (32, 64, 128, 256, 512, 1024), + depths: Sequence[int] = (1, 1, 2, 8, 8, 4), + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + if len(widths) != len(depths): + raise ValueError("Width and depth has to be given for an equal number of stages.") + + def conv3x3(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def stage(in_channels: int, out_channels: int, depth: int) -> nn.Module: + csp = CSPStage( + out_channels, + out_channels, + depth=depth, + shortcut=True, + activation=activation, + norm=normalization, + ) + return nn.Sequential( + OrderedDict( + [ + ("downsample", downsample(in_channels, out_channels)), + ("csp", csp), + ] + ) + ) + + convs = [conv3x3(in_channels, widths[0])] + [conv3x3(widths[0], widths[0]) for _ in range(depths[0] - 1)] + self.stem = nn.Sequential(*convs) + self.stages = nn.ModuleList( + stage(in_channels, out_channels, depth) + for in_channels, out_channels, depth in zip(widths[:-1], widths[1:], depths[1:]) + ) + + def forward(self, x: Tensor) -> List[Tensor]: + x = self.stem(x) + outputs: List[Tensor] = [] + for stage in self.stages: + x = stage(x) + outputs.append(x) + return outputs + + +class YOLOV5Backbone(nn.Module): + """The Cross Stage Partial Network backbone from YOLOv5. + + Args: + in_channels: Number of channels in the input image. + width: Number of channels in the narrowest convolutional layer. The wider convolutional layers will use a number + of channels that is a multiple of this value. The values used by the different variants are 16 (yolov5n), 32 + (yolov5s), 48 (yolov5m), 64 (yolov5l), and 80 (yolov5x). + depth: Repeat the bottleneck layers this many times. Can be used to make the network deeper. The values used by + the different variants are 1 (yolov5n, yolov5s), 2 (yolov5m), 3 (yolov5l), and 4 (yolov5x). + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int = 3, + width: int = 64, + depth: int = 3, + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + def downsample(in_channels: int, out_channels: int, kernel_size: int = 3) -> nn.Module: + return Conv( + in_channels, out_channels, kernel_size=kernel_size, stride=2, activation=activation, norm=normalization + ) + + def stage(in_channels: int, out_channels: int, depth: int) -> nn.Module: + csp = CSPStage( + out_channels, + out_channels, + depth=depth, + shortcut=True, + activation=activation, + norm=normalization, + ) + return nn.Sequential( + OrderedDict( + [ + ("downsample", downsample(in_channels, out_channels)), + ("csp", csp), + ] + ) + ) + + stages = [ + downsample(in_channels, width, kernel_size=6), + stage(width, width * 2, depth), + stage(width * 2, width * 4, depth * 2), + stage(width * 4, width * 8, depth * 3), + stage(width * 8, width * 16, depth), + ] + self.stages = nn.ModuleList(stages) + + def forward(self, x: Tensor) -> List[Tensor]: + c1 = self.stages[0](x) + c2 = self.stages[1](c1) + c3 = self.stages[2](c2) + c4 = self.stages[3](c3) + c5 = self.stages[4](c4) + return [c1, c2, c3, c4, c5] + + +class YOLOV7Backbone(nn.Module): + """A backbone that corresponds to the W6 variant of the Efficient Layer Aggregation Network from YOLOv7. + + Args: + in_channels: Number of channels in the input image. + widths: Number of channels at each network stage. Before the first stage there will be one extra split of + spatial resolution by a ``ReOrg`` layer, producing ``in_channels * 4`` channels. + depth: Number of computational blocks at each network stage. YOLOv7-W6 backbone uses 2. + block_depth: Number of convolutional layers in one computational block. YOLOv7-W6 backbone uses 2. + activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic", + "linear", or "none". + normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". + """ + + def __init__( + self, + in_channels: int = 3, + widths: Sequence[int] = (64, 128, 256, 512, 768, 1024), + depth: int = 2, + block_depth: int = 2, + activation: Optional[str] = "silu", + normalization: Optional[str] = "batchnorm", + ) -> None: + super().__init__() + + def conv3x3(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) + + def downsample(in_channels: int, out_channels: int) -> nn.Module: + return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + + def stage(in_channels: int, out_channels: int) -> nn.Module: + elan = ELANStage( + out_channels, + out_channels, + depth=depth, + block_depth=block_depth, + activation=activation, + norm=normalization, + ) + return nn.Sequential( + OrderedDict( + [ + ("downsample", downsample(in_channels, out_channels)), + ("elan", elan), + ] + ) + ) + + self.stem = nn.Sequential(*[ReOrg(), conv3x3(in_channels * 4, widths[0])]) + self.stages = nn.ModuleList( + stage(in_channels, out_channels) for in_channels, out_channels in zip(widths[:-1], widths[1:]) + ) + + def forward(self, x: Tensor) -> List[Tensor]: + x = self.stem(x) + outputs: List[Tensor] = [] + for stage in self.stages: + x = stage(x) + outputs.append(x) + return outputs From b8f28a251a140ec8aa7b8af797a647e4a8516bed Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 3 Mar 2023 13:26:10 +0200 Subject: [PATCH 02/13] Fixed module imports --- torchvision/models/detection/__init__.py | 11 +++++++++-- torchvision/models/detection/yolo_networks.py | 8 ++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index faff00ff86a..6cc2e6e5515 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -1,4 +1,3 @@ -from .darknet_network import DarknetNetwork from .faster_rcnn import * from .fcos import * from .keypoint_rcnn import * @@ -7,4 +6,12 @@ from .ssd import * from .ssdlite import * from .yolo import * -from .yolo_networks import YOLOV4TinyNetwork, YOLOV4Network, YOLOV4P6Network, YOLOV5Network, YOLOV7Network, YOLOXNetwork +from .yolo_networks import ( + DarknetNetwork, + YOLOV4TinyNetwork, + YOLOV4Network, + YOLOV4P6Network, + YOLOV5Network, + YOLOV7Network, + YOLOXNetwork, +) diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py index 630cbe27711..ac6c9550462 100644 --- a/torchvision/models/detection/yolo_networks.py +++ b/torchvision/models/detection/yolo_networks.py @@ -1189,10 +1189,10 @@ def detect(prior_shape_idxs: Sequence[int], range: float) -> DetectionLayer: widths=widths, depth=2, block_depth=2, activation=activation, normalization=normalization ) - w3 = widths[-4] # 256 - w4 = widths[-3] # 512 - w5 = widths[-2] # 768 - w6 = widths[-1] # 1024 + w3 = widths[-4] + w4 = widths[-3] + w5 = widths[-2] + w6 = widths[-1] self.spp = spp(w6, w6 // 2) From d0769d97e47c93164d272143501ce5d7c237653a Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 17 Mar 2023 20:27:39 +0200 Subject: [PATCH 03/13] SimOTA matches targets also based on the size of the anchors (size_range) --- test/test_models_detection_box_utils.py | 77 ++++++ test/test_models_detection_target_matching.py | 82 ++----- torchvision/models/detection/box_utils.py | 83 +++++++ .../models/detection/target_matching.py | 231 +++++++++--------- torchvision/models/detection/yolo.py | 34 ++- torchvision/models/detection/yolo_loss.py | 32 ++- torchvision/models/detection/yolo_networks.py | 101 +++++--- torchvision/models/yolo.py | 2 +- 8 files changed, 392 insertions(+), 250 deletions(-) create mode 100644 test/test_models_detection_box_utils.py create mode 100644 torchvision/models/detection/box_utils.py diff --git a/test/test_models_detection_box_utils.py b/test/test_models_detection_box_utils.py new file mode 100644 index 00000000000..20d85411695 --- /dev/null +++ b/test/test_models_detection_box_utils.py @@ -0,0 +1,77 @@ +import pytest +import torch +from torchvision.models.detection.anchor_utils import grid_centers +from torchvision.models.detection.box_utils import aligned_iou, compare_box_sizes, iou_below, is_inside_box + + +@pytest.mark.parametrize( + "dims1, dims2, expected_ious", + [ + ( + torch.tensor([[1.0, 1.0], [10.0, 1.0], [100.0, 10.0]]), + torch.tensor([[1.0, 10.0], [2.0, 20.0]]), + torch.tensor([[1.0 / 10.0, 1.0 / 40.0], [1.0 / 19.0, 2.0 / 48.0], [10.0 / 1000.0, 20.0 / 1020.0]]), + ) + ], +) +def test_aligned_iou(dims1, dims2, expected_ious): + torch.testing.assert_close(aligned_iou(dims1, dims2), expected_ious) + + +def test_iou_below(): + tl = torch.rand((10, 10, 3, 2)) * 100 + br = tl + 10 + pred_boxes = torch.cat((tl, br), -1) + target_boxes = torch.stack((pred_boxes[1, 1, 0], pred_boxes[3, 5, 1])) + result = iou_below(pred_boxes, target_boxes, 0.9) + assert result.shape == (10, 10, 3) + assert not result[1, 1, 0] + assert not result[3, 5, 1] + + +def test_is_inside_box(): + """ + centers: + [[1,1; 3,1; 5,1; 7,1; 9,1; 11,1; 13,1; 15,1; 17,1; 19,1] + [1,3; 3,3; 5,3; 7,3; 9,3; 11,3; 13,3; 15,3; 17,3; 19,3] + [1,5; 3,5; 5,5; 7,5; 9,5; 11,5; 13,5; 15,5; 17,5; 19,5] + [1,7; 3,7; 5,7; 7,7; 9,7; 11,7; 13,7; 15,7; 17,7; 19,7] + [1,9; 3,9; 5,9; 7,9; 9,9; 11,9; 13,9; 15,9; 17,9; 19,9]] + + is_inside[..., 0]: + [[F, F, F, F, F, F, F, F, F, F] + [F, T, T, F, F, F, F, F, F, F] + [F, T, T, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F]] + + is_inside[..., 1]: + [[F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, F, F, F] + [F, F, F, F, F, F, F, T, T, F]] + """ + size = torch.tensor([10, 5]) + centers = grid_centers(size) * 2.0 + centers = centers.view(-1, 2) + boxes = torch.tensor([[2, 2, 6, 6], [14, 8, 18, 10]]) + is_inside = is_inside_box(centers, boxes).view(5, 10, 2) + assert torch.count_nonzero(is_inside) == 6 + assert torch.all(is_inside[1:3, 1:3, 0]) + assert torch.all(is_inside[4, 7:9, 1]) + + +def test_compare_box_sizes(): + wh1 = torch.tensor([[24, 11], [12, 25], [26, 27], [15, 17]]) + wh2 = torch.tensor([[10, 30], [15, 9]]) + result = compare_box_sizes(wh1, wh2, 2.0) + assert result.shape == (4, 2) + assert not result[0, 0] # 24 / 10 >= 2 + assert result[0, 1] # 24 / 15 < 2, 11 / 9 < 2 + assert result[1, 0] # 12 / 10 < 2, 30 / 25 < 2 + assert not result[1, 1] # 25 / 9 >= 2 + assert not result[2, 0] # 26 / 10 >= 2 + assert not result[2, 1] # 27 / 9 >= 2 + assert result[3, 0] # 15 / 10 < 2, 30 / 17 < 2 + assert result[3, 1] # 15 / 15 < 2, 17 / 9 < 2 diff --git a/test/test_models_detection_target_matching.py b/test/test_models_detection_target_matching.py index dc99669a947..aaf678575e3 100644 --- a/test/test_models_detection_target_matching.py +++ b/test/test_models_detection_target_matching.py @@ -1,80 +1,26 @@ -import pytest import torch -from torchvision.models.detection.anchor_utils import grid_centers -from torchvision.models.detection.target_matching import aligned_iou, iou_below, is_inside_box, _sim_ota_match - -@pytest.mark.parametrize( - "dims1, dims2, expected_ious", - [ - ( - torch.tensor([[1.0, 1.0], [10.0, 1.0], [100.0, 10.0]]), - torch.tensor([[1.0, 10.0], [2.0, 20.0]]), - torch.tensor([[1.0 / 10.0, 1.0 / 40.0], [1.0 / 19.0, 2.0 / 48.0], [10.0 / 1000.0, 20.0 / 1020.0]]), - ) - ], -) -def test_aligned_iou(dims1, dims2, expected_ious): - torch.testing.assert_close(aligned_iou(dims1, dims2), expected_ious) - - -def test_iou_below(): - tl = torch.rand((10, 10, 3, 2)) * 100 - br = tl + 10 - pred_boxes = torch.cat((tl, br), -1) - target_boxes = torch.stack((pred_boxes[1, 1, 0], pred_boxes[3, 5, 1])) - result = iou_below(pred_boxes, target_boxes, 0.9) - assert result.shape == (10, 10, 3) - assert not result[1, 1, 0] - assert not result[3, 5, 1] - - -def test_is_inside_box(): - """ - centers: - [[1,1; 3,1; 5,1; 7,1; 9,1; 11,1; 13,1; 15,1; 17,1; 19,1] - [1,3; 3,3; 5,3; 7,3; 9,3; 11,3; 13,3; 15,3; 17,3; 19,3] - [1,5; 3,5; 5,5; 7,5; 9,5; 11,5; 13,5; 15,5; 17,5; 19,5] - [1,7; 3,7; 5,7; 7,7; 9,7; 11,7; 13,7; 15,7; 17,7; 19,7] - [1,9; 3,9; 5,9; 7,9; 9,9; 11,9; 13,9; 15,9; 17,9; 19,9]] - - is_inside[0]: - [[F, F, F, F, F, F, F, F, F, F] - [F, T, T, F, F, F, F, F, F, F] - [F, T, T, F, F, F, F, F, F, F] - [F, F, F, F, F, F, F, F, F, F] - [F, F, F, F, F, F, F, F, F, F]] - - is_inside[1]: - [[F, F, F, F, F, F, F, F, F, F] - [F, F, F, F, F, F, F, F, F, F] - [F, F, F, F, F, F, F, F, F, F] - [F, F, F, F, F, F, F, F, F, F] - [F, F, F, F, F, F, F, T, T, F]] - """ - size = torch.tensor([10, 5]) - centers = grid_centers(size) * 2.0 - centers = centers.view(-1, 2) - boxes = torch.tensor([[2, 2, 6, 6], [14, 8, 18, 10]]) - is_inside = is_inside_box(centers, boxes).view(2, 5, 10) - assert torch.count_nonzero(is_inside) == 6 - assert torch.all(is_inside[0, 1:3, 1:3]) - assert torch.all(is_inside[1, 4, 7:9]) +from torchvision.models.detection.target_matching import _sim_ota_match def test_sim_ota_match(): - # IoUs will determined that 2 and 1 predictions will be selected for the first and the second target. - ious = torch.tensor([[0.1, 0.1, 0.9, 0.9], [0.2, 0.3, 0.4, 0.1]]) + # For each of the two targets, k will be the sum of the IoUs. 2 and 1 predictions will be selected for the first and + # the second target respectively. + ious = torch.tensor([[0.1, 0.2], [0.1, 0.3], [0.9, 0.4], [0.9, 0.1]]) # Costs will determine that the first and the last prediction will be selected for the first target, and the first - # prediction will be selected for the second target. Since the first prediction was selected for both targets, it - # will be matched to the best target only (the second one). - costs = torch.tensor([[0.3, 0.5, 0.4, 0.3], [0.1, 0.2, 0.5, 0.3]]) + # prediction will be selected for the second target. The first prediction was selected for two targets, but it will + # be matched to the best target only (the second one). + costs = torch.tensor([[0.3, 0.1], [0.5, 0.2], [0.4, 0.5], [0.3, 0.3]]) matched_preds, matched_targets = _sim_ota_match(costs, ious) + + # The first and the last prediction were matched. assert len(matched_preds) == 4 assert matched_preds[0] assert not matched_preds[1] assert not matched_preds[2] assert matched_preds[3] - assert len(matched_targets) == 2 # Two predictions were matched. - assert matched_targets[0] == 1 # Which target was matched to the first prediction. - assert matched_targets[1] == 0 # Which target was matched to the last prediction. + + # The first prediction was matched to the target 1 and the last prediction was matched to target 0. + assert len(matched_targets) == 2 + assert matched_targets[0] == 1 + assert matched_targets[1] == 0 diff --git a/torchvision/models/detection/box_utils.py b/torchvision/models/detection/box_utils.py new file mode 100644 index 00000000000..af3481dba81 --- /dev/null +++ b/torchvision/models/detection/box_utils.py @@ -0,0 +1,83 @@ +import torch +from torch import Tensor + +from ...ops import box_iou + + +def aligned_iou(wh1: Tensor, wh2: Tensor) -> Tensor: + """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at + the same coordinates. + + Args: + wh1: An ``[N, 2]`` matrix of box shapes (width and height). + wh2: An ``[M, 2]`` matrix of box shapes (width and height). + + Returns: + An ``[N, M]`` matrix of pairwise IoU values for every element in ``wh1`` and ``wh2`` + """ + area1 = wh1[:, 0] * wh1[:, 1] # [N] + area2 = wh2[:, 0] * wh2[:, 1] # [M] + + inter_wh = torch.min(wh1[:, None, :], wh2) # [N, M, 2] + inter = inter_wh[:, :, 0] * inter_wh[:, :, 1] # [N, M] + union = area1[:, None] + area2 - inter # [N, M] + + return inter / union + + +def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Tensor: + """Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target + significantly (IoU greater than ``threshold``). + + Args: + pred_boxes: The predicted corner coordinates. Tensor of size ``[height, width, boxes_per_cell, 4]``. + target_boxes: Corner coordinates of the target boxes. Tensor of size ``[height, width, boxes_per_cell, 4]``. + + Returns: + A boolean tensor sized ``[height, width, boxes_per_cell]``, with ``False`` where the predicted box overlaps a + target significantly and ``True`` elsewhere. + """ + shape = pred_boxes.shape[:-1] + pred_boxes = pred_boxes.view(-1, 4) + ious = box_iou(pred_boxes, target_boxes) + best_iou = ious.max(-1).values + below_threshold = best_iou <= threshold + return below_threshold.view(shape) + + +def is_inside_box(points: Tensor, boxes: Tensor) -> Tensor: + """Get pairwise truth values of whether the point is inside the box. + + Args: + points: Point (x, y) coordinates, a tensor shaped ``[points, 2]``. + boxes: Box (x1, y1, x2, y2) coordinates, a tensor shaped ``[boxes, 4]``. + + Returns: + A tensor shaped ``[points, boxes]`` containing pairwise truth values of whether the points are inside the boxes. + """ + lt = points[:, None, :] - boxes[None, :, :2] # [boxes, points, 2] + rb = boxes[None, :, 2:] - points[:, None, :] # [boxes, points, 2] + deltas = torch.cat((lt, rb), -1) # [points, boxes, 4] + return deltas.min(-1).values > 0.0 # [points, boxes] + + +def compare_box_sizes(wh1: Tensor, wh2: Tensor, threshold: float) -> Tensor: + """Compares the dimensions of the boxes pairwise and returns a mask that indicates which pairs have similar + sizes. + + For each pair of boxes, calculates the largest ratio that can be obtained by dividing the widths with each other or + dividing the heights with each other. Returns a mask that indicates which pairs have a ratio less than the given + threshold. + + Args: + wh1: An ``[N, 2]`` matrix of box shapes (width and height). + wh2: An ``[M, 2]`` matrix of box shapes (width and height). + threshold: A threshold for the size ratio. + + Returns: + An ``[N, M]`` matrix of truth values indicating which box pairs have the maximum size ratio below the threshold. + """ + wh_ratio = wh1[:, None, :] / wh2[None, :, :] # [M, N, 2] + wh_ratio = torch.max(wh_ratio, 1.0 / wh_ratio) + wh_ratio = wh_ratio.max(2).values # [M, N] + return wh_ratio < threshold diff --git a/torchvision/models/detection/target_matching.py b/torchvision/models/detection/target_matching.py index d78456a74f5..9a87c3c91e7 100644 --- a/torchvision/models/detection/target_matching.py +++ b/torchvision/models/detection/target_matching.py @@ -4,70 +4,12 @@ import torch from torch import Tensor -from ...ops import box_convert, box_iou +from ...ops import box_convert from .anchor_utils import grid_centers +from .box_utils import aligned_iou, iou_below, is_inside_box, compare_box_sizes from .yolo_loss import YOLOLoss -def aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: - """Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at - the same coordinates. - - Args: - dims1: Width and height of `N` boxes. Tensor of size ``[N, 2]``. - dims2: Width and height of `M` boxes. Tensor of size ``[M, 2]``. - - Returns: - Tensor of size ``[N, M]`` containing the pairwise IoU values for every element in ``dims1`` and ``dims2`` - """ - area1 = dims1[:, 0] * dims1[:, 1] # [N] - area2 = dims2[:, 0] * dims2[:, 1] # [M] - - inter_wh = torch.min(dims1[:, None, :], dims2) # [N, M, 2] - inter = inter_wh[:, :, 0] * inter_wh[:, :, 1] # [N, M] - union = area1[:, None] + area2 - inter # [N, M] - - return inter / union - - -def iou_below(pred_boxes: Tensor, target_boxes: Tensor, threshold: float) -> Tensor: - """Creates a binary mask whose value will be ``True``, unless the predicted box overlaps any target - significantly (IoU greater than ``threshold``). - - Args: - pred_boxes: The predicted corner coordinates. Tensor of size ``[height, width, boxes_per_cell, 4]``. - target_boxes: Corner coordinates of the target boxes. Tensor of size ``[height, width, boxes_per_cell, 4]``. - - Returns: - A boolean tensor sized ``[height, width, boxes_per_cell]``, with ``False`` where the predicted box overlaps a - target significantly and ``True`` elsewhere. - """ - shape = pred_boxes.shape[:-1] - pred_boxes = pred_boxes.view(-1, 4) - ious = box_iou(pred_boxes, target_boxes) - best_iou = ious.max(-1).values - below_threshold = best_iou <= threshold - return below_threshold.view(shape) - - -def is_inside_box(points: Tensor, boxes: Tensor) -> Tensor: - """Get pairwise truth values of whether the point is inside the box. - - Args: - points: point (x, y) coordinates, [points, 2] - boxes: box (x1, y1, x2, y2) coordinates, [boxes, 4] - - Returns: - A tensor shaped ``[boxes, points]`` containing pairwise truth values of whether the points are inside the boxes. - """ - points = points.unsqueeze(0) # [1, points, 2] - boxes = boxes.unsqueeze(1) # [boxes, 1, 4] - lt = points - boxes[..., :2] # [boxes, points, 2] - rb = boxes[..., 2:] - points # [boxes, points, 2] - deltas = torch.cat((lt, rb), -1) # [boxes, points, 4] - return deltas.min(-1).values > 0.0 # [boxes, points] - - class ShapeMatching(ABC): """Selects which anchors are used to predict each target, by comparing the shape of the target box to a set of prior shapes. @@ -249,12 +191,7 @@ def __init__( def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) - - wh_ratio = wh[:, None, :] / prior_wh[None, :, :] # [num_targets, num_anchors, 2] - wh_ratio = torch.max(wh_ratio, 1.0 / wh_ratio) - wh_ratio = wh_ratio.max(2).values # [num_targets, num_anchors] - below_threshold = (wh_ratio < self.threshold).nonzero() - return below_threshold.T + return compare_box_sizes(wh, prior_wh, self.threshold).nonzero().T def _sim_ota_match(costs: Tensor, ious: Tensor) -> Tuple[Tensor, Tensor]: @@ -265,35 +202,38 @@ def _sim_ota_match(costs: Tensor, ious: Tensor) -> Tuple[Tensor, Tensor]: predicted boxes. Args: - costs: Sum of losses for (prediction, target) pairs: ``[targets, predictions]`` - ious: IoUs for (prediction, target) pairs: ``[targets, predictions]`` + costs: A ``[predictions, targets]`` matrix of losses. + ious: A ``[predictions, targets]`` matrix of IoUs. Returns: A mask of predictions that were matched, and the indices of the matched targets. The latter contains as many elements as there are ``True`` values in the mask. """ + num_preds, num_targets = ious.shape + matching_matrix = torch.zeros_like(costs, dtype=torch.bool) if ious.numel() > 0: # For each target, define k as the sum of the 10 highest IoUs. - top10_iou = torch.topk(ious, min(10, ious.shape[1])).values.sum(1) + top10_iou = torch.topk(ious, min(10, num_preds), dim=0).values.sum(0) ks = torch.clip(top10_iou.int(), min=1) + assert len(ks) == num_targets - # For each target, select k predictions with lowest cost. - for target_idx, (cost, k) in enumerate(zip(costs, ks)): - prediction_idx = torch.topk(cost, k, largest=False).indices - matching_matrix[target_idx, prediction_idx] = True + # For each target, select k predictions with the lowest cost. + for target_idx, (target_costs, k) in enumerate(zip(costs.T, ks)): + pred_idx = torch.topk(target_costs, k, largest=False).indices + matching_matrix[pred_idx, target_idx] = True # If there's more than one match for some prediction, match it with the best target. Now we consider all # targets, regardless of whether they were originally matched with the prediction or not. - more_than_one_match = matching_matrix.sum(0) > 1 - best_targets = costs[:, more_than_one_match].argmin(0) - matching_matrix[:, more_than_one_match] = False - matching_matrix[best_targets, more_than_one_match] = True + more_than_one_match = matching_matrix.sum(1) > 1 + best_targets = costs[more_than_one_match, :].argmin(1) + matching_matrix[more_than_one_match, :] = False + matching_matrix[more_than_one_match, best_targets] = True # For those predictions that were matched, get the index of the target. - pred_mask = matching_matrix.sum(0) > 0 - target_selector = matching_matrix[:, pred_mask].int().argmax(0) + pred_mask = matching_matrix.sum(1) > 0 + target_selector = matching_matrix[pred_mask, :].int().argmax(1) return pred_mask, target_selector @@ -303,14 +243,29 @@ class SimOTAMatching: This is the matching rule used by YOLOX. Args: + prior_shapes: A list of all the prior box dimensions. The list should contain (width, height) tuples in the + network input resolution. + prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that + this layer uses. loss_func: A ``LossFunction`` object that can be used to calculate the pairwise costs. - range: For each target, restrict to the anchors that are within an `N x N` grid cell are centered at the target, - where `N` is the value of this parameter. + spatial_range: For each target, restrict to the anchors that are within an `N × N` grid cell are centered at the + target, where `N` is the value of this parameter. + size_range: For each target, restrict to the anchors whose prior dimensions are not larger than the target + dimensions multiplied by this value and not smaller than the target dimensions divided by this value. """ - def __init__(self, loss_func: YOLOLoss, range: float = 5.0) -> None: + def __init__( + self, + prior_shapes: Sequence[Tuple[int, int]], + prior_shape_idxs: Sequence[int], + loss_func: YOLOLoss, + spatial_range: float, + size_range: float, + ) -> None: + self.prior_shapes = [prior_shapes[idx] for idx in prior_shape_idxs] self.loss_func = loss_func - self.range = range + self.spatial_range = spatial_range + self.size_range = size_range def __call__( self, @@ -329,47 +284,89 @@ def __call__( A mask of predictions that were matched, background mask (inverse of the first mask), and the indices of the matched targets. The last tensor contains as many elements as there are ``True`` values in the first mask. """ - height, width, boxes_per_cell, num_classes = preds["classprobs"].shape - device = preds["boxes"].device + height, width, boxes_per_cell, _ = preds["boxes"].shape + prior_mask, anchor_inside_target = self._get_prior_mask(targets, image_size, width, height, boxes_per_cell) + prior_preds = { + "boxes": preds["boxes"][prior_mask], + "confidences": preds["confidences"][prior_mask], + "classprobs": preds["classprobs"][prior_mask], + } + + losses, ious = self.loss_func.pairwise(prior_preds, targets, input_is_normalized=False) + costs = losses.overlap + losses.confidence + losses.classification + costs += 100000.0 * ~anchor_inside_target + pred_mask, target_selector = _sim_ota_match(costs, ious) + + # Add the anchor dimension to the mask and replace True values with the results of the actual SimOTA matching. + prior_mask[prior_mask.nonzero().T.tolist()] = pred_mask + + background_mask = torch.logical_not(prior_mask) + return prior_mask, background_mask, target_selector + + def _get_prior_mask( + self, + targets: Dict[str, Tensor], + image_size: Tensor, + grid_width: int, + grid_height: int, + boxes_per_cell: int, + ) -> Tuple[Tensor, Tensor]: + """Creates a mask for selecting the "center prior" anchors. + + In the first step we restrict ourselves to the grid cells whose center is inside or close enough to one or more + targets. + + Args: + targets: Training targets for a single image. + image_size: Input image width and height. + grid_width: Width of the feature grid. + grid_height: Height of the feature grid. + boxes_per_cell: Number of boxes that will be predicted per feature grid cell. + + Returns: + Two masks, a ``[grid_height, grid_width, boxes_per_cell]`` mask for selecting anchors that are close and + similar in shape to a target, and an ``[anchors, targets]`` matrix that indicates which targets are inside + those anchors. + """ # A multiplier for scaling feature map coordinates to image coordinates - grid_size = torch.tensor([width, height], device=device) + grid_size = torch.tensor([grid_width, grid_height], device=targets["boxes"].device) grid_to_image = torch.true_divide(image_size, grid_size) - # Create a matrix for selecting the anchors that are inside the target bounding boxes. - centers = grid_centers(grid_size).view(-1, 2) * grid_to_image - inside_matrix = is_inside_box(centers, targets["boxes"]) - - # Set the width and height of all target bounding boxes to self.range grid cells and create a matrix for - # selecting the anchors that are now inside the boxes. If a small target has no anchors inside its bounding - # box, it will be matched to one of these anchors, but a high penalty will ensure that anchors that are inside - # the bounding box will be preferred. + # Get target center coordinates and dimensions. xywh = box_convert(targets["boxes"], in_fmt="xyxy", out_fmt="cxcywh") xy = xywh[:, :2] - wh = self.range * grid_to_image * torch.ones_like(xy) - xywh = torch.cat((xy, wh), -1) - boxes = box_convert(xywh, in_fmt="cxcywh", out_fmt="xyxy") - close_matrix = is_inside_box(centers, boxes) - - # In the first step we restrict ourselves to the grid cells whose center is inside or close enough to one or - # more targets. The prediction grids are flattened and masked using a [height * width] boolean vector. - mask = (inside_matrix | close_matrix).sum(0) > 0 - shape = (height * width, boxes_per_cell) - fg_preds = { - "boxes": preds["boxes"].view(*shape, 4)[mask].view(-1, 4), - "confidences": preds["confidences"].view(shape)[mask].view(-1), - "classprobs": preds["classprobs"].view(*shape, num_classes)[mask].view(-1, num_classes), - } + wh = xywh[:, 2:] - losses, ious = self.loss_func.pairwise(fg_preds, targets, input_is_normalized=False) - costs = losses.overlap + losses.confidence + losses.classification - costs += 100000.0 * ~inside_matrix[:, mask].repeat_interleave(boxes_per_cell, 1) - pred_mask, target_selector = _sim_ota_match(costs, ious) + # Create a [boxes_per_cell, targets] tensor for selecting prior shapes that are close enough to the target + # dimensions. + prior_wh = torch.tensor(self.prior_shapes, device=targets["boxes"].device) # XXX Enable size filtering. + shape_selector = compare_box_sizes(prior_wh, wh, self.size_range) # XXX Enable size filtering. - # Add the anchor dimension to the mask and replace True values with the results of the actual SimOTA matching. - mask = mask.view(height, width).unsqueeze(-1).repeat(1, 1, boxes_per_cell) - mask[mask.nonzero().T.tolist()] = pred_mask + # Create a [grid_cells, targets] tensor for selecting spatial locations that are inside target bounding boxes. + centers = grid_centers(grid_size).view(-1, 2) * grid_to_image + inside_selector = is_inside_box(centers, targets["boxes"]) + + # Combine the above selectors into a [grid_cells, boxes_per_cell, targets] tensor for selecting anchors that are + # inside target bounding boxes and close enough shape. + inside_selector = inside_selector[:, None, :].repeat(1, boxes_per_cell, 1) + inside_selector = torch.logical_and(inside_selector, shape_selector) # XXX Enable size filtering. + + # Set the width and height of all target bounding boxes to self.range grid cells and create a selector for + # anchors that are now inside the boxes. If a small target has no anchors inside its bounding box, it will be + # matched to one of these anchors, but a high penalty will ensure that anchors that are inside the bounding box + # will be preferred. + wh = self.spatial_range * grid_to_image * torch.ones_like(xy) + xywh = torch.cat((xy, wh), -1) + boxes = box_convert(xywh, in_fmt="cxcywh", out_fmt="xyxy") + close_selector = is_inside_box(centers, boxes) - background_mask = torch.logical_not(mask) + # Create a [grid_cells, boxes_per_cell, targets] tensor for selecting anchors that are spatially close to a + # target and whose shape is close enough to the target. + close_selector = close_selector[:, None, :].repeat(1, boxes_per_cell, 1) + close_selector = torch.logical_and(close_selector, shape_selector) # XXX Enable size filtering. - return mask, background_mask, target_selector + mask = torch.logical_or(inside_selector, close_selector).sum(-1) > 0 + mask = mask.view(grid_height, grid_width, boxes_per_cell) + inside_selector = inside_selector.view(grid_height, grid_width, boxes_per_cell, -1) + return mask, inside_selector[mask] diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py index 969ef8ed409..f5a2d1cfcea 100644 --- a/torchvision/models/detection/yolo.py +++ b/torchvision/models/detection/yolo.py @@ -16,14 +16,13 @@ TARGETS = List[TARGET] -def validate_batch(images: List[Tensor], targets: TARGETS) -> None: - """Reads a batch of data, validates the format, and stacks the images into a single tensor. +def validate_batch(images: List[Tensor], targets: Optional[TARGETS]) -> None: + """Validates the format of a batch of data. Args: - batch: The batch of data read by the :class:`~torch.utils.data.DataLoader`. - - Returns: - The input batch with images stacked into a single tensor. + images: A list of image tensors. + targets: A list of target dictionaries or ``None``. If a list is provided, there should be as many target + dictionaries as there are images. """ if not images: raise ValueError("No images in batch.") @@ -80,19 +79,21 @@ class YOLO(nn.Module): have different image sizes, as long as the size is divisible by the ratio in which the network downsamples the input. - During training, the model expects both the image tensors and a list of targets. *Each target is a dictionary - containing the following tensors*: + During training, the model expects both the image tensors and a list of targets. It's possible to train a model + using one integer class label per box, but the YOLO model supports also multiple classes per box. For multi-class + training, simply use a boolean matrix that indicates which classes are assigned to which boxes, in place of the + class labels. *Each target is a dictionary containing the following tensors*: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in `(x1, y1, x2, y2)` format - labels (``Int64Tensor[N]`` or ``BoolTensor[N, classes]``): the class label or a boolean class mask for each ground-truth box - :func:`~.yolo_module.YOLO.forward` method returns all predictions from all detection layers in one tensor with shape + :func:`~.yolo.YOLO.forward` method returns all predictions from all detection layers in one tensor with shape ``[N, anchors, classes + 5]``, where ``anchors`` is the total number of anchors in all detection layers. The coordinates are scaled to the input image size. During training it also returns a dictionary containing the classification, box overlap, and confidence losses. - During inference, the model requires only the image tensor. :func:`~.yolo_module.YOLO.infer` method filters and + During inference, the model requires only the image tensor. :func:`~.yolo.YOLO.infer` method filters and processes the predictions. If a prediction has a high score for more than one class, it will be duplicated. *The processed output is returned in a dictionary containing the following tensors*: @@ -147,11 +148,7 @@ def __init__( self.nms_threshold = nms_threshold self.detections_per_image = detections_per_image - def forward( - self, - images: List[Tensor], - targets: Optional[TARGETS] = None, - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def forward(self, images: List[Tensor], targets: Optional[TARGETS] = None) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are provided, computes the losses from the detection layers. @@ -159,10 +156,9 @@ def forward( that depends on the size of the feature map and the number of anchors per feature map cell. Args: - images: Images to be processed. Tensor of size - ``[batch_size, channels, height, width]``. - targets: If set, computes losses from detection layers against these targets. A list of - target dictionaries, one for each image. + images: Images to be processed. Tensor of size ``[batch_size, channels, height, width]``. + targets: If given, computes losses from detection layers against these targets. A list of target + dictionaries, one for each image. Returns: detections (:class:`~torch.Tensor`), losses (:class:`~torch.Tensor`): Detections, and if targets were diff --git a/torchvision/models/detection/yolo_loss.py b/torchvision/models/detection/yolo_loss.py index d8f05f93d49..ec8d66c9fd0 100644 --- a/torchvision/models/detection/yolo_loss.py +++ b/torchvision/models/detection/yolo_loss.py @@ -75,27 +75,27 @@ def _pairwise_confidence_loss( Args: preds: An ``[N]`` vector of predicted confidences. - overlap: An ``[M, N]`` matrix of overlaps between all target and predicted bounding boxes. + overlap: An ``[N, M]`` matrix of overlaps between all predicted and target bounding boxes. bce_func: A function for calculating binary cross entropy. predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target confidence is one if there's an object, and 1.0 means that the target confidence is the overlap. Returns: - An ``[M, N]`` matrix of confidence losses between all targets and predictions. + An ``[N, M]`` matrix of confidence losses between all predictions and targets. """ if predict_overlap is not None: # When predicting overlap, target confidence is different for each pair of a prediction and a target. The - # tensors have to be broadcasted to [M, N]. - preds = preds.unsqueeze(0).expand(overlap.shape) + # tensors have to be broadcasted to [N, M]. + preds = preds.unsqueeze(1).expand(overlap.shape) targets = torch.ones_like(preds) - predict_overlap # Distance-IoU may return negative "overlaps", so we have to make sure that the targets are not negative. targets += predict_overlap * overlap.detach().clamp(min=0) return bce_func(preds, targets, reduction="none") else: - # When not predicting overlap, target confidence is the same for every target, but we should still return a + # When not predicting overlap, target confidence is the same for every prediction, but we should still return a # matrix. targets = torch.ones_like(preds) - return bce_func(preds, targets, reduction="none").unsqueeze(0).expand(overlap.shape) + return bce_func(preds, targets, reduction="none").unsqueeze(1).expand(overlap.shape) def _foreground_confidence_loss( @@ -216,28 +216,36 @@ def pairwise( This method is called for obtaining costs for SimOTA matching. Args: - preds: A dictionary of predictions, containing "boxes", "confidences", and "classprobs". - targets: A dictionary of training targets, containing "boxes" and "labels". + preds: A dictionary of predictions, containing "boxes", "confidences", and "classprobs". Each tensor + contains `N` rows. + targets: A dictionary of training targets, containing "boxes" and "labels". Each tensor contains `M` rows. input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. Returns: - Loss matrices and an overlap matrix. + Loss matrices and an overlap matrix. Each matrix is shaped ``[N, M]``. """ + loss_shape = torch.Size([len(preds["boxes"]), len(targets["boxes"])]) + if input_is_normalized: bce_func = binary_cross_entropy else: bce_func = binary_cross_entropy_with_logits - overlap = self._pairwise_overlap(targets["boxes"], preds["boxes"]) + overlap = self._pairwise_overlap(preds["boxes"], targets["boxes"]) + assert overlap.shape == loss_shape + overlap_loss = 1.0 - overlap + assert overlap_loss.shape == loss_shape confidence_loss = _pairwise_confidence_loss(preds["confidences"], overlap, bce_func, self.predict_overlap) + assert confidence_loss.shape == loss_shape - pred_probs = preds["classprobs"].unsqueeze(0) # [1, preds, classes] + pred_probs = preds["classprobs"].unsqueeze(1) # [N, 1, classes] target_probs = _target_labels_to_probs(targets["labels"], pred_probs.shape[-1], pred_probs.dtype) - target_probs = target_probs.unsqueeze(1) # [targets, 1, classes] + target_probs = target_probs.unsqueeze(0) # [1, M, classes] pred_probs, target_probs = torch.broadcast_tensors(pred_probs, target_probs) class_loss = bce_func(pred_probs, target_probs, reduction="none").sum(-1) + assert class_loss.shape == loss_shape losses = Losses( overlap_loss * self.overlap_multiplier, diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py index ac6c9550462..e720205d174 100644 --- a/torchvision/models/detection/yolo_networks.py +++ b/torchvision/models/detection/yolo_networks.py @@ -252,7 +252,8 @@ def _create_detection_layer( prior_shape_idxs: Sequence[int], matching_algorithm: Optional[str] = None, matching_threshold: Optional[float] = None, - sim_ota_range: float = 5.0, + spatial_range: float = 5.0, + size_range: float = 4.0, ignore_bg_threshold: float = 0.7, overlap_func: Union[str, Callable] = "ciou", predict_overlap: float = 1.0, @@ -274,8 +275,10 @@ def _create_detection_layer( ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - sim_ota_range: The "simota" matching algorithm will restrict to the anchors that are within an `N x N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. @@ -301,7 +304,7 @@ def _create_detection_layer( loss_func = YOLOLoss( overlap_func, None, overlap_loss_multiplier, confidence_loss_multiplier, class_loss_multiplier ) - matching_func = SimOTAMatching(loss_func, sim_ota_range) + matching_func = SimOTAMatching(prior_shapes, prior_shape_idxs, loss_func, spatial_range, size_range) elif matching_algorithm == "size": if matching_threshold is None: raise ValueError("matching_threshold is required with size ratio matching.") @@ -441,6 +444,10 @@ class YOLOV4TinyNetwork(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. @@ -571,6 +578,10 @@ class YOLOV4Network(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. @@ -732,6 +743,10 @@ class YOLOV4P6Network(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. @@ -922,6 +937,10 @@ class YOLOV5Network(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. @@ -1081,6 +1100,8 @@ class YOLOV7Network(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. @@ -1176,7 +1197,7 @@ def detect(prior_shape_idxs: Sequence[int], range: float) -> DetectionLayer: return _create_detection_layer( prior_shapes, prior_shape_idxs, - sim_ota_range=range, + spatial_range=range, num_classes=num_classes, input_is_normalized=False, **kwargs, @@ -1391,6 +1412,10 @@ class YOLOXNetwork(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. @@ -1525,39 +1550,41 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU class DarknetNetwork(nn.Module): - """This class can be used to parse the configuration files of the Darknet YOLOv4 implementation.""" + """This class can be used to parse the configuration files of the Darknet YOLOv4 implementation. + + Iterates through the layers from the configuration and creates corresponding PyTorch modules. If ``weights_path`` is + given and points to a Darknet model file, loads the convolutional layer weights from the file. + + Args: + config_path: Path to a Darknet configuration file that defines the network architecture. + weights_path: Path to a Darknet model file. If given, the model weights will be read from this file. + in_channels: Number of channels in the input image. + matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule + from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given + ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that + gives the highest IoU, default). + matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. + ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor + has IoU with some target greater than this threshold, the predictor will not be taken into account when + calculating the confidence loss. + overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a + function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and + "ciou". + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target + confidence is one if there's an object, and 1.0 means that the target confidence is the output of + ``overlap_func``. + overlap_loss_multiplier: Overlap loss will be scaled by this value. + confidence_loss_multiplier: Confidence loss will be scaled by this value. + class_loss_multiplier: Classification loss will be scaled by this value. + """ def __init__( self, config_path: str, weights_path: Optional[str] = None, in_channels: Optional[int] = None, **kwargs: Any ) -> None: - """Parses a Darknet configuration file and creates the network structure. - - Iterates through the layers from the configuration and creates corresponding PyTorch modules. If - ``weights_path`` is given and points to a Darknet model file, loads the convolutional layer weights from the - file. - - Args: - config_path: Path to a Darknet configuration file that defines the network architecture. - weights_path: Path to a Darknet model file. If given, the model weights will be read from this file. - in_channels: Number of channels in the input image. - matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching - rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is - below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the - prior shape that gives the highest IoU, default). - matching_threshold: Threshold for "size" and "iou" matching algorithms. - ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding - anchor has IoU with some target greater than this threshold, the predictor will not be taken into - account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or - a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", - and "ciou". - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of - ``overlap_func``. - overlap_loss_multiplier: Overlap loss will be scaled by this value. - confidence_loss_multiplier: Confidence loss will be scaled by this value. - class_loss_multiplier: Classification loss will be scaled by this value. - """ super().__init__() with open(config_path) as config_file: @@ -1902,6 +1929,8 @@ def _create_yolo( prior_shapes: Optional[List[Tuple[int, int]]] = None, matching_algorithm: Optional[str] = None, matching_threshold: Optional[float] = None, + spatial_range: float = 5.0, + size_range: float = 4.0, ignore_bg_threshold: Optional[float] = None, overlap_func: Optional[Union[str, Callable]] = None, predict_overlap: float = 1.0, @@ -1925,6 +1954,10 @@ def _create_yolo( ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target, where `N` is the value of this parameter. + size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and + no less than `1/N` times the target dimensions, where `N` is the value of this parameter. ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. @@ -1968,6 +2001,8 @@ def _create_yolo( prior_shape_idxs=config["mask"], matching_algorithm=matching_algorithm, matching_threshold=matching_threshold, + spatial_range=spatial_range, + size_range=size_range, ignore_bg_threshold=ignore_bg_threshold, overlap_func=overlap_func, predict_overlap=predict_overlap, diff --git a/torchvision/models/yolo.py b/torchvision/models/yolo.py index 695b613353f..6ae2f334338 100644 --- a/torchvision/models/yolo.py +++ b/torchvision/models/yolo.py @@ -256,7 +256,7 @@ def __init__( self.conv2 = Conv(hidden_channels, hidden_channels, kernel_size=3, stride=1, activation=activation, norm=norm) self.mix = Conv(num_channels, num_channels, kernel_size=1, stride=1, activation=activation, norm=norm) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: partial = torch.chunk(x, 2, dim=1)[1] y1 = self.conv1(partial) y2 = self.conv2(y1) From ab88372fade11bf4e8c76ecb87c53e906bed1480 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Tue, 4 Apr 2023 16:20:44 +0300 Subject: [PATCH 04/13] Added support for label smoothing --- test/test_models_detection_box_utils.py | 22 +- torchvision/models/detection/anchor_utils.py | 2 +- torchvision/models/detection/box_utils.py | 13 +- .../models/detection/target_matching.py | 14 +- torchvision/models/detection/yolo.py | 59 +++--- torchvision/models/detection/yolo_loss.py | 65 ++++-- torchvision/models/detection/yolo_networks.py | 196 ++++++++++-------- torchvision/models/yolo.py | 4 +- 8 files changed, 215 insertions(+), 160 deletions(-) diff --git a/test/test_models_detection_box_utils.py b/test/test_models_detection_box_utils.py index 20d85411695..75f748249be 100644 --- a/test/test_models_detection_box_utils.py +++ b/test/test_models_detection_box_utils.py @@ -1,7 +1,7 @@ import pytest import torch from torchvision.models.detection.anchor_utils import grid_centers -from torchvision.models.detection.box_utils import aligned_iou, compare_box_sizes, iou_below, is_inside_box +from torchvision.models.detection.box_utils import aligned_iou, box_size_ratio, iou_below, is_inside_box @pytest.mark.parametrize( @@ -62,16 +62,16 @@ def test_is_inside_box(): assert torch.all(is_inside[4, 7:9, 1]) -def test_compare_box_sizes(): +def test_box_size_ratio(): wh1 = torch.tensor([[24, 11], [12, 25], [26, 27], [15, 17]]) wh2 = torch.tensor([[10, 30], [15, 9]]) - result = compare_box_sizes(wh1, wh2, 2.0) + result = box_size_ratio(wh1, wh2) assert result.shape == (4, 2) - assert not result[0, 0] # 24 / 10 >= 2 - assert result[0, 1] # 24 / 15 < 2, 11 / 9 < 2 - assert result[1, 0] # 12 / 10 < 2, 30 / 25 < 2 - assert not result[1, 1] # 25 / 9 >= 2 - assert not result[2, 0] # 26 / 10 >= 2 - assert not result[2, 1] # 27 / 9 >= 2 - assert result[3, 0] # 15 / 10 < 2, 30 / 17 < 2 - assert result[3, 1] # 15 / 15 < 2, 17 / 9 < 2 + assert result[0, 0] == 30 / 11 + assert result[0, 1] == 24 / 15 + assert result[1, 0] == 12 / 10 + assert result[1, 1] == 25 / 9 + assert result[2, 0] == 26 / 10 + assert result[2, 1] == 27 / 9 + assert result[3, 0] == 30 / 17 + assert result[3, 1] == 17 / 9 diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index d7c01a4af11..7a41a5b0e0c 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -279,7 +279,7 @@ def grid_offsets(grid_size: Tensor) -> Tensor: """ x_range = torch.arange(grid_size[0].item(), device=grid_size.device) y_range = torch.arange(grid_size[1].item(), device=grid_size.device) - grid_y, grid_x = torch.meshgrid((y_range, x_range), indexing="ij") + grid_y, grid_x = torch.meshgrid([y_range, x_range], indexing="ij") return torch.stack((grid_x, grid_y), -1) diff --git a/torchvision/models/detection/box_utils.py b/torchvision/models/detection/box_utils.py index af3481dba81..ac188ebaebf 100644 --- a/torchvision/models/detection/box_utils.py +++ b/torchvision/models/detection/box_utils.py @@ -61,23 +61,20 @@ def is_inside_box(points: Tensor, boxes: Tensor) -> Tensor: return deltas.min(-1).values > 0.0 # [points, boxes] -def compare_box_sizes(wh1: Tensor, wh2: Tensor, threshold: float) -> Tensor: - """Compares the dimensions of the boxes pairwise and returns a mask that indicates which pairs have similar - sizes. +def box_size_ratio(wh1: Tensor, wh2: Tensor) -> Tensor: + """Compares the dimensions of the boxes pairwise. For each pair of boxes, calculates the largest ratio that can be obtained by dividing the widths with each other or - dividing the heights with each other. Returns a mask that indicates which pairs have a ratio less than the given - threshold. + dividing the heights with each other. Args: wh1: An ``[N, 2]`` matrix of box shapes (width and height). wh2: An ``[M, 2]`` matrix of box shapes (width and height). - threshold: A threshold for the size ratio. Returns: - An ``[N, M]`` matrix of truth values indicating which box pairs have the maximum size ratio below the threshold. + An ``[N, M]`` matrix of ratios of width or height dimensions, whichever is larger. """ wh_ratio = wh1[:, None, :] / wh2[None, :, :] # [M, N, 2] wh_ratio = torch.max(wh_ratio, 1.0 / wh_ratio) wh_ratio = wh_ratio.max(2).values # [M, N] - return wh_ratio < threshold + return wh_ratio diff --git a/torchvision/models/detection/target_matching.py b/torchvision/models/detection/target_matching.py index 9a87c3c91e7..3b8eb1b7d7b 100644 --- a/torchvision/models/detection/target_matching.py +++ b/torchvision/models/detection/target_matching.py @@ -6,7 +6,7 @@ from ...ops import box_convert from .anchor_utils import grid_centers -from .box_utils import aligned_iou, iou_below, is_inside_box, compare_box_sizes +from .box_utils import aligned_iou, box_size_ratio, iou_below, is_inside_box from .yolo_loss import YOLOLoss @@ -191,7 +191,7 @@ def __init__( def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) - return compare_box_sizes(wh, prior_wh, self.threshold).nonzero().T + return (box_size_ratio(wh, prior_wh) < self.threshold).nonzero().T def _sim_ota_match(costs: Tensor, ious: Tensor) -> Tuple[Tensor, Tensor]: @@ -247,7 +247,7 @@ class SimOTAMatching: network input resolution. prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that this layer uses. - loss_func: A ``LossFunction`` object that can be used to calculate the pairwise costs. + loss_func: A ``YOLOLoss`` object that can be used to calculate the pairwise costs. spatial_range: For each target, restrict to the anchors that are within an `N × N` grid cell are centered at the target, where `N` is the value of this parameter. size_range: For each target, restrict to the anchors whose prior dimensions are not larger than the target @@ -340,8 +340,8 @@ def _get_prior_mask( # Create a [boxes_per_cell, targets] tensor for selecting prior shapes that are close enough to the target # dimensions. - prior_wh = torch.tensor(self.prior_shapes, device=targets["boxes"].device) # XXX Enable size filtering. - shape_selector = compare_box_sizes(prior_wh, wh, self.size_range) # XXX Enable size filtering. + prior_wh = torch.tensor(self.prior_shapes, device=targets["boxes"].device) + shape_selector = box_size_ratio(prior_wh, wh) < self.size_range # Create a [grid_cells, targets] tensor for selecting spatial locations that are inside target bounding boxes. centers = grid_centers(grid_size).view(-1, 2) * grid_to_image @@ -350,7 +350,7 @@ def _get_prior_mask( # Combine the above selectors into a [grid_cells, boxes_per_cell, targets] tensor for selecting anchors that are # inside target bounding boxes and close enough shape. inside_selector = inside_selector[:, None, :].repeat(1, boxes_per_cell, 1) - inside_selector = torch.logical_and(inside_selector, shape_selector) # XXX Enable size filtering. + inside_selector = torch.logical_and(inside_selector, shape_selector) # Set the width and height of all target bounding boxes to self.range grid cells and create a selector for # anchors that are now inside the boxes. If a small target has no anchors inside its bounding box, it will be @@ -364,7 +364,7 @@ def _get_prior_mask( # Create a [grid_cells, boxes_per_cell, targets] tensor for selecting anchors that are spatially close to a # target and whose shape is close enough to the target. close_selector = close_selector[:, None, :].repeat(1, boxes_per_cell, 1) - close_selector = torch.logical_and(close_selector, shape_selector) # XXX Enable size filtering. + close_selector = torch.logical_and(close_selector, shape_selector) mask = torch.logical_or(inside_selector, close_selector).sum(-1) > 0 mask = mask.view(grid_height, grid_width, boxes_per_cell) diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py index f5a2d1cfcea..9f1bb788b7a 100644 --- a/torchvision/models/detection/yolo.py +++ b/torchvision/models/detection/yolo.py @@ -12,31 +12,38 @@ from .backbone_utils import _validate_trainable_layers from .yolo_networks import DarknetNetwork, YOLOV4Network +IMAGES = Union[Tuple[Tensor, ...], List[Tensor]] +PRED = Dict[str, Any] +PREDS = Union[Tuple[PRED, ...], List[PRED]] TARGET = Dict[str, Any] -TARGETS = List[TARGET] +TARGETS = Union[Tuple[TARGET, ...], List[TARGET]] -def validate_batch(images: List[Tensor], targets: Optional[TARGETS]) -> None: +def validate_batch(images: Union[Tensor, IMAGES], targets: Optional[TARGETS]) -> None: """Validates the format of a batch of data. Args: - images: A list of image tensors. + images: A tensor containing a batch of images or a list of image tensors. targets: A list of target dictionaries or ``None``. If a list is provided, there should be as many target dictionaries as there are images. """ - if not images: - raise ValueError("No images in batch.") - - shape = images[0].shape - for image in images: - if not isinstance(image, Tensor): - raise ValueError(f"Expected image to be of type Tensor, got {type(image).__name__}.") - if image.shape != shape: - raise ValueError(f"Images with different shapes in one batch: {shape} and {image.shape}") + if not isinstance(images, Tensor): + if not isinstance(images, (tuple, list)): + raise TypeError(f"Expected images to be a Tensor, tuple, or a list, got {type(images).__name__}.") + if not images: + raise ValueError("No images in batch.") + shape = images[0].shape + for image in images: + if not isinstance(image, Tensor): + raise ValueError(f"Expected image to be of type Tensor, got {type(image).__name__}.") + if image.shape != shape: + raise ValueError(f"Images with different shapes in one batch: {shape} and {image.shape}") if targets is None: return + if not isinstance(targets, (tuple, list)): + raise TypeError(f"Expected targets to be a tuple or a list, got {type(images).__name__}.") if len(images) != len(targets): raise ValueError(f"Got {len(images)} images, but targets for {len(targets)} images.") @@ -80,7 +87,7 @@ class YOLO(nn.Module): input. During training, the model expects both the image tensors and a list of targets. It's possible to train a model - using one integer class label per box, but the YOLO model supports also multiple classes per box. For multi-class + using one integer class label per box, but the YOLO model supports also multiple labels per box. For multi-label training, simply use a boolean matrix that indicates which classes are assigned to which boxes, in place of the class labels. *Each target is a dictionary containing the following tensors*: @@ -148,7 +155,9 @@ def __init__( self.nms_threshold = nms_threshold self.detections_per_image = detections_per_image - def forward(self, images: List[Tensor], targets: Optional[TARGETS] = None) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def forward( + self, images: Union[Tensor, IMAGES], targets: Optional[TARGETS] = None + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are provided, computes the losses from the detection layers. @@ -156,7 +165,8 @@ def forward(self, images: List[Tensor], targets: Optional[TARGETS] = None) -> Un that depends on the size of the feature map and the number of anchors per feature map cell. Args: - images: Images to be processed. Tensor of size ``[batch_size, channels, height, width]``. + images: A tensor of size ``[batch_size, channels, height, width]`` containing a batch of images or a list of + image tensors. targets: If given, computes losses from detection layers against these targets. A list of target dictionaries, one for each image. @@ -167,7 +177,8 @@ def forward(self, images: List[Tensor], targets: Optional[TARGETS] = None) -> Un coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. """ validate_batch(images, targets) - detections, losses, hits = self.network(torch.stack(images), targets) + images_tensor = images if isinstance(images, Tensor) else torch.stack(images) + detections, losses, hits = self.network(images_tensor, targets) detections = torch.cat(detections, 1) if targets is None: @@ -176,7 +187,7 @@ def forward(self, images: List[Tensor], targets: Optional[TARGETS] = None) -> Un losses = torch.stack(losses).sum(0) return detections, losses, hits - def infer(self, image: Tensor) -> Dict[str, Tensor]: + def infer(self, image: Tensor) -> PRED: """Feeds an image to the network and returns the detected bounding boxes, confidence scores, and class labels. @@ -204,7 +215,7 @@ def infer(self, image: Tensor) -> Dict[str, Tensor]: self.train() return detections - def process_detections(self, preds: Tensor) -> List[Dict[str, Tensor]]: + def process_detections(self, preds: Tensor) -> List[PRED]: """Splits the detection tensor returned by a forward pass into a list of prediction dictionaries, and filters them based on confidence threshold, non-maximum suppression (NMS), and maximum number of predictions. @@ -241,7 +252,7 @@ def process(boxes: Tensor, confidences: Tensor, classprobs: Tensor) -> Dict[str, return [process(p[..., :4], p[..., 4], p[..., 5:]) for p in preds] - def process_targets(self, targets: TARGETS) -> TARGETS: + def process_targets(self, targets: TARGETS) -> List[TARGET]: """Duplicates multi-label targets to create one target for each label. Args: @@ -333,8 +344,8 @@ def yolov4( nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher confidence box is higher than this threshold, if the predicted categories are equal. detections_per_image: Keep at most this number of highest-confidence detections per image. - **kwargs: Parameters passed to the ``.YOLOV4Network`` class. Please refer to the `source code - `_ + **kwargs: Parameters passed to the ``torchvision.models.detection.YOLOV4Network`` class. Please refer to the + `source code `_ for more details about this class. .. autoclass:: .YOLOV4_Weights @@ -402,9 +413,9 @@ def yolo_darknet( nms_threshold: Non-maximum suppression will remove bounding boxes whose IoU with a higher confidence box is higher than this threshold, if the predicted categories are equal. detections_per_image: Keep at most this number of highest-confidence detections per image. - **kwargs: Parameters passed to the ``.YOLOV4Network`` class. Please refer to the `source code - `_ + **kwargs: Parameters passed to the ``torchvision.models.detection.DarknetNetwork`` class. Please refer to the + `source code `_ for more details about this class. """ - network = DarknetNetwork(config_path, weights_path) + network = DarknetNetwork(config_path, weights_path, **kwargs) return YOLO(network, confidence_threshold, nms_threshold, detections_per_image) diff --git a/torchvision/models/detection/yolo_loss.py b/torchvision/models/detection/yolo_loss.py index ec8d66c9fd0..31efb0d3742 100644 --- a/torchvision/models/detection/yolo_loss.py +++ b/torchvision/models/detection/yolo_loss.py @@ -3,16 +3,16 @@ import torch from torch import Tensor -from torch.nn.functional import binary_cross_entropy, binary_cross_entropy_with_logits +from torch.nn.functional import binary_cross_entropy, binary_cross_entropy_with_logits, one_hot from torchvision.ops import ( box_iou, - generalized_box_iou, - generalized_box_iou_loss, - distance_box_iou, - distance_box_iou_loss, complete_box_iou, complete_box_iou_loss, + distance_box_iou, + distance_box_iou_loss, + generalized_box_iou, + generalized_box_iou_loss, ) @@ -70,15 +70,16 @@ def _pairwise_confidence_loss( ) -> Tensor: """Calculates the confidence loss for every pair of a foreground anchor and a target. - If ``predict_overlap`` is ``True``, ``overlap`` will be used as the target confidence. Otherwise the target - confidence is 1. The method returns a matrix of losses for target/prediction pairs. + If ``predict_overlap`` is ``None``, the target confidence will be 1. If ``predict_overlap`` is 1.0, ``overlap`` will + be used as the target confidence. Otherwise this parameter defines a balance between these two targets. The method + returns a vector of losses for each foreground anchor. Args: preds: An ``[N]`` vector of predicted confidences. overlap: An ``[N, M]`` matrix of overlaps between all predicted and target bounding boxes. bce_func: A function for calculating binary cross entropy. - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the overlap. + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the overlap. Returns: An ``[N, M]`` matrix of confidence losses between all predictions and targets. @@ -103,15 +104,16 @@ def _foreground_confidence_loss( ) -> Tensor: """Calculates the sum of the confidence losses for foreground anchors and their matched targets. - If ``predict_overlap`` is ``True``, ``overlap`` will be used as the target confidence. Otherwise the target - confidence is 1. The method returns a vector of losses for each foreground anchor. + If ``predict_overlap`` is ``None``, the target confidence will be 1. If ``predict_overlap`` is 1.0, ``overlap`` will + be used as the target confidence. Otherwise this parameter defines a balance between these two targets. The method + returns a vector of losses for each foreground anchor. Args: preds: A vector of predicted confidences. overlap: A vector of overlaps between matched target and predicted bounding boxes. bce_func: A function for calculating binary cross entropy. - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the overlap. + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1, and 1.0 means that the target confidence is the overlap. Returns: The sum of the confidence losses for foreground anchors. @@ -138,14 +140,22 @@ def _background_confidence_loss(preds: Tensor, bce_func: Callable) -> Tensor: return bce_func(preds, targets, reduction="sum") -def _target_labels_to_probs(targets: Tensor, num_classes: int, dtype: torch.dtype) -> Tensor: +def _target_labels_to_probs( + targets: Tensor, num_classes: int, dtype: torch.dtype, label_smoothing: Optional[float] = None +) -> Tensor: """If ``targets`` is a vector of class labels, converts it to a matrix of one-hot class probabilities. + If label smoothing is disabled, the returned target probabilities will be binary. If label smoothing is enabled, the + target probabilities will be, ``(label_smoothing / 2)`` or ``(label_smoothing / 2) + (1.0 - label_smoothing)``. That + corresponds to label smoothing with two categories, since the YOLO model does multi-label classification. + Args: targets: An ``[M, C]`` matrix of target class probabilities or an ``[M]`` vector of class labels. num_classes: The number of classes (C dimension) for the new targets. If ``targets`` is already two-dimensional, checks that the length of the second dimension matches this number. dtype: Floating-point data type to be used for the one-hot targets. + label_smoothing: The epsilon parameter (weight) for label smoothing. 0.0 means no smoothing (binary targets), + and 1.0 means that the target probabilities are always 0.5. Returns: An ``[M, C]`` matrix of target class probabilities. @@ -155,13 +165,16 @@ def _target_labels_to_probs(targets: Tensor, num_classes: int, dtype: torch.dtyp # greater than the number of predicted classes, it will be mapped to the last class. last_class = torch.tensor(num_classes - 1, device=targets.device) targets = torch.min(targets, last_class) - targets = torch.nn.functional.one_hot(targets, num_classes) + targets = one_hot(targets, num_classes) elif targets.shape[-1] != num_classes: raise ValueError( f"The number of classes in the data ({targets.shape[-1]}) doesn't match the number of classes " f"predicted by the model ({num_classes})." ) - return targets.to(dtype=dtype) + targets = targets.to(dtype=dtype) + if label_smoothing is not None: + targets = (label_smoothing / 2) + targets * (1.0 - label_smoothing) + return targets @dataclass @@ -174,13 +187,19 @@ class Losses: class YOLOLoss: """A class for calculating the YOLO losses from predictions and targets. + If label smoothing is enabled, the target class probabilities will be ``(label_smoothing / 2)`` or + ``(label_smoothing / 2) + (1.0 - label_smoothing)``, instead of 0 or 1. That corresponds to label smoothing with two + categories, since the YOLO model does multi-label classification. + Args: overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou" (default). - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -190,6 +209,7 @@ def __init__( self, overlap_func: Union[str, Callable] = "ciou", predict_overlap: Optional[float] = None, + label_smoothing: Optional[float] = None, overlap_multiplier: float = 5.0, confidence_multiplier: float = 1.0, class_multiplier: float = 1.0, @@ -201,6 +221,7 @@ def __init__( self._pairwise_overlap, self._elementwise_overlap_loss = _get_iou_and_loss_functions(overlap_func) self.predict_overlap = predict_overlap + self.label_smoothing = label_smoothing self.overlap_multiplier = overlap_multiplier self.confidence_multiplier = confidence_multiplier self.class_multiplier = class_multiplier @@ -241,7 +262,9 @@ def pairwise( assert confidence_loss.shape == loss_shape pred_probs = preds["classprobs"].unsqueeze(1) # [N, 1, classes] - target_probs = _target_labels_to_probs(targets["labels"], pred_probs.shape[-1], pred_probs.dtype) + target_probs = _target_labels_to_probs( + targets["labels"], pred_probs.shape[-1], pred_probs.dtype, self.label_smoothing + ) target_probs = target_probs.unsqueeze(0) # [1, M, classes] pred_probs, target_probs = torch.broadcast_tensors(pred_probs, target_probs) class_loss = bce_func(pred_probs, target_probs, reduction="none").sum(-1) @@ -287,7 +310,9 @@ def elementwise_sums( confidence_loss += _background_confidence_loss(preds["bg_confidences"], bce_func) pred_probs = preds["classprobs"] - target_probs = _target_labels_to_probs(targets["labels"], pred_probs.shape[-1], pred_probs.dtype) + target_probs = _target_labels_to_probs( + targets["labels"], pred_probs.shape[-1], pred_probs.dtype, self.label_smoothing + ) class_loss = bce_func(pred_probs, target_probs, reduction="sum") losses = Losses( diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py index e720205d174..c99220b796d 100644 --- a/torchvision/models/detection/yolo_networks.py +++ b/torchvision/models/detection/yolo_networks.py @@ -25,19 +25,15 @@ YOLOV7Backbone, ) from .anchor_utils import global_xy -from .target_matching import ( - HighestIoUMatching, - IoUThresholdMatching, - ShapeMatching, - SimOTAMatching, - SizeRatioMatching, -) +from .target_matching import HighestIoUMatching, IoUThresholdMatching, ShapeMatching, SimOTAMatching, SizeRatioMatching from .yolo_loss import YOLOLoss CONFIG = Dict[str, Any] CREATE_LAYER_OUTPUT = Tuple[nn.Module, int] # layer, num_outputs +PRED = Dict[str, Any] +PREDS = Union[Tuple[PRED, ...], List[PRED]] TARGET = Dict[str, Any] -TARGETS = List[TARGET] +TARGETS = Union[Tuple[TARGET, ...], List[TARGET]] NETWORK_OUTPUT = Tuple[List[Tensor], List[Tensor], List[int]] # detections, losses, hits @@ -51,7 +47,7 @@ class DetectionLayer(nn.Module): prior_shapes: A list of prior box dimensions for this layer, used for scaling the predicted dimensions. The list should contain (width, height) tuples in the network input resolution. matching_func: The matching algorithm to be used for assigning targets to anchors. - loss_func: ``LossFunction`` object for calculating the losses. + loss_func: ``YOLOLoss`` object for calculating the losses. xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps to produce coordinate values close to one. input_is_normalized: The input is normalized by logistic activation in the previous layer. In this case the @@ -78,7 +74,7 @@ def __init__( self.xy_scale = xy_scale self.input_is_normalized = input_is_normalized - def forward(self, x: Tensor, image_size: Tensor) -> Tuple[Tensor, List[Dict[str, Tensor]]]: + def forward(self, x: Tensor, image_size: Tensor) -> Tuple[Tensor, PREDS]: """Runs a forward pass through this YOLO detection layer. Maps cell-local coordinates to global coordinates in the image space, scales the bounding boxes with the @@ -147,11 +143,11 @@ def forward(self, x: Tensor, image_size: Tensor) -> Tuple[Tensor, List[Dict[str, def match_targets( self, - preds: List[Dict[str, Tensor]], - return_preds: List[Dict[str, Tensor]], - targets: List[Dict[str, Tensor]], + preds: PREDS, + return_preds: PREDS, + targets: TARGETS, image_size: Tensor, - ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: + ) -> Tuple[PRED, TARGET]: """Matches the predictions to targets. Args: @@ -187,16 +183,15 @@ def match_targets( "labels": image_targets["labels"][target_selector], } else: - device = image_preds["confidences"].device matched_preds = { - "boxes": torch.empty((0, 4), device=device), - "confidences": torch.empty(0, device=device), - "bg_confidences": image_preds["confidences"].flatten(), - "classprobs": torch.empty((0, self.num_classes), device=device), + "boxes": torch.empty((0, 4), device=image_return_preds["boxes"].device), + "confidences": torch.empty(0, device=image_return_preds["confidences"].device), + "bg_confidences": image_return_preds["confidences"].flatten(), + "classprobs": torch.empty((0, self.num_classes), device=image_return_preds["classprobs"].device), } matched_targets = { - "boxes": torch.empty((0, 4), device=device), - "labels": torch.empty(0, dtype=torch.int64, device=device), + "boxes": torch.empty((0, 4), device=image_targets["boxes"].device), + "labels": torch.empty(0, dtype=torch.int64, device=image_targets["labels"].device), } matches.append((matched_preds, matched_targets)) @@ -214,10 +209,10 @@ def match_targets( def calculate_losses( self, - preds: List[Dict[str, Tensor]], - targets: List[Dict[str, Tensor]], + preds: PREDS, + targets: TARGETS, image_size: Tensor, - loss_preds: Optional[List[Dict[str, Tensor]]] = None, + loss_preds: Optional[PREDS] = None, ) -> Tuple[Tensor, int]: """Matches the predictions to targets and computes the losses. @@ -247,7 +242,7 @@ def calculate_losses( return losses, hits -def _create_detection_layer( +def create_detection_layer( prior_shapes: Sequence[Tuple[int, int]], prior_shape_idxs: Sequence[int], matching_algorithm: Optional[str] = None, @@ -256,7 +251,8 @@ def _create_detection_layer( size_range: float = 4.0, ignore_bg_threshold: float = 0.7, overlap_func: Union[str, Callable] = "ciou", - predict_overlap: float = 1.0, + predict_overlap: Optional[float] = None, + label_smoothing: Optional[float] = None, overlap_loss_multiplier: float = 5.0, confidence_loss_multiplier: float = 1.0, class_loss_multiplier: float = 1.0, @@ -285,9 +281,11 @@ def _create_detection_layer( overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou" (default). - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -302,7 +300,7 @@ def _create_detection_layer( matching_func: Union[ShapeMatching, SimOTAMatching] if matching_algorithm == "simota": loss_func = YOLOLoss( - overlap_func, None, overlap_loss_multiplier, confidence_loss_multiplier, class_loss_multiplier + overlap_func, None, None, overlap_loss_multiplier, confidence_loss_multiplier, class_loss_multiplier ) matching_func = SimOTAMatching(prior_shapes, prior_shape_idxs, loss_func, spatial_range, size_range) elif matching_algorithm == "size": @@ -319,17 +317,21 @@ def _create_detection_layer( raise ValueError(f"Matching algorithm `{matching_algorithm}´ is unknown.") loss_func = YOLOLoss( - overlap_func, predict_overlap, overlap_loss_multiplier, confidence_loss_multiplier, class_loss_multiplier + overlap_func, + predict_overlap, + label_smoothing, + overlap_loss_multiplier, + confidence_loss_multiplier, + class_loss_multiplier, ) - layer_shapes = [prior_shapes[i] for i in prior_shape_idxs] return DetectionLayer(prior_shapes=layer_shapes, matching_func=matching_func, loss_func=loss_func, **kwargs) -def _run_detection( +def run_detection( detection_layer: DetectionLayer, layer_input: Tensor, - targets: Optional[List[Dict[str, Tensor]]], + targets: Optional[TARGETS], image_size: Tensor, detections: List[Tensor], losses: List[Tensor], @@ -357,12 +359,12 @@ def _run_detection( hits.append(layer_hits) -def _run_detection_with_aux_head( +def run_detection_with_aux_head( detection_layer: DetectionLayer, aux_detection_layer: DetectionLayer, layer_input: Tensor, aux_input: Tensor, - targets: Optional[List[Dict[str, Tensor]]], + targets: Optional[TARGETS], image_size: Tensor, aux_weight: float, detections: List[Tensor], @@ -405,7 +407,7 @@ def _run_detection_with_aux_head( @torch.jit.script -def _get_image_size(images: Tensor) -> Tensor: +def get_image_size(images: Tensor) -> Tensor: """Get the image size from an input tensor. The function needs the ``@torch.jit.script`` decorator in order for ONNX generation to work. The tracing based @@ -454,9 +456,11 @@ class YOLOV4TinyNetwork(nn.Module): overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou" (default). - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -509,7 +513,7 @@ def outputs(in_channels: int) -> nn.Module: def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: assert prior_shapes is not None - return _create_detection_layer( + return create_detection_layer( prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs ) @@ -542,7 +546,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses: List[Tensor] = [] # Losses from detection layers hits: List[int] = [] # Number of targets each detection layer was responsible for - image_size = _get_image_size(x) + image_size = get_image_size(x) c3, c4, c5 = self.backbone(x)[-3:] @@ -552,9 +556,9 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.upsample4(p4), c3), dim=1) p3 = self.fpn3(x) - _run_detection(self.detect5, self.out5(p5), targets, image_size, detections, losses, hits) - _run_detection(self.detect4, self.out4(p4), targets, image_size, detections, losses, hits) - _run_detection(self.detect3, self.out3(p3), targets, image_size, detections, losses, hits) + run_detection(self.detect5, self.out5(p5), targets, image_size, detections, losses, hits) + run_detection(self.detect4, self.out4(p4), targets, image_size, detections, losses, hits) + run_detection(self.detect3, self.out3(p3), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -588,9 +592,11 @@ class YOLOV4Network(nn.Module): overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou" (default). - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -661,7 +667,7 @@ def downsample(in_channels: int, out_channels: int) -> nn.Module: def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: assert prior_shapes is not None - return _create_detection_layer( + return create_detection_layer( prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs ) @@ -703,7 +709,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses: List[Tensor] = [] # Losses from detection layers hits: List[int] = [] # Number of targets each detection layer was responsible for - image_size = _get_image_size(x) + image_size = get_image_size(x) c3, c4, x = self.backbone(x)[-3:] c5 = self.spp(x) @@ -717,9 +723,9 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample4(n4), c5), dim=1) n5 = self.pan5(x) - _run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) - _run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) - _run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) + run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) + run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -753,9 +759,11 @@ class YOLOV4P6Network(nn.Module): overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou" (default). - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -833,7 +841,7 @@ def downsample(in_channels: int, out_channels: int) -> nn.Module: def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: assert prior_shapes is not None - return _create_detection_layer( + return create_detection_layer( prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs ) @@ -887,7 +895,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses: List[Tensor] = [] # Losses from detection layers hits: List[int] = [] # Number of targets each detection layer was responsible for - image_size = _get_image_size(x) + image_size = get_image_size(x) c3, c4, c5, x = self.backbone(x)[-4:] c6 = self.spp(x) @@ -905,10 +913,10 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample5(n5), c6), dim=1) n6 = self.pan6(x) - _run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) - _run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) - _run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) - _run_detection(self.detect6, self.out6(n6), targets, image_size, detections, losses, hits) + run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) + run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) + run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + run_detection(self.detect6, self.out6(n6), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -947,9 +955,11 @@ class YOLOV5Network(nn.Module): overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou" (default). - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -1015,7 +1025,7 @@ def csp(in_channels: int, out_channels: int) -> nn.Module: def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: assert prior_shapes is not None - return _create_detection_layer( + return create_detection_layer( prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs ) @@ -1057,7 +1067,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses: List[Tensor] = [] # Losses from detection layers hits: List[int] = [] # Number of targets each detection layer was responsible for - image_size = _get_image_size(x) + image_size = get_image_size(x) c3, c4, x = self.backbone(x)[-3:] c5 = self.spp(x) @@ -1073,9 +1083,9 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample4(n4), p5), dim=1) n5 = self.pan5(x) - _run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) - _run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) - _run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) + run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) + run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -1108,9 +1118,11 @@ class YOLOV7Network(nn.Module): overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou" (default). - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -1122,7 +1134,7 @@ def __init__( self, num_classes: int, backbone: Optional[nn.Module] = None, - widths: Sequence[int] = (32, 64, 128, 256, 512, 1024, 1024), + widths: Sequence[int] = (64, 128, 256, 512, 768, 1024), activation: Optional[str] = "silu", normalization: Optional[str] = "batchnorm", prior_shapes: Optional[List[Tuple[int, int]]] = None, @@ -1194,7 +1206,7 @@ def downsample(in_channels: int, out_channels: int) -> nn.Module: def detect(prior_shape_idxs: Sequence[int], range: float) -> DetectionLayer: assert prior_shapes is not None - return _create_detection_layer( + return create_detection_layer( prior_shapes, prior_shape_idxs, spatial_range=range, @@ -1261,7 +1273,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses: List[Tensor] = [] # Losses from detection layers hits: List[int] = [] # Number of targets each detection layer was responsible for - image_size = _get_image_size(x) + image_size = get_image_size(x) c3, c4, c5, x = self.backbone(x)[-4:] c6 = self.spp(x) @@ -1279,7 +1291,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample5(n5), c6), dim=1) n6 = self.pan6(x) - _run_detection_with_aux_head( + run_detection_with_aux_head( self.detect3, self.aux_detect3, self.out3(n3), @@ -1291,7 +1303,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses, hits, ) - _run_detection_with_aux_head( + run_detection_with_aux_head( self.detect4, self.aux_detect4, self.out4(n4), @@ -1303,7 +1315,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses, hits, ) - _run_detection_with_aux_head( + run_detection_with_aux_head( self.detect5, self.aux_detect5, self.out5(n5), @@ -1315,7 +1327,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses, hits, ) - _run_detection_with_aux_head( + run_detection_with_aux_head( self.detect6, self.aux_detect6, self.out6(n6), @@ -1422,9 +1434,11 @@ class YOLOXNetwork(nn.Module): overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou" (default). - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -1485,7 +1499,7 @@ def head(in_channels: int, hidden_channels: int) -> YOLOXHead: def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: assert prior_shapes is not None - return _create_detection_layer( + return create_detection_layer( prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs ) @@ -1527,7 +1541,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses: List[Tensor] = [] # Losses from detection layers hits: List[int] = [] # Number of targets each detection layer was responsible for - image_size = _get_image_size(x) + image_size = get_image_size(x) c3, c4, x = self.backbone(x)[-3:] c5 = self.spp(x) @@ -1543,9 +1557,9 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample4(n4), p5), dim=1) n5 = self.pan5(x) - _run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) - _run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) - _run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) + run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) + run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -1574,9 +1588,11 @@ class DarknetNetwork(nn.Module): overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou". - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -1621,7 +1637,7 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU losses: List[Tensor] = [] # Losses from detection layers hits: List[int] = [] # Number of targets each detection layer was responsible for - image_size = _get_image_size(x) + image_size = get_image_size(x) for layer in self.layers: if isinstance(layer, (RouteLayer, ShortcutLayer)): @@ -1933,7 +1949,8 @@ def _create_yolo( size_range: float = 4.0, ignore_bg_threshold: Optional[float] = None, overlap_func: Optional[Union[str, Callable]] = None, - predict_overlap: float = 1.0, + predict_overlap: Optional[float] = None, + label_smoothing: Optional[float] = None, overlap_loss_multiplier: Optional[float] = None, confidence_loss_multiplier: Optional[float] = None, class_loss_multiplier: Optional[float] = None, @@ -1964,9 +1981,11 @@ def _create_yolo( overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and "ciou". - predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that target - confidence is one if there's an object, and 1.0 means that the target confidence is the output of + predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target + confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. + label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary + targets), and 1.0 means that the target probabilities are always 0.5. overlap_loss_multiplier: Overlap loss will be scaled by this value. confidence_loss_multiplier: Confidence loss will be scaled by this value. class_loss_multiplier: Classification loss will be scaled by this value. @@ -1995,7 +2014,7 @@ def _create_yolo( class_loss_multiplier = config.get("cls_normalizer", 1.0) assert isinstance(class_loss_multiplier, float) - layer = _create_detection_layer( + layer = create_detection_layer( num_classes=config["classes"], prior_shapes=prior_shapes, prior_shape_idxs=config["mask"], @@ -2006,6 +2025,7 @@ def _create_yolo( ignore_bg_threshold=ignore_bg_threshold, overlap_func=overlap_func, predict_overlap=predict_overlap, + label_smoothing=label_smoothing, overlap_loss_multiplier=overlap_loss_multiplier, confidence_loss_multiplier=confidence_loss_multiplier, class_loss_multiplier=class_loss_multiplier, diff --git a/torchvision/models/yolo.py b/torchvision/models/yolo.py index 6ae2f334338..aaaa8a1e449 100644 --- a/torchvision/models/yolo.py +++ b/torchvision/models/yolo.py @@ -493,7 +493,9 @@ def smooth(num_channels: int) -> nn.Module: return Conv(num_channels, num_channels, kernel_size=3, stride=1, activation=activation, norm=normalization) def downsample(in_channels: int, out_channels: int) -> nn.Module: - conv_module = Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) + conv_module = Conv( + in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization + ) return nn.Sequential(OrderedDict([("downsample", conv_module), ("smooth", smooth(out_channels))])) def maxpool(out_channels: int) -> nn.Module: From 154df0c9323a82718f2088289e5c9ad38eaad228 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 5 Apr 2023 20:08:09 +0300 Subject: [PATCH 05/13] Fixed type annotations and yolov4() helper function --- torchvision/models/detection/yolo.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py index 9f1bb788b7a..d79c081d98b 100644 --- a/torchvision/models/detection/yolo.py +++ b/torchvision/models/detection/yolo.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +import warnings from torch import Tensor from ...ops import batched_nms @@ -157,7 +158,7 @@ def __init__( def forward( self, images: Union[Tensor, IMAGES], targets: Optional[TARGETS] = None - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + ) -> Union[Tensor, Tuple[Tensor, Tensor, List[int]]]: """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are provided, computes the losses from the detection layers. @@ -171,10 +172,10 @@ def forward( dictionaries, one for each image. Returns: - detections (:class:`~torch.Tensor`), losses (:class:`~torch.Tensor`): Detections, and if targets were - provided, a dictionary of losses. Detections are shaped ``[batch_size, anchors, classes + 5]``, where - ``anchors`` is the feature map size (width * height) times the number of anchors per cell. The predicted box - coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. + detections (:class:`~torch.Tensor`), losses (:class:`~torch.Tensor`), hits (List[int]): Detections, and if + targets were provided, a dictionary of losses. Detections are shaped ``[batch_size, anchors, classes + 5]``, + where ``anchors`` is the feature map size (width * height) times the number of anchors per cell. The + predicted box coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. """ validate_batch(images, targets) images_tensor = images if isinstance(images, Tensor) else torch.stack(images) @@ -289,7 +290,9 @@ def freeze_backbone_layers(backbone: nn.Module, trainable_layers: Optional[int], is_trained: Set to ``True`` when using pre-trained weights. Otherwise will issue a warning if ``trainable_layers`` is set. """ - num_layers = len(backbone.stages) + if not hasattr(backbone, "stages"): + warnings.warn("Cannot freeze backbone layers. Backbone object has no 'stages' attribute.") + num_layers = len(backbone.stages) # type: ignore trainable_layers = _validate_trainable_layers(is_trained, trainable_layers, num_layers, 3) layers_to_train = [f"stages.{idx}" for idx in range(num_layers - trainable_layers, num_layers)] @@ -370,7 +373,7 @@ def yolov4( freeze_backbone_layers(backbone, trainable_backbone_layers, is_trained) if weights_backbone is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) + backbone.load_state_dict(weights_backbone.get_state_dict(progress=progress)) network = YOLOV4Network(num_classes, backbone, **kwargs) model = YOLO(network, confidence_threshold, nms_threshold, detections_per_image) From 7cfa5af410c12892449ed8a6f58c5acc8c843f45 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 6 Apr 2023 08:55:27 +0300 Subject: [PATCH 06/13] Fixed some unit tests --- .../ModelTester.test_yolo_darknet_expect.pkl | Bin 0 -> 1179 bytes .../expect/ModelTester.test_yolov4_expect.pkl | Bin 0 -> 1155 bytes test/test_models.py | 23 ++-- torchvision/models/detection/__init__.py | 2 +- torchvision/models/detection/yolo.py | 120 ++++++++++-------- 5 files changed, 81 insertions(+), 64 deletions(-) create mode 100644 test/expect/ModelTester.test_yolo_darknet_expect.pkl create mode 100644 test/expect/ModelTester.test_yolov4_expect.pkl diff --git a/test/expect/ModelTester.test_yolo_darknet_expect.pkl b/test/expect/ModelTester.test_yolo_darknet_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d3e297368cf4a7974ddeb2c9bf4f218f3cbaba06 GIT binary patch literal 1179 zcmb7D%SyvQ6rDCTt*NgrR8Xj(n-=XOiVMMo1nI&c#;Oq9gftnNVoW;8r07C$Cn5;$ z6v4e8;Agndm471W#<`7;K2X~`3>RkZoO9=#8Aaw|2qhAzD{^QVZMr75cd<{gm!~Xg z$FA*artTd$m}>Z_jtyEcb*ks<2X;vk(SE;i8a*K>qp}{*A=VudF{tYqR$dEe&GrS& z!+R#YSdD#x->4+t+u zD2Z>&D4RVu4A-_<_chn?g)Ps-9yZr&1{FvmTPA~)2xXFNp7g0nhAc*EaaM>al3}Zf zH4oc}0#VF3+wXpBXDmYTgadh0dxaHZfas)A!Bk$eo=CUUVLduW==6zd{JsrX=YAJd`W6v zaeh&JnGshZlZHkFJ5YTwP-S9zY9TX33ztiNX;MyVFhrEAkR^f6tV|;GiYjdCMW0TRmh+6Ls*qT znwSvQH*j+j#}&SN>?=GfZS5RDwyI6Dw*W;X2;;L=l_*;c-JBG0*@}qk9L^6@OalGE z_OI!C5K7p7H3k(z_Lrp=6$2y2%}Ifv1B|&qKEF6GJ(LOPYY-0bW&~02G>se=>L3Xe zKx7I>HxxOjRZ$E~0;W154AnP4HxfDYP>sAnq>+Z`Mj{7`8j33;fI&pSmC#5C@MdGv yf$EiG)`gn_%CR5-qy0g-4G1oQ10NI(Y@kfapa@j}k`C}@Wdn(^03k>{L@fXSB@rJ0 literal 0 HcmV?d00001 diff --git a/test/test_models.py b/test/test_models.py index f9128d7d0e1..4dfbebad0fe 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -475,7 +475,7 @@ def is_skippable(model_name, device): }, "yolov4": { "max_trainable": 5, - "n_trn_params_per_layer": [1, 1, 1, 1, 1, 1], # TODO: Fill the correct values. + "n_trn_params_per_layer": [138, 174, 234, 294, 318, 339], }, } @@ -793,7 +793,7 @@ def check_out(out): _check_input_backprop(model, x) -def check_model_output(out): +def check_model_output(out, model_name): assert len(out) == 1 def compact(tensor): @@ -874,7 +874,7 @@ def test_detection_model(model_fn, dev): out = model(model_input) assert model_input[0] is x - full_validation = check_model_output(out) + full_validation = check_model_output(out, model_name) _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) if dev == "cuda": @@ -900,8 +900,9 @@ def test_detection_model(model_fn, dev): @pytest.mark.parametrize("dev", cpu_and_gpu()) def test_yolo_darknet(dev): set_rng_seed(0) + model_name = "yolo_darknet" dtype = torch.get_default_dtype() - input_shape = (3, 300, 300) + input_shape = (3, 224, 224) model = yolo_darknet(DARKNET_CONFIG) model.eval().to(device=dev, dtype=dtype) @@ -911,13 +912,13 @@ def test_yolo_darknet(dev): out = model(model_input) assert model_input[0] is x - full_validation = check_model_output(out) + full_validation = check_model_output(out, model_name) _check_jit_scriptable(model, ([x],), unwrapper=None, eager_out=out) if dev == "cuda": with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state(): out = model(model_input) - full_validation &= check_model_output(out) + full_validation &= check_model_output(out, model_name) if not full_validation: msg = ( @@ -936,28 +937,28 @@ def test_yolo_darknet(dev): def test_detection_model_validation(model_fn): set_rng_seed(0) model = model_fn(num_classes=50, weights=None, weights_backbone=None) - input_shape = (3, 300, 300) + input_shape = (3, 256, 256) # YOLO models expect the input dimensions to be a multiple of 32 or 64. x = [torch.rand(input_shape)] # validate that targets are present in training - with pytest.raises(AssertionError): + with pytest.raises((AssertionError, ValueError)): model(x) # validate type targets = [{"boxes": 0.0}] - with pytest.raises(AssertionError): + with pytest.raises((AssertionError, TypeError)): model(x, targets=targets) # validate boxes shape for boxes in (torch.rand((4,)), torch.rand((1, 5))): targets = [{"boxes": boxes}] - with pytest.raises(AssertionError): + with pytest.raises((AssertionError, ValueError)): model(x, targets=targets) # validate that no degenerate boxes are present boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]]) targets = [{"boxes": boxes}] - with pytest.raises(AssertionError): + with pytest.raises((AssertionError, ValueError)): model(x, targets=targets) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 6cc2e6e5515..9be3dea4594 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -5,7 +5,7 @@ from .retinanet import * from .ssd import * from .ssdlite import * -from .yolo import * +from .yolo import YOLO, YOLOV4_Backbone_Weights, YOLOV4_Weights, yolov4, yolo_darknet from .yolo_networks import ( DarknetNetwork, YOLOV4TinyNetwork, diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py index d79c081d98b..fa1f1ff5e97 100644 --- a/torchvision/models/detection/yolo.py +++ b/torchvision/models/detection/yolo.py @@ -7,7 +7,7 @@ from ...ops import batched_nms from ...transforms import functional as F -from .._api import register_model, WeightsEnum +from .._api import register_model, Weights, WeightsEnum from .._utils import _ovewrite_value_param from ..yolo import YOLOV4Backbone from .backbone_utils import _validate_trainable_layers @@ -20,63 +20,20 @@ TARGETS = Union[Tuple[TARGET, ...], List[TARGET]] -def validate_batch(images: Union[Tensor, IMAGES], targets: Optional[TARGETS]) -> None: - """Validates the format of a batch of data. - - Args: - images: A tensor containing a batch of images or a list of image tensors. - targets: A list of target dictionaries or ``None``. If a list is provided, there should be as many target - dictionaries as there are images. - """ - if not isinstance(images, Tensor): - if not isinstance(images, (tuple, list)): - raise TypeError(f"Expected images to be a Tensor, tuple, or a list, got {type(images).__name__}.") - if not images: - raise ValueError("No images in batch.") - shape = images[0].shape - for image in images: - if not isinstance(image, Tensor): - raise ValueError(f"Expected image to be of type Tensor, got {type(image).__name__}.") - if image.shape != shape: - raise ValueError(f"Images with different shapes in one batch: {shape} and {image.shape}") - - if targets is None: - return - - if not isinstance(targets, (tuple, list)): - raise TypeError(f"Expected targets to be a tuple or a list, got {type(images).__name__}.") - if len(images) != len(targets): - raise ValueError(f"Got {len(images)} images, but targets for {len(targets)} images.") - - for target in targets: - boxes = target["boxes"] - if not isinstance(boxes, Tensor): - raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes).__name__}.") - if (boxes.ndim != 2) or (boxes.shape[-1] != 4): - raise ValueError(f"Expected target boxes to be tensors of shape [N, 4], got {list(boxes.shape)}.") - labels = target["labels"] - if not isinstance(labels, Tensor): - raise ValueError(f"Expected target labels to be of type Tensor, got {type(labels).__name__}.") - if (labels.ndim < 1) or (labels.ndim > 2) or (len(labels) != len(boxes)): - raise ValueError( - f"Expected target labels to be tensors of shape [N] or [N, num_classes], got {list(labels.shape)}." - ) - - class YOLO(nn.Module): """YOLO implementation that supports the most important features of YOLOv3, YOLOv4, YOLOv5, YOLOv7, Scaled-YOLOv4, and YOLOX. - *YOLOv3 paper*: `Joseph Redmon and Ali Farhadi `_ + *YOLOv3 paper*: `Joseph Redmon and Ali Farhadi `__ - *YOLOv4 paper*: `Alexey Bochkovskiy, Chien-Yao Wang, and Hong-Yuan Mark Liao `_ + *YOLOv4 paper*: `Alexey Bochkovskiy, Chien-Yao Wang, and Hong-Yuan Mark Liao `__ - *YOLOv7 paper*: `Chien-Yao Wang, Alexey Bochkovskiy, and Hong-Yuan Mark Liao `_ + *YOLOv7 paper*: `Chien-Yao Wang, Alexey Bochkovskiy, and Hong-Yuan Mark Liao `__ *Scaled-YOLOv4 paper*: `Chien-Yao Wang, Alexey Bochkovskiy, and Hong-Yuan Mark Liao - `_ + `__ - *YOLOX paper*: `Zheng Ge, Songtao Liu, Feng Wang, Zeming Li, and Jian Sun `_ + *YOLOX paper*: `Zheng Ge, Songtao Liu, Feng Wang, Zeming Li, and Jian Sun `__ The network architecture can be written in PyTorch, or read from a Darknet configuration file using the :class:`~.yolo_networks.DarknetNetwork` class. ``DarknetNetwork`` is also able to read weights that have been saved @@ -177,7 +134,7 @@ def forward( where ``anchors`` is the feature map size (width * height) times the number of anchors per cell. The predicted box coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. """ - validate_batch(images, targets) + self.validate_batch(images, targets) images_tensor = images if isinstance(images, Tensor) else torch.stack(images) detections, losses, hits = self.network(images_tensor, targets) @@ -272,13 +229,72 @@ def process(boxes: Tensor, labels: Tensor, **other: Any) -> Dict[str, Any]: return [process(**t) for t in targets] + def validate_batch(self, images: Union[Tensor, IMAGES], targets: Optional[TARGETS]) -> None: + """Validates the format of a batch of data. + + Args: + images: A tensor containing a batch of images or a list of image tensors. + targets: A list of target dictionaries or ``None``. If a list is provided, there should be as many target + dictionaries as there are images. + """ + if not isinstance(images, Tensor): + if not isinstance(images, (tuple, list)): + raise TypeError(f"Expected images to be a Tensor, tuple, or a list, got {type(images).__name__}.") + if not images: + raise ValueError("No images in batch.") + shape = images[0].shape + for image in images: + if not isinstance(image, Tensor): + raise ValueError(f"Expected image to be of type Tensor, got {type(image).__name__}.") + if image.shape != shape: + raise ValueError(f"Images with different shapes in one batch: {shape} and {image.shape}") + + if targets is None: + if self.training: + raise ValueError("Targets should be given in training mode.") + else: + return + + if not isinstance(targets, (tuple, list)): + raise TypeError(f"Expected targets to be a tuple or a list, got {type(images).__name__}.") + if len(images) != len(targets): + raise ValueError(f"Got {len(images)} images, but targets for {len(targets)} images.") + + for target in targets: + if "boxes" not in target: + raise ValueError("Target dictionary doesn't contain boxes.") + boxes = target["boxes"] + if not isinstance(boxes, Tensor): + raise TypeError(f"Expected target boxes to be of type Tensor, got {type(boxes).__name__}.") + if (boxes.ndim != 2) or (boxes.shape[-1] != 4): + raise ValueError(f"Expected target boxes to be tensors of shape [N, 4], got {list(boxes.shape)}.") + if "labels" not in target: + raise ValueError("Target dictionary doesn't contain labels.") + labels = target["labels"] + if not isinstance(labels, Tensor): + raise ValueError(f"Expected target labels to be of type Tensor, got {type(labels).__name__}.") + if (labels.ndim < 1) or (labels.ndim > 2) or (len(labels) != len(boxes)): + raise ValueError( + f"Expected target labels to be tensors of shape [N] or [N, num_classes], got {list(labels.shape)}." + ) + class YOLOV4_Backbone_Weights(WeightsEnum): - DEFAULT = None + # TODO: Create pretrained weights. + DEFAULT = Weights( + url="", + transforms=lambda x: x, + meta={}, + ) class YOLOV4_Weights(WeightsEnum): - DEFAULT = None + # TODO: Create pretrained weights. + DEFAULT = Weights( + url="", + transforms=lambda x: x, + meta={}, + ) def freeze_backbone_layers(backbone: nn.Module, trainable_layers: Optional[int], is_trained: bool) -> None: From 5cc855842a7eb6edb37427b6695c7af4a80c7ea3 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 19 Apr 2023 19:04:42 +0300 Subject: [PATCH 07/13] Fixed TorchScript compilation --- .../models/detection/target_matching.py | 279 +++++--- torchvision/models/detection/yolo.py | 8 +- torchvision/models/detection/yolo_loss.py | 160 +++-- torchvision/models/detection/yolo_networks.py | 659 +++++++++--------- 4 files changed, 602 insertions(+), 504 deletions(-) diff --git a/torchvision/models/detection/target_matching.py b/torchvision/models/detection/target_matching.py index 3b8eb1b7d7b..f0c2f06fda8 100644 --- a/torchvision/models/detection/target_matching.py +++ b/torchvision/models/detection/target_matching.py @@ -1,5 +1,4 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Sequence, Tuple, Union +from typing import Dict, List, Tuple import torch from torch import Tensor @@ -9,94 +8,45 @@ from .box_utils import aligned_iou, box_size_ratio, iou_below, is_inside_box from .yolo_loss import YOLOLoss +PRIOR_SHAPES = List[List[int]] # TorchScript doesn't allow a list of tuples. -class ShapeMatching(ABC): - """Selects which anchors are used to predict each target, by comparing the shape of the target box to a set of - prior shapes. - Most YOLO variants match targets to anchors based on prior shapes that are assigned to the anchors in the model - configuration. The subclasses of ``ShapeMatching`` implement matching rules that compare the width and height of - the targets to each prior shape (regardless of the location where the target is). When the model includes multiple - detection layers, different shapes are defined for each layer. Usually there are three detection layers and three - prior shapes per layer. +def target_boxes_to_grid(preds: Tensor, targets: Tensor, image_size: Tensor) -> Tensor: + """Scales target bounding boxes to feature map coordinates. - Args: - ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU - with some target greater than this threshold, the predictor will not be taken into account when calculating - the confidence loss. - """ + It would be better to implement this in a super class, but TorchScript doesn't allow class inheritance. - def __init__(self, ignore_bg_threshold: float = 0.7) -> None: - self.ignore_bg_threshold = ignore_bg_threshold - - def __call__( - self, - preds: Dict[str, Tensor], - targets: Dict[str, Tensor], - image_size: Tensor, - ) -> Tuple[List[Tensor], Tensor, Tensor]: - """For each target, selects predictions from the same grid cell, where the center of the target box is. - - Typically there are three predictions per grid cell. Subclasses implement ``match()``, which selects the - predictions within the grid cell. - - Args: - preds: Predictions for a single image. - targets: Training targets for a single image. - image_size: Input image width and height. - - Returns: - The indices of the matched predictions, background mask, and a mask for selecting the matched targets. - """ - height, width = preds["boxes"].shape[:2] - device = preds["boxes"].device - - # A multiplier for scaling image coordinates to feature map coordinates - grid_size = torch.tensor([width, height], device=device) - image_to_grid = torch.true_divide(grid_size, image_size) - - # Bounding box center coordinates are converted to the feature map dimensions so that the whole number tells the - # cell index and the fractional part tells the location inside the cell. - xywh = box_convert(targets["boxes"], in_fmt="xyxy", out_fmt="cxcywh") - grid_xy = xywh[:, :2] * image_to_grid - cell_i = grid_xy[:, 0].to(torch.int64).clamp(0, width - 1) - cell_j = grid_xy[:, 1].to(torch.int64).clamp(0, height - 1) - - target_selector, anchor_selector = self.match(xywh[:, 2:]) - cell_i = cell_i[target_selector] - cell_j = cell_j[target_selector] - - # Background mask is used to select anchors that are not responsible for predicting any object, for - # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a - # predicted box overlaps any target significantly, or if a prediction is matched to a target. - background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) - background_mask[cell_j, cell_i, anchor_selector] = False - - pred_selector = [cell_j, cell_i, anchor_selector] + Args: + preds: Predicted bounding boxes for a single image. + targets: Target bounding boxes for a single image. + image_size: Input image width and height. - return pred_selector, background_mask, target_selector + Returns: + A tensor containing target x, y, width, and height in the feature map coordinates. + """ + height, width = preds.shape[:2] - @abstractmethod - def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: - """Selects anchors for each target based on the predicted shapes. The subclasses implement this method. + # A multiplier for scaling image coordinates to feature map coordinates + grid_size = torch.tensor([width, height], device=image_size.device) + image_to_grid = torch.true_divide(grid_size, image_size) - Args: - wh: A matrix of predicted width and height values. + # Bounding box center coordinates are converted to the feature map dimensions so that the whole number tells the + # cell index and the fractional part tells the location inside the cell. + xywh = box_convert(targets, in_fmt="xyxy", out_fmt="cxcywh") + grid_xy = xywh[:, :2] * image_to_grid + cell_i = grid_xy[:, 0].to(torch.int64).clamp(0, width - 1) + cell_j = grid_xy[:, 1].to(torch.int64).clamp(0, height - 1) - Returns: - matched_targets, matched_anchors: Two vectors or a `2xN` matrix. The first vector is used to select the - targets that this layer matched and the second one lists the matching anchors within the grid cell. - """ - pass + return torch.cat((cell_i.unsqueeze(1), cell_j.unsqueeze(1), xywh[:, 2:]), 1) -class HighestIoUMatching(ShapeMatching): +class HighestIoUMatching: """For each target, select the prior shape that gives the highest IoU. This is the original YOLO matching rule. Args: - prior_shapes: A list of all the prior box dimensions. The list should contain (width, height) tuples in the + prior_shapes: A list of all the prior box dimensions. The list should contain [width, height] pairs in the network input resolution. prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that this layer uses. @@ -106,17 +56,26 @@ class HighestIoUMatching(ShapeMatching): """ def __init__( - self, prior_shapes: Sequence[Tuple[int, int]], prior_shape_idxs: Sequence[int], ignore_bg_threshold: float = 0.7 + self, prior_shapes: PRIOR_SHAPES, prior_shape_idxs: List[int], ignore_bg_threshold: float = 0.7 ) -> None: - super().__init__(ignore_bg_threshold) self.prior_shapes = prior_shapes # anchor_map maps the anchor indices to anchors in this layer, or to -1 if it's not an anchor of this layer. # This layer ignores the target if all the selected anchors are in another layer. self.anchor_map = [ prior_shape_idxs.index(idx) if idx in prior_shape_idxs else -1 for idx in range(len(prior_shapes)) ] + self.ignore_bg_threshold = ignore_bg_threshold + + def match(self, wh: Tensor) -> Tuple[Tensor, Tensor]: + """Selects anchors for each target based on the predicted shapes. The subclasses implement this method. - def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: + Args: + wh: A matrix of predicted width and height values. + + Returns: + matched_targets, matched_anchors: Two vectors. The first vector is used to select the targets that this + layer matched and the second one lists the matching anchors within the grid cell. + """ prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) anchor_map = torch.tensor(self.anchor_map, dtype=torch.int64, device=wh.device) @@ -127,12 +86,48 @@ def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: matched_anchors = highest_iou_anchors[matched_targets] return matched_targets, matched_anchors + def __call__( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + image_size: Tensor, + ) -> Tuple[List[Tensor], Tensor, Tensor]: + """For each target, selects predictions from the same grid cell, where the center of the target box is. + + Typically there are three predictions per grid cell. Subclasses implement ``match()``, which selects the + predictions within the grid cell. + + Args: + preds: Predictions for a single image. + targets: Training targets for a single image. + image_size: Input image width and height. + + Returns: + The indices of the matched predictions, background mask, and a mask for selecting the matched targets. + """ + scaled_targets = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) + target_selector, anchor_selector = self.match(scaled_targets[:, 2:]) + + scaled_targets = scaled_targets[target_selector] + cell_i = scaled_targets[:, 0] + cell_j = scaled_targets[:, 1] + + # Background mask is used to select anchors that are not responsible for predicting any object, for + # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a + # predicted box overlaps any target significantly, or if a prediction is matched to a target. + background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) + background_mask[cell_j, cell_i, anchor_selector] = False + + pred_selector = [cell_j, cell_i, anchor_selector] + + return pred_selector, background_mask, target_selector -class IoUThresholdMatching(ShapeMatching): + +class IoUThresholdMatching: """For each target, select all prior shapes that give a high enough IoU. Args: - prior_shapes: A list of all the prior box dimensions. The list should contain (width, height) tuples in the + prior_shapes: A list of all the prior box dimensions. The list should contain [width, height] pairs in the network input resolution. prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that this layer uses. @@ -144,31 +139,76 @@ class IoUThresholdMatching(ShapeMatching): def __init__( self, - prior_shapes: Sequence[Tuple[int, int]], - prior_shape_idxs: Sequence[int], + prior_shapes: PRIOR_SHAPES, + prior_shape_idxs: List[int], threshold: float, ignore_bg_threshold: float = 0.7, ) -> None: - super().__init__(ignore_bg_threshold) self.prior_shapes = [prior_shapes[idx] for idx in prior_shape_idxs] self.threshold = threshold + self.ignore_bg_threshold = ignore_bg_threshold - def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: + def match(self, wh: Tensor) -> Tuple[Tensor, Tensor]: + """Selects anchors for each target based on the predicted shapes. The subclasses implement this method. + + Args: + wh: A matrix of predicted width and height values. + + Returns: + matched_targets, matched_anchors: Two vectors. The first vector is used to select the targets that this + layer matched and the second one lists the matching anchors within the grid cell. + """ prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) ious = aligned_iou(wh, prior_wh) above_threshold = (ious > self.threshold).nonzero() - return above_threshold.T + return above_threshold[:, 0], above_threshold[:, 1] + + def __call__( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + image_size: Tensor, + ) -> Tuple[List[Tensor], Tensor, Tensor]: + """For each target, selects predictions from the same grid cell, where the center of the target box is. + + Typically there are three predictions per grid cell. Subclasses implement ``match()``, which selects the + predictions within the grid cell. + + Args: + preds: Predictions for a single image. + targets: Training targets for a single image. + image_size: Input image width and height. + + Returns: + The indices of the matched predictions, background mask, and a mask for selecting the matched targets. + """ + scaled_targets = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) + target_selector, anchor_selector = self.match(scaled_targets[:, 2:]) + + scaled_targets = scaled_targets[target_selector] + cell_i = scaled_targets[:, 0] + cell_j = scaled_targets[:, 1] + + # Background mask is used to select anchors that are not responsible for predicting any object, for + # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a + # predicted box overlaps any target significantly, or if a prediction is matched to a target. + background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) + background_mask[cell_j, cell_i, anchor_selector] = False + + pred_selector = [cell_j, cell_i, anchor_selector] + + return pred_selector, background_mask, target_selector -class SizeRatioMatching(ShapeMatching): +class SizeRatioMatching: """For each target, select those prior shapes, whose width and height relative to the target is below given ratio. This is the matching rule used by Ultralytics YOLOv5 implementation. Args: - prior_shapes: A list of all the prior box dimensions. The list should contain (width, height) tuples in the + prior_shapes: A list of all the prior box dimensions. The list should contain [width, height] pairs in the network input resolution. prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that this layer uses. @@ -180,18 +220,64 @@ class SizeRatioMatching(ShapeMatching): def __init__( self, - prior_shapes: Sequence[Tuple[int, int]], - prior_shape_idxs: Sequence[int], + prior_shapes: PRIOR_SHAPES, + prior_shape_idxs: List[int], threshold: float, ignore_bg_threshold: float = 0.7, ) -> None: - super().__init__(ignore_bg_threshold) self.prior_shapes = [prior_shapes[idx] for idx in prior_shape_idxs] self.threshold = threshold + self.ignore_bg_threshold = ignore_bg_threshold - def match(self, wh: Tensor) -> Union[Tuple[Tensor, Tensor], Tensor]: + def match(self, wh: Tensor) -> Tuple[Tensor, Tensor]: + """Selects anchors for each target based on the predicted shapes. The subclasses implement this method. + + Args: + wh: A matrix of predicted width and height values. + + Returns: + matched_targets, matched_anchors: Two vectors. The first vector is used to select the targets that this + layer matched and the second one lists the matching anchors within the grid cell. + """ prior_wh = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device) - return (box_size_ratio(wh, prior_wh) < self.threshold).nonzero().T + below_threshold = (box_size_ratio(wh, prior_wh) < self.threshold).nonzero() + return below_threshold[:, 0], below_threshold[:, 1] + + def __call__( + self, + preds: Dict[str, Tensor], + targets: Dict[str, Tensor], + image_size: Tensor, + ) -> Tuple[List[Tensor], Tensor, Tensor]: + """For each target, selects predictions from the same grid cell, where the center of the target box is. + + Typically there are three predictions per grid cell. Subclasses implement ``match()``, which selects the + predictions within the grid cell. + + Args: + preds: Predictions for a single image. + targets: Training targets for a single image. + image_size: Input image width and height. + + Returns: + The indices of the matched predictions, background mask, and a mask for selecting the matched targets. + """ + scaled_targets = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) + target_selector, anchor_selector = self.match(scaled_targets[:, 2:]) + + scaled_targets = scaled_targets[target_selector] + cell_i = scaled_targets[:, 0] + cell_j = scaled_targets[:, 1] + + # Background mask is used to select anchors that are not responsible for predicting any object, for + # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a + # predicted box overlaps any target significantly, or if a prediction is matched to a target. + background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) + background_mask[cell_j, cell_i, anchor_selector] = False + + pred_selector = [cell_j, cell_i, anchor_selector] + + return pred_selector, background_mask, target_selector def _sim_ota_match(costs: Tensor, ious: Tensor) -> Tuple[Tensor, Tensor]: @@ -243,7 +329,7 @@ class SimOTAMatching: This is the matching rule used by YOLOX. Args: - prior_shapes: A list of all the prior box dimensions. The list should contain (width, height) tuples in the + prior_shapes: A list of all the prior box dimensions. The list should contain [width, height] pairs in the network input resolution. prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that this layer uses. @@ -256,8 +342,8 @@ class SimOTAMatching: def __init__( self, - prior_shapes: Sequence[Tuple[int, int]], - prior_shape_idxs: Sequence[int], + prior_shapes: PRIOR_SHAPES, + prior_shape_idxs: List[int], loss_func: YOLOLoss, spatial_range: float, size_range: float, @@ -298,7 +384,8 @@ def __call__( pred_mask, target_selector = _sim_ota_match(costs, ious) # Add the anchor dimension to the mask and replace True values with the results of the actual SimOTA matching. - prior_mask[prior_mask.nonzero().T.tolist()] = pred_mask + pred_selector = prior_mask.nonzero().T.tolist() + prior_mask[pred_selector] = pred_mask background_mask = torch.logical_not(prior_mask) diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py index fa1f1ff5e97..b7fce691faf 100644 --- a/torchvision/models/detection/yolo.py +++ b/torchvision/models/detection/yolo.py @@ -11,13 +11,9 @@ from .._utils import _ovewrite_value_param from ..yolo import YOLOV4Backbone from .backbone_utils import _validate_trainable_layers -from .yolo_networks import DarknetNetwork, YOLOV4Network +from .yolo_networks import DarknetNetwork, YOLOV4Network, PRED, TARGET, TARGETS -IMAGES = Union[Tuple[Tensor, ...], List[Tensor]] -PRED = Dict[str, Any] -PREDS = Union[Tuple[PRED, ...], List[PRED]] -TARGET = Dict[str, Any] -TARGETS = Union[Tuple[TARGET, ...], List[TARGET]] +IMAGES = List[Tensor] # TorchScript doesn't allow a tuple. class YOLO(nn.Module): diff --git a/torchvision/models/detection/yolo_loss.py b/torchvision/models/detection/yolo_loss.py index 31efb0d3742..3de448da411 100644 --- a/torchvision/models/detection/yolo_loss.py +++ b/torchvision/models/detection/yolo_loss.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple import torch from torch import Tensor @@ -16,38 +16,35 @@ ) -def box_iou_loss(boxes1: Tensor, boxes2: Tensor) -> Tensor: - return 1.0 - box_iou(boxes1, boxes2).diagonal() - - -_iou_and_loss_functions = { - "iou": (box_iou, box_iou_loss), - "giou": (generalized_box_iou, generalized_box_iou_loss), - "diou": (distance_box_iou, distance_box_iou_loss), - "ciou": (complete_box_iou, complete_box_iou_loss), -} - +def _binary_cross_entropy( + inputs: Tensor, + targets: Tensor, + reduction: str = "mean", + input_is_normalized: bool = True): + """Returns the binary cross entropy from either normalized inputs or logits. -def _get_iou_and_loss_functions(name: str) -> Tuple[Callable, Callable]: - """Returns functions for calculating the IoU and the IoU loss, given the IoU variant name. + It would be more convenient to pass the correct cross entropy function to every function that uses it, but + TorchScript doesn't allow passing functions. Args: - name: Name of the IoU variant. Either "iou", "giou", "diou", or "ciou". - - Returns: - A tuple of two functions. The first function calculates the pairwise IoU and the second function calculates the - elementwise loss. + inputs: Probabilities in a tensor of an arbitrary shape. + targets: Targets in a tensor of the same shape as ``input``. + reduction: Specifies the reduction to apply to the output. ``'none'``: no reduction will be applied, ``'mean'``: + the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be + summed. + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. """ - if name not in _iou_and_loss_functions: - raise ValueError(f"Unknown IoU function '{name}'.") - iou_func, loss_func = _iou_and_loss_functions[name] - if not callable(iou_func): - raise ValueError(f"The IoU function '{name}' is not supported by the installed version of Torchvision.") - assert callable(loss_func) - return iou_func, loss_func + if input_is_normalized: + return binary_cross_entropy(inputs, targets, reduction=reduction) + else: + return binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) + + +def box_iou_loss(boxes1: Tensor, boxes2: Tensor) -> Tensor: + return 1.0 - box_iou(boxes1, boxes2).diagonal() -def _size_compensation(targets: Tensor, image_size: Tensor) -> Tuple[Tensor, Tensor]: +def _size_compensation(targets: Tensor, image_size: Tensor) -> Tensor: """Calcuates the size compensation factor for the overlap loss. The overlap losses for each target should be multiplied by the returned weight. The returned value is @@ -66,7 +63,7 @@ def _size_compensation(targets: Tensor, image_size: Tensor) -> Tuple[Tensor, Ten def _pairwise_confidence_loss( - preds: Tensor, overlap: Tensor, bce_func: Callable, predict_overlap: Optional[float] + preds: Tensor, overlap: Tensor, input_is_normalized: bool, predict_overlap: Optional[float] ) -> Tensor: """Calculates the confidence loss for every pair of a foreground anchor and a target. @@ -77,7 +74,7 @@ def _pairwise_confidence_loss( Args: preds: An ``[N]`` vector of predicted confidences. overlap: An ``[N, M]`` matrix of overlaps between all predicted and target bounding boxes. - bce_func: A function for calculating binary cross entropy. + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the overlap. @@ -91,16 +88,17 @@ def _pairwise_confidence_loss( targets = torch.ones_like(preds) - predict_overlap # Distance-IoU may return negative "overlaps", so we have to make sure that the targets are not negative. targets += predict_overlap * overlap.detach().clamp(min=0) - return bce_func(preds, targets, reduction="none") + return _binary_cross_entropy(preds, targets, reduction="none", input_is_normalized=input_is_normalized) else: # When not predicting overlap, target confidence is the same for every prediction, but we should still return a # matrix. targets = torch.ones_like(preds) - return bce_func(preds, targets, reduction="none").unsqueeze(1).expand(overlap.shape) + result = _binary_cross_entropy(preds, targets, reduction="none", input_is_normalized=input_is_normalized) + return result.unsqueeze(1).expand(overlap.shape) def _foreground_confidence_loss( - preds: Tensor, overlap: Tensor, bce_func: Callable, predict_overlap: Optional[float] + preds: Tensor, overlap: Tensor, input_is_normalized: bool, predict_overlap: Optional[float] ) -> Tensor: """Calculates the sum of the confidence losses for foreground anchors and their matched targets. @@ -111,7 +109,7 @@ def _foreground_confidence_loss( Args: preds: A vector of predicted confidences. overlap: A vector of overlaps between matched target and predicted bounding boxes. - bce_func: A function for calculating binary cross entropy. + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1, and 1.0 means that the target confidence is the overlap. @@ -123,21 +121,21 @@ def _foreground_confidence_loss( targets -= predict_overlap # Distance-IoU may return negative "overlaps", so we have to make sure that the targets are not negative. targets += predict_overlap * overlap.detach().clamp(min=0) - return bce_func(preds, targets, reduction="sum") + return _binary_cross_entropy(preds, targets, reduction="sum", input_is_normalized=input_is_normalized) -def _background_confidence_loss(preds: Tensor, bce_func: Callable) -> Tensor: +def _background_confidence_loss(preds: Tensor, input_is_normalized: bool) -> Tensor: """Calculates the sum of the confidence losses for background anchors. Args: preds: A vector of predicted confidences for background anchors. - bce_func: A function for calculating binary cross entropy. + input_is_normalized: If ``False``, input is logits, if ``True``, input is normalized to `0..1`. Returns: The sum of the background confidence losses. """ targets = torch.zeros_like(preds) - return bce_func(preds, targets, reduction="sum") + return _binary_cross_entropy(preds, targets, reduction="sum", input_is_normalized=input_is_normalized) def _target_labels_to_probs( @@ -177,6 +175,7 @@ def _target_labels_to_probs( return targets +@torch.jit.script @dataclass class Losses: overlap: Tensor @@ -192,9 +191,8 @@ class YOLOLoss: categories, since the YOLO model does multi-label classification. Args: - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou" (default). + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -207,19 +205,14 @@ class YOLOLoss: def __init__( self, - overlap_func: Union[str, Callable] = "ciou", + overlap_func: str = "ciou", predict_overlap: Optional[float] = None, label_smoothing: Optional[float] = None, overlap_multiplier: float = 5.0, confidence_multiplier: float = 1.0, class_multiplier: float = 1.0, ): - if callable(overlap_func): - self._pairwise_overlap = overlap_func - self._elementwise_overlap_loss = lambda boxes1, boxes2: 1.0 - overlap_func(boxes1, boxes2).diagonal() - else: - self._pairwise_overlap, self._elementwise_overlap_loss = _get_iou_and_loss_functions(overlap_func) - + self.overlap_func = overlap_func self.predict_overlap = predict_overlap self.label_smoothing = label_smoothing self.overlap_multiplier = overlap_multiplier @@ -247,18 +240,13 @@ def pairwise( """ loss_shape = torch.Size([len(preds["boxes"]), len(targets["boxes"])]) - if input_is_normalized: - bce_func = binary_cross_entropy - else: - bce_func = binary_cross_entropy_with_logits - overlap = self._pairwise_overlap(preds["boxes"], targets["boxes"]) assert overlap.shape == loss_shape overlap_loss = 1.0 - overlap assert overlap_loss.shape == loss_shape - confidence_loss = _pairwise_confidence_loss(preds["confidences"], overlap, bce_func, self.predict_overlap) + confidence_loss = _pairwise_confidence_loss(preds["confidences"], overlap, input_is_normalized, self.predict_overlap) assert confidence_loss.shape == loss_shape pred_probs = preds["classprobs"].unsqueeze(1) # [N, 1, classes] @@ -267,7 +255,10 @@ def pairwise( ) target_probs = target_probs.unsqueeze(0) # [1, M, classes] pred_probs, target_probs = torch.broadcast_tensors(pred_probs, target_probs) - class_loss = bce_func(pred_probs, target_probs, reduction="none").sum(-1) + class_loss = _binary_cross_entropy( + pred_probs, target_probs, reduction="none", input_is_normalized=input_is_normalized + ) + class_loss = class_loss.sum(-1) assert class_loss.shape == loss_shape losses = Losses( @@ -297,23 +288,20 @@ def elementwise_sums( Returns: The final losses. """ - if input_is_normalized: - bce_func = binary_cross_entropy - else: - bce_func = binary_cross_entropy_with_logits - overlap_loss = self._elementwise_overlap_loss(targets["boxes"], preds["boxes"]) overlap = 1.0 - overlap_loss overlap_loss = (overlap_loss * _size_compensation(targets["boxes"], image_size)).sum() - confidence_loss = _foreground_confidence_loss(preds["confidences"], overlap, bce_func, self.predict_overlap) - confidence_loss += _background_confidence_loss(preds["bg_confidences"], bce_func) + confidence_loss = _foreground_confidence_loss(preds["confidences"], overlap, input_is_normalized, self.predict_overlap) + confidence_loss += _background_confidence_loss(preds["bg_confidences"], input_is_normalized) pred_probs = preds["classprobs"] target_probs = _target_labels_to_probs( targets["labels"], pred_probs.shape[-1], pred_probs.dtype, self.label_smoothing ) - class_loss = bce_func(pred_probs, target_probs, reduction="sum") + class_loss = _binary_cross_entropy( + pred_probs, target_probs, reduction="sum", input_is_normalized=input_is_normalized + ) losses = Losses( overlap_loss * self.overlap_multiplier, @@ -322,3 +310,51 @@ def elementwise_sums( ) return losses + + def _pairwise_overlap(self, boxes1: Tensor, boxes2: Tensor) -> Tensor: + """Returns the pairwise intersection-over-union values between two sets of boxes. + + Uses the IoU function specified in ``self.overlap_func``. It would be better to save the function in a variable, + but TorchScript doesn't allow this. + + Args: + boxes1: first set of boxes + boxes2: second set of boxes + + Returns: + A matrix containing the pairwise IoU values for every element in ``boxes1`` and ``boxes2``. + """ + if self.overlap_func == "iou": + return box_iou(boxes1, boxes2) + elif self.overlap_func == "giou": + return generalized_box_iou(boxes1, boxes2) + elif self.overlap_func == "diou": + return distance_box_iou(boxes1, boxes2) + elif self.overlap_func == "ciou": + return complete_box_iou(boxes1, boxes2) + else: + raise ValueError(f"Unknown IoU function '{self.overlap_func}'.") + + def _elementwise_overlap_loss(self, boxes1: Tensor, boxes2: Tensor) -> Tensor: + """Returns the elementwise intersection-over-union losses between two sets of boxes. + + Uses the IoU loss function specified in ``self.overlap_func``. It would be better to save the function in a + variable, but TorchScript doesn't allow this. + + Args: + boxes1: first set of boxes + boxes2: second set of boxes + + Returns: + A vector containing the IoU losses between corresponding elements in ``boxes1`` and ``boxes2``. + """ + if self.overlap_func == "iou": + return box_iou_loss(boxes1, boxes2) + elif self.overlap_func == "giou": + return generalized_box_iou_loss(boxes1, boxes2) + elif self.overlap_func == "diou": + return distance_box_iou_loss(boxes1, boxes2) + elif self.overlap_func == "ciou": + return complete_box_iou_loss(boxes1, boxes2) + else: + raise ValueError(f"Unknown IoU function '{self.overlap_func}'.") diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py index c99220b796d..47e3e145b2e 100644 --- a/torchvision/models/detection/yolo_networks.py +++ b/torchvision/models/detection/yolo_networks.py @@ -25,15 +25,15 @@ YOLOV7Backbone, ) from .anchor_utils import global_xy -from .target_matching import HighestIoUMatching, IoUThresholdMatching, ShapeMatching, SimOTAMatching, SizeRatioMatching +from .target_matching import HighestIoUMatching, IoUThresholdMatching, PRIOR_SHAPES, SimOTAMatching, SizeRatioMatching from .yolo_loss import YOLOLoss -CONFIG = Dict[str, Any] +DARKNET_CONFIG = Dict[str, Any] CREATE_LAYER_OUTPUT = Tuple[nn.Module, int] # layer, num_outputs -PRED = Dict[str, Any] -PREDS = Union[Tuple[PRED, ...], List[PRED]] -TARGET = Dict[str, Any] -TARGETS = Union[Tuple[TARGET, ...], List[TARGET]] +PRED = Dict[str, Tensor] +PREDS = List[PRED] # TorchScript doesn't allow a tuple +TARGET = Dict[str, Tensor] +TARGETS = List[TARGET] # TorchScript doesn't allow a tuple NETWORK_OUTPUT = Tuple[List[Tensor], List[Tensor], List[int]] # detections, losses, hits @@ -45,7 +45,7 @@ class DetectionLayer(nn.Module): Args: num_classes: Number of different classes that this layer predicts. prior_shapes: A list of prior box dimensions for this layer, used for scaling the predicted dimensions. The list - should contain (width, height) tuples in the network input resolution. + should contain [width, height] pairs in the network input resolution. matching_func: The matching algorithm to be used for assigning targets to anchors. loss_func: ``YOLOLoss`` object for calculating the losses. xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps @@ -59,7 +59,7 @@ class DetectionLayer(nn.Module): def __init__( self, num_classes: int, - prior_shapes: List[Tuple[int, int]], + prior_shapes: PRIOR_SHAPES, matching_func: Callable, loss_func: YOLOLoss, xy_scale: float = 1.0, @@ -98,7 +98,7 @@ def forward(self, x: Tensor, image_size: Tensor) -> Tuple[Tensor, PREDS]: """ batch_size, num_features, height, width = x.shape num_attrs = self.num_classes + 5 - anchors_per_cell = int(torch.div(num_features, num_attrs, rounding_mode="floor")) + anchors_per_cell = num_features // num_attrs if anchors_per_cell != len(self.prior_shapes): raise ValueError( "The model predicts {} bounding boxes per spatial location, but {} prior box dimensions are defined " @@ -166,44 +166,40 @@ def match_targets( if (len(targets) != batch_size) or (len(return_preds) != batch_size): raise ValueError("Different batch size for predictions and targets.") - matches = [] + # Creating lists that are concatenated in the end will confuse TorchScript compilation. Instead, we'll create + # tensors and concatenate new matches immediately. + pred_boxes = torch.empty((0, 4), device=return_preds[0]["boxes"].device) + pred_confidences = torch.empty(0, device=return_preds[0]["confidences"].device) + pred_bg_confidences = torch.empty(0, device=return_preds[0]["confidences"].device) + pred_classprobs = torch.empty((0, self.num_classes), device=return_preds[0]["classprobs"].device) + target_boxes = torch.empty((0, 4), device=targets[0]["boxes"].device) + target_labels = torch.empty(0, dtype=torch.int64, device=targets[0]["labels"].device) + for image_preds, image_return_preds, image_targets in zip(preds, return_preds, targets): if image_targets["boxes"].shape[0] > 0: pred_selector, background_selector, target_selector = self.matching_func( image_preds, image_targets, image_size ) - matched_preds = { - "boxes": image_return_preds["boxes"][pred_selector], - "confidences": image_return_preds["confidences"][pred_selector], - "bg_confidences": image_return_preds["confidences"][background_selector], - "classprobs": image_return_preds["classprobs"][pred_selector], - } - matched_targets = { - "boxes": image_targets["boxes"][target_selector], - "labels": image_targets["labels"][target_selector], - } + pred_boxes = torch.cat((pred_boxes, image_return_preds["boxes"][pred_selector])) + pred_confidences = torch.cat((pred_confidences, image_return_preds["confidences"][pred_selector])) + pred_bg_confidences = torch.cat( + (pred_bg_confidences, image_return_preds["confidences"][background_selector]) + ) + pred_classprobs = torch.cat((pred_classprobs, image_return_preds["classprobs"][pred_selector])) + target_boxes = torch.cat((target_boxes, image_targets["boxes"][target_selector])) + target_labels = torch.cat((target_labels, image_targets["labels"][target_selector])) else: - matched_preds = { - "boxes": torch.empty((0, 4), device=image_return_preds["boxes"].device), - "confidences": torch.empty(0, device=image_return_preds["confidences"].device), - "bg_confidences": image_return_preds["confidences"].flatten(), - "classprobs": torch.empty((0, self.num_classes), device=image_return_preds["classprobs"].device), - } - matched_targets = { - "boxes": torch.empty((0, 4), device=image_targets["boxes"].device), - "labels": torch.empty(0, dtype=torch.int64, device=image_targets["labels"].device), - } - matches.append((matched_preds, matched_targets)) + pred_bg_confidences = torch.cat((pred_bg_confidences, image_return_preds["confidences"].flatten())) matched_preds = { - "boxes": torch.cat(tuple(m[0]["boxes"] for m in matches)), - "confidences": torch.cat(tuple(m[0]["confidences"] for m in matches)), - "bg_confidences": torch.cat(tuple(m[0]["bg_confidences"] for m in matches)), - "classprobs": torch.cat(tuple(m[0]["classprobs"] for m in matches)), + "boxes": pred_boxes, + "confidences": pred_confidences, + "bg_confidences": pred_bg_confidences, + "classprobs": pred_classprobs, } matched_targets = { - "boxes": torch.cat(tuple(m[1]["boxes"] for m in matches)), - "labels": torch.cat(tuple(m[1]["labels"] for m in matches)), + "boxes": target_boxes, + "labels": target_labels, } return matched_preds, matched_targets @@ -243,14 +239,14 @@ def calculate_losses( def create_detection_layer( - prior_shapes: Sequence[Tuple[int, int]], - prior_shape_idxs: Sequence[int], + prior_shapes: PRIOR_SHAPES, + prior_shape_idxs: List[int], matching_algorithm: Optional[str] = None, matching_threshold: Optional[float] = None, spatial_range: float = 5.0, size_range: float = 4.0, ignore_bg_threshold: float = 0.7, - overlap_func: Union[str, Callable] = "ciou", + overlap_func: str = "ciou", predict_overlap: Optional[float] = None, label_smoothing: Optional[float] = None, overlap_loss_multiplier: float = 5.0, @@ -262,7 +258,7 @@ def create_detection_layer( Args: prior_shapes: A list of all the prior box dimensions, used for scaling the predicted dimensions and possibly for - matching the targets to the anchors. The list should contain (width, height) tuples in the network input + matching the targets to the anchors. The list should contain [width, height] pairs in the network input resolution. prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that this layer uses. @@ -278,9 +274,8 @@ def create_detection_layer( ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou" (default). + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -297,7 +292,7 @@ def create_detection_layer( height are scaled up so that the maximum value is four times the anchor dimension. This is used by the Darknet configurations of Scaled-YOLOv4. """ - matching_func: Union[ShapeMatching, SimOTAMatching] + matching_func: Callable if matching_algorithm == "simota": loss_func = YOLOLoss( overlap_func, None, None, overlap_loss_multiplier, confidence_loss_multiplier, class_loss_multiplier @@ -328,82 +323,101 @@ def create_detection_layer( return DetectionLayer(prior_shapes=layer_shapes, matching_func=matching_func, loss_func=loss_func, **kwargs) -def run_detection( - detection_layer: DetectionLayer, - layer_input: Tensor, - targets: Optional[TARGETS], - image_size: Tensor, - detections: List[Tensor], - losses: List[Tensor], - hits: List[int], -) -> None: - """Runs the detection layer on the inputs and appends the output to the ``detections`` list. +class DetectionStage(nn.Module): + """This is a convenience class for running a detection layer. - If ``targets`` is given, also calculates the losses and appends to the ``losses`` list. - - Args: - detection_layer: The detection layer. - layer_input: Input to the detection layer. - targets: List of training targets for each image. - image_size: Width and height in a vector that defines the scale of the target coordinates. - detections: A list where a tensor containing the detections will be appended to. - losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. - hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + It might be cleaner to implement this as a function, but TorchScript allows only specific types in function + arguments, not modules. """ - output, preds = detection_layer(layer_input, image_size) - detections.append(output) - - if targets is not None: - layer_losses, layer_hits = detection_layer.calculate_losses(preds, targets, image_size) - losses.append(layer_losses) - hits.append(layer_hits) - - -def run_detection_with_aux_head( - detection_layer: DetectionLayer, - aux_detection_layer: DetectionLayer, - layer_input: Tensor, - aux_input: Tensor, - targets: Optional[TARGETS], - image_size: Tensor, - aux_weight: float, - detections: List[Tensor], - losses: List[Tensor], - hits: List[int], -) -> None: - """Runs the detection layer on the inputs and appends the output to the ``detections`` list. - - If ``targets`` is given, also runs the auxiliary detection layer on the auxiliary inputs, calculates the losses, and - appends the losses to the ``losses`` list. + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self.detection_layer = create_detection_layer(**kwargs) + + def forward( + self, + layer_input: Tensor, + targets: Optional[TARGETS], + image_size: Tensor, + detections: List[Tensor], + losses: List[Tensor], + hits: List[int], + ) -> None: + """Runs the detection layer on the inputs and appends the output to the ``detections`` list. + + If ``targets`` is given, also calculates the losses and appends to the ``losses`` list. + + Args: + layer_input: Input to the detection layer. + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + detections: A list where a tensor containing the detections will be appended to. + losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. + hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + """ + output, preds = self.detection_layer(layer_input, image_size) + detections.append(output) + + if targets is not None: + layer_losses, layer_hits = self.detection_layer.calculate_losses(preds, targets, image_size) + losses.append(layer_losses) + hits.append(layer_hits) + + +class DetectionStageWithAux(nn.Module): + """This class represents a combination of a lead and an auxiliary detection layer. Args: - detection_layer: The lead detection layer. - aux_detection_layer: The auxiliary detection layer. - layer_input: Input to the lead detection layer. - aux_input: Input to the auxiliary detection layer. - targets: List of training targets for each image. - image_size: Width and height in a vector that defines the scale of the target coordinates. - aux_weight: Weight of the auxiliary loss. - detections: A list where a tensor containing the detections will be appended to. - losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. - hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target. This parameter specifies `N` for the lead head. + aux_spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target. This parameter specifies `N` for the auxiliary head. + aux_weight: Weight for the loss from the auxiliary head. """ - output, preds = detection_layer(layer_input, image_size) - detections.append(output) - - if targets is not None: - # Match lead head predictions to targets and calculate losses from lead head outputs. - layer_losses, layer_hits = detection_layer.calculate_losses(preds, targets, image_size) - losses.append(layer_losses) - hits.append(layer_hits) - - # Match lead head predictions to targets and calculate losses from auxiliary head outputs. - _, aux_preds = aux_detection_layer(aux_input, image_size) - layer_losses, layer_hits = aux_detection_layer.calculate_losses( - preds, targets, image_size, loss_preds=aux_preds - ) - losses.append(layer_losses * aux_weight) - hits.append(layer_hits) + def __init__(self, spatial_range: float = 5.0, aux_spatial_range: float = 3.0, aux_weight: float = 0.25, **kwargs: Any) -> None: + self.detection_layer = create_detection_layer(spatial_range=spatial_range, **kwargs) + self.aux_detection_layer = create_detection_layer(spatial_range=aux_spatial_range, **kwargs) + self.aux_weight = aux_weight + + def forward( + self, + layer_input: Tensor, + aux_input: Tensor, + targets: Optional[TARGETS], + image_size: Tensor, + detections: List[Tensor], + losses: List[Tensor], + hits: List[int], + ) -> None: + """Runs the detection layer and the auxiliary detection layer on their respective inputs and appends the outputs + to the ``detections`` list. + + If ``targets`` is given, also calculates the losses and appends to the ``losses`` list. + + Args: + layer_input: Input to the lead detection layer. + aux_input: Input to the auxiliary detection layer. + targets: List of training targets for each image. + image_size: Width and height in a vector that defines the scale of the target coordinates. + detections: A list where a tensor containing the detections will be appended to. + losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. + hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + """ + output, preds = self.detection_layer(layer_input, image_size) + detections.append(output) + + if targets is not None: + # Match lead head predictions to targets and calculate losses from lead head outputs. + layer_losses, layer_hits = self.detection_layer.calculate_losses(preds, targets, image_size) + losses.append(layer_losses) + hits.append(layer_hits) + + # Match lead head predictions to targets and calculate losses from auxiliary head outputs. + _, aux_preds = self.aux_detection_layer(aux_input, image_size) + layer_losses, layer_hits = self.aux_detection_layer.calculate_losses( + preds, targets, image_size, loss_preds=aux_preds + ) + losses.append(layer_losses * self.aux_weight) + hits.append(layer_hits) @torch.jit.script @@ -437,10 +451,10 @@ class YOLOV4TinyNetwork(nn.Module): "linear", or "none". normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for - matching the targets to the anchors. The list should contain (width, height) tuples in the network input - resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They - are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning - that you typically want to sort the shapes from the smallest to the largest. + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that @@ -453,9 +467,8 @@ class YOLOV4TinyNetwork(nn.Module): ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou" (default). + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -475,7 +488,7 @@ def __init__( width: int = 32, activation: Optional[str] = "leaky", normalization: Optional[str] = "batchnorm", - prior_shapes: Optional[List[Tuple[int, int]]] = None, + prior_shapes: Optional[PRIOR_SHAPES] = None, **kwargs: Any, ) -> None: super().__init__() @@ -483,15 +496,15 @@ def __init__( # By default use the prior shapes that have been learned from the COCO data. if prior_shapes is None: prior_shapes = [ - (12, 16), - (19, 36), - (40, 28), - (36, 75), - (76, 55), - (72, 146), - (142, 110), - (192, 243), - (459, 401), + [12, 16], + [19, 36], + [40, 28], + [36, 75], + [76, 55], + [72, 146], + [142, 110], + [192, 243], + [459, 401], ] anchors_per_cell = 3 else: @@ -511,10 +524,14 @@ def upsample(in_channels: int, out_channels: int) -> nn.Module: def outputs(in_channels: int) -> nn.Module: return nn.Conv2d(in_channels, num_outputs, kernel_size=1, stride=1, bias=True) - def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: assert prior_shapes is not None - return create_detection_layer( - prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, ) self.backbone = backbone or YOLOV4TinyBackbone(width=width, activation=activation, normalization=normalization) @@ -556,9 +573,9 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.upsample4(p4), c3), dim=1) p3 = self.fpn3(x) - run_detection(self.detect5, self.out5(p5), targets, image_size, detections, losses, hits) - run_detection(self.detect4, self.out4(p4), targets, image_size, detections, losses, hits) - run_detection(self.detect3, self.out3(p3), targets, image_size, detections, losses, hits) + self.detect5(self.out5(p5), targets, image_size, detections, losses, hits) + self.detect4(self.out4(p4), targets, image_size, detections, losses, hits) + self.detect3(self.out3(p3), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -573,10 +590,10 @@ class YOLOV4Network(nn.Module): "linear", or "none". normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for - matching the targets to the anchors. The list should contain (width, height) tuples in the network input - resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They - are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning - that you typically want to sort the shapes from the smallest to the largest. + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that @@ -589,9 +606,8 @@ class YOLOV4Network(nn.Module): ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou" (default). + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -611,7 +627,7 @@ def __init__( widths: Sequence[int] = (32, 64, 128, 256, 512, 1024), activation: Optional[str] = "silu", normalization: Optional[str] = "batchnorm", - prior_shapes: Optional[List[Tuple[int, int]]] = None, + prior_shapes: Optional[PRIOR_SHAPES] = None, **kwargs: Any, ) -> None: super().__init__() @@ -619,15 +635,15 @@ def __init__( # By default use the prior shapes that have been learned from the COCO data. if prior_shapes is None: prior_shapes = [ - (12, 16), - (19, 36), - (40, 28), - (36, 75), - (76, 55), - (72, 146), - (142, 110), - (192, 243), - (459, 401), + [12, 16], + [19, 36], + [40, 28], + [36, 75], + [76, 55], + [72, 146], + [142, 110], + [192, 243], + [459, 401], ] anchors_per_cell = 3 else: @@ -665,10 +681,14 @@ def upsample(in_channels: int, out_channels: int) -> nn.Module: def downsample(in_channels: int, out_channels: int) -> nn.Module: return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) - def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: assert prior_shapes is not None - return create_detection_layer( - prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, ) if backbone is not None: @@ -723,9 +743,9 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample4(n4), c5), dim=1) n5 = self.pan5(x) - run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) - run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) - run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + self.detect3(self.out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -740,10 +760,10 @@ class YOLOV4P6Network(nn.Module): "linear", or "none". normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for - matching the targets to the anchors. The list should contain (width, height) tuples in the network input - resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They - are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning - that you typically want to sort the shapes from the smallest to the largest. + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `4N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that @@ -756,9 +776,8 @@ class YOLOV4P6Network(nn.Module): ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou" (default). + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -778,7 +797,7 @@ def __init__( widths: Sequence[int] = (32, 64, 128, 256, 512, 1024, 1024), activation: Optional[str] = "silu", normalization: Optional[str] = "batchnorm", - prior_shapes: Optional[List[Tuple[int, int]]] = None, + prior_shapes: Optional[PRIOR_SHAPES] = None, **kwargs: Any, ) -> None: super().__init__() @@ -786,22 +805,22 @@ def __init__( # By default use the prior shapes that have been learned from the COCO data. if prior_shapes is None: prior_shapes = [ - (13, 17), - (31, 25), - (24, 51), - (61, 45), - (61, 45), - (48, 102), - (119, 96), - (97, 189), - (97, 189), - (217, 184), - (171, 384), - (324, 451), - (324, 451), - (545, 357), - (616, 618), - (1024, 1024), + [13, 17], + [31, 25], + [24, 51], + [61, 45], + [61, 45], + [48, 102], + [119, 96], + [97, 189], + [97, 189], + [217, 184], + [171, 384], + [324, 451], + [324, 451], + [545, 357], + [616, 618], + [1024, 1024], ] anchors_per_cell = 4 else: @@ -839,10 +858,14 @@ def upsample(in_channels: int, out_channels: int) -> nn.Module: def downsample(in_channels: int, out_channels: int) -> nn.Module: return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) - def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: assert prior_shapes is not None - return create_detection_layer( - prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, ) if backbone is not None: @@ -913,10 +936,10 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample5(n5), c6), dim=1) n6 = self.pan6(x) - run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) - run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) - run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) - run_detection(self.detect6, self.out6(n6), targets, image_size, detections, losses, hits) + self.detect3(self.out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), targets, image_size, detections, losses, hits) + self.detect6(self.out6(n6), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -936,10 +959,10 @@ class YOLOV5Network(nn.Module): "linear", or "none". normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for - matching the targets to the anchors. The list should contain (width, height) tuples in the network input - resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They - are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning - that you typically want to sort the shapes from the smallest to the largest. + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that @@ -952,9 +975,8 @@ class YOLOV5Network(nn.Module): ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou" (default). + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -975,7 +997,7 @@ def __init__( depth: int = 3, activation: Optional[str] = "silu", normalization: Optional[str] = "batchnorm", - prior_shapes: Optional[List[Tuple[int, int]]] = None, + prior_shapes: Optional[PRIOR_SHAPES] = None, **kwargs: Any, ) -> None: super().__init__() @@ -983,15 +1005,15 @@ def __init__( # By default use the prior shapes that have been learned from the COCO data. if prior_shapes is None: prior_shapes = [ - (12, 16), - (19, 36), - (40, 28), - (36, 75), - (76, 55), - (72, 146), - (142, 110), - (192, 243), - (459, 401), + [12, 16], + [19, 36], + [40, 28], + [36, 75], + [76, 55], + [72, 146], + [142, 110], + [192, 243], + [459, 401], ] anchors_per_cell = 3 else: @@ -1023,10 +1045,14 @@ def csp(in_channels: int, out_channels: int) -> nn.Module: activation=activation, ) - def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: assert prior_shapes is not None - return create_detection_layer( - prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, ) self.backbone = backbone or YOLOV5Backbone( @@ -1083,9 +1109,9 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample4(n4), p5), dim=1) n5 = self.pan5(x) - run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) - run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) - run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + self.detect3(self.out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -1100,24 +1126,26 @@ class YOLOV7Network(nn.Module): "linear", or "none". normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for - matching the targets to the anchors. The list should contain (width, height) tuples in the network input - resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They - are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning - that you typically want to sort the shapes from the smallest to the largest. - aux_weight: Weight for the loss from the auxiliary heads. + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `4N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target. This parameter specifies `N` for the lead head. + aux_spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + area centered at the target. This parameter specifies `N` for the auxiliary head. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou" (default). + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -1128,6 +1156,7 @@ class YOLOV7Network(nn.Module): class_loss_multiplier: Classification loss will be scaled by this value. xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps to produce coordinate values close to one. + aux_weight: Weight for the loss from the auxiliary heads. """ def __init__( @@ -1137,33 +1166,30 @@ def __init__( widths: Sequence[int] = (64, 128, 256, 512, 768, 1024), activation: Optional[str] = "silu", normalization: Optional[str] = "batchnorm", - prior_shapes: Optional[List[Tuple[int, int]]] = None, - aux_weight: float = 0.25, + prior_shapes: Optional[PRIOR_SHAPES] = None, **kwargs: Any, ) -> None: super().__init__() - self.aux_weight = aux_weight - # By default use the prior shapes that have been learned from the COCO data. if prior_shapes is None: prior_shapes = [ - (13, 17), - (31, 25), - (24, 51), - (61, 45), - (61, 45), - (48, 102), - (119, 96), - (97, 189), - (97, 189), - (217, 184), - (171, 384), - (324, 451), - (324, 451), - (545, 357), - (616, 618), - (1024, 1024), + [13, 17], + [31, 25], + [24, 51], + [61, 45], + [61, 45], + [48, 102], + [119, 96], + [97, 189], + [97, 189], + [217, 184], + [171, 384], + [324, 451], + [324, 451], + [545, 357], + [616, 618], + [1024, 1024], ] anchors_per_cell = 4 else: @@ -1204,12 +1230,11 @@ def upsample(in_channels: int, out_channels: int) -> nn.Module: def downsample(in_channels: int, out_channels: int) -> nn.Module: return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization) - def detect(prior_shape_idxs: Sequence[int], range: float) -> DetectionLayer: + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStageWithAux: assert prior_shapes is not None - return create_detection_layer( - prior_shapes, - prior_shape_idxs, - spatial_range=range, + return DetectionStageWithAux( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), num_classes=num_classes, input_is_normalized=False, **kwargs, @@ -1259,14 +1284,10 @@ def detect(prior_shape_idxs: Sequence[int], range: float) -> DetectionLayer: self.out6 = out(w6 // 2, w6) self.aux_out6 = out(w6 // 2, w6 + (w6 // 4)) - self.detect3 = detect(range(0, anchors_per_cell), 5.0) - self.aux_detect3 = detect(range(0, anchors_per_cell), 3.0) - self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2), 5.0) - self.aux_detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2), 3.0) - self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3), 5.0) - self.aux_detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3), 3.0) - self.detect6 = detect(range(anchors_per_cell * 3, anchors_per_cell * 4), 5.0) - self.aux_detect6 = detect(range(anchors_per_cell * 3, anchors_per_cell * 4), 3.0) + self.detect3 = detect(range(0, anchors_per_cell)) + self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2)) + self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3)) + self.detect6 = detect(range(anchors_per_cell * 3, anchors_per_cell * 4)) def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: detections: List[Tensor] = [] # Outputs from detection layers @@ -1291,54 +1312,10 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample5(n5), c6), dim=1) n6 = self.pan6(x) - run_detection_with_aux_head( - self.detect3, - self.aux_detect3, - self.out3(n3), - self.aux_out3(n3), - targets, - image_size, - self.aux_weight, - detections, - losses, - hits, - ) - run_detection_with_aux_head( - self.detect4, - self.aux_detect4, - self.out4(n4), - self.aux_out4(p4), - targets, - image_size, - self.aux_weight, - detections, - losses, - hits, - ) - run_detection_with_aux_head( - self.detect5, - self.aux_detect5, - self.out5(n5), - self.aux_out5(p5), - targets, - image_size, - self.aux_weight, - detections, - losses, - hits, - ) - run_detection_with_aux_head( - self.detect6, - self.aux_detect6, - self.out6(n6), - self.aux_out6(c6), - targets, - image_size, - self.aux_weight, - detections, - losses, - hits, - ) + self.detect3(self.out3(n3), self.aux_out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), self.aux_out4(p4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), self.aux_out5(p5), targets, image_size, detections, losses, hits) + self.detect6(self.out6(n6), self.aux_out6(c6), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -1415,10 +1392,10 @@ class YOLOXNetwork(nn.Module): "linear", or "none". normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none". prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for - matching the targets to the anchors. The list should contain (width, height) tuples in the network input - resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They - are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning - that you typically want to sort the shapes from the smallest to the largest. + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are + assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that + you typically want to sort the shapes from the smallest to the largest. matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that @@ -1431,9 +1408,8 @@ class YOLOXNetwork(nn.Module): ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou" (default). + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -1454,14 +1430,14 @@ def __init__( depth: int = 3, activation: Optional[str] = "silu", normalization: Optional[str] = "batchnorm", - prior_shapes: Optional[List[Tuple[int, int]]] = None, + prior_shapes: Optional[PRIOR_SHAPES] = None, **kwargs: Any, ) -> None: super().__init__() # By default use one anchor per cell and the stride as the prior size. if prior_shapes is None: - prior_shapes = [(8, 8), (16, 16), (32, 32)] + prior_shapes = [[8, 8], [16, 16], [32, 32]] anchors_per_cell = 1 else: anchors_per_cell, modulo = divmod(len(prior_shapes), 3) @@ -1497,10 +1473,14 @@ def head(in_channels: int, hidden_channels: int) -> YOLOXHead: norm=normalization, ) - def detect(prior_shape_idxs: Sequence[int]) -> DetectionLayer: + def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage: assert prior_shapes is not None - return create_detection_layer( - prior_shapes, prior_shape_idxs, num_classes=num_classes, input_is_normalized=False, **kwargs + return DetectionStage( + prior_shapes=prior_shapes, + prior_shape_idxs=list(prior_shape_idxs), + num_classes=num_classes, + input_is_normalized=False, + **kwargs, ) self.backbone = backbone or YOLOV5Backbone( @@ -1557,9 +1537,9 @@ def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPU x = torch.cat((self.downsample4(n4), p5), dim=1) n5 = self.pan5(x) - run_detection(self.detect3, self.out3(n3), targets, image_size, detections, losses, hits) - run_detection(self.detect4, self.out4(n4), targets, image_size, detections, losses, hits) - run_detection(self.detect5, self.out5(n5), targets, image_size, detections, losses, hits) + self.detect3(self.out3(n3), targets, image_size, detections, losses, hits) + self.detect4(self.out4(n4), targets, image_size, detections, losses, hits) + self.detect5(self.out5(n5), targets, image_size, detections, losses, hits) return detections, losses, hits @@ -1585,9 +1565,8 @@ class DarknetNetwork(nn.Module): ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou". + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -1811,7 +1790,7 @@ def convert(key: str, value: str) -> Union[str, int, float, List[Union[str, int, return sections -def _create_layer(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: +def _create_layer(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: """Calls one of the ``_create_(config, num_inputs)`` functions to create a PyTorch module from the layer config. @@ -1834,7 +1813,7 @@ def _create_layer(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREAT return create_func[config["type"]](config, num_inputs, **kwargs) -def _create_convolutional(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: +def _create_convolutional(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: """Creates a convolutional layer. Args: @@ -1861,7 +1840,7 @@ def _create_convolutional(config: CONFIG, num_inputs: List[int], **kwargs: Any) return layer, config["filters"] -def _create_maxpool(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: +def _create_maxpool(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: """Creates a max pooling layer. Padding is added so that the output resolution will be the input resolution divided by stride, rounded upwards. @@ -1878,7 +1857,7 @@ def _create_maxpool(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CRE return layer, num_inputs[-1] -def _create_route(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: +def _create_route(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: """Creates a routing layer. A routing layer concatenates the output (or part of it) from the layers specified by the "layers" configuration @@ -1907,7 +1886,7 @@ def _create_route(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREAT return layer, num_outputs -def _create_shortcut(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: +def _create_shortcut(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: """Creates a shortcut layer. A shortcut layer adds a residual connection from the layer specified by the "from" configuration option. @@ -1924,7 +1903,7 @@ def _create_shortcut(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CR return layer, num_inputs[-1] -def _create_upsample(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: +def _create_upsample(config: DARKNET_CONFIG, num_inputs: List[int], **kwargs: Any) -> CREATE_LAYER_OUTPUT: """Creates a layer that upsamples the data. Args: @@ -1940,15 +1919,15 @@ def _create_upsample(config: CONFIG, num_inputs: List[int], **kwargs: Any) -> CR def _create_yolo( - config: CONFIG, + config: DARKNET_CONFIG, num_inputs: List[int], - prior_shapes: Optional[List[Tuple[int, int]]] = None, + prior_shapes: Optional[PRIOR_SHAPES] = None, matching_algorithm: Optional[str] = None, matching_threshold: Optional[float] = None, spatial_range: float = 5.0, size_range: float = 4.0, ignore_bg_threshold: Optional[float] = None, - overlap_func: Optional[Union[str, Callable]] = None, + overlap_func: Optional[str] = None, predict_overlap: Optional[float] = None, label_smoothing: Optional[float] = None, overlap_loss_multiplier: Optional[float] = None, @@ -1962,10 +1941,11 @@ def _create_yolo( config: Dictionary of configuration options for this layer. num_inputs: Number of channels in the input of every layer up to this layer. Not used by the detection layer. prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for - matching the targets to the anchors. The list should contain (width, height) tuples in the network input - resolution. There should be `3N` tuples, where `N` defines the number of anchors per spatial location. They - are assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning - that you typically want to sort the shapes from the smallest to the largest. + matching the targets to the anchors. The list should contain [width, height] pairs in the network input + resolution. There should be `M × N` pairs, where `M` is the number of detection layers and `N` is the number + of anchors per spatial location. They are assigned to the layers from the lowest (high-resolution) to the + highest (low-resolution) layer, meaning that you typically want to sort the shapes from the smallest to the + largest. matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that @@ -1978,9 +1958,8 @@ def _create_yolo( ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor has IoU with some target greater than this threshold, the predictor will not be taken into account when calculating the confidence loss. - overlap_func: A function for calculating the pairwise overlaps between two sets of boxes. Either a string or a - function that returns a matrix of pairwise overlaps. Valid string values are "iou", "giou", "diou", and - "ciou". + overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou", + "giou", "diou", and "ciou". predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of ``overlap_func``. @@ -1997,7 +1976,7 @@ def _create_yolo( if prior_shapes is None: # The "anchors" list alternates width and height. dims = config["anchors"] - prior_shapes = [(dims[i], dims[i + 1]) for i in range(0, len(dims), 2)] + prior_shapes = [[dims[i], dims[i + 1]] for i in range(0, len(dims), 2)] if ignore_bg_threshold is None: ignore_bg_threshold = config.get("ignore_thresh", 1.0) assert isinstance(ignore_bg_threshold, float) From 7bb600829f4e3a4372749e269421b9714f0e4784 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 19 Apr 2023 20:23:35 +0300 Subject: [PATCH 08/13] Fixed formatting --- test/test_models_detection_anchor_utils.py | 8 +++- test/test_models_detection_yolo_networks.py | 4 +- torchvision/models/detection/__init__.py | 4 +- .../models/detection/target_matching.py | 2 +- torchvision/models/detection/yolo.py | 4 +- torchvision/models/detection/yolo_loss.py | 14 +++--- torchvision/models/detection/yolo_networks.py | 44 +++++++++++-------- torchvision/models/yolo.py | 2 +- 8 files changed, 48 insertions(+), 34 deletions(-) diff --git a/test/test_models_detection_anchor_utils.py b/test/test_models_detection_anchor_utils.py index 661b8685bda..edba1b1626c 100644 --- a/test/test_models_detection_anchor_utils.py +++ b/test/test_models_detection_anchor_utils.py @@ -1,7 +1,13 @@ import pytest import torch from common_utils import assert_equal -from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator, global_xy, grid_centers, grid_offsets +from torchvision.models.detection.anchor_utils import ( + AnchorGenerator, + DefaultBoxGenerator, + global_xy, + grid_centers, + grid_offsets, +) from torchvision.models.detection.image_list import ImageList diff --git a/test/test_models_detection_yolo_networks.py b/test/test_models_detection_yolo_networks.py index 121f59df7cf..0bec4c09e8f 100644 --- a/test/test_models_detection_yolo_networks.py +++ b/test/test_models_detection_yolo_networks.py @@ -42,8 +42,8 @@ def test_create_convolutional(config): @pytest.mark.parametrize( "config", [ - ({ "size": 2, "stride": 2 }), - ({ "size": 6, "stride": 3 }), + ({"size": 2, "stride": 2}), + ({"size": 6, "stride": 3}), ], ) def test_create_maxpool(config): diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 9be3dea4594..35fcdcf9015 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -5,12 +5,12 @@ from .retinanet import * from .ssd import * from .ssdlite import * -from .yolo import YOLO, YOLOV4_Backbone_Weights, YOLOV4_Weights, yolov4, yolo_darknet +from .yolo import YOLO, yolo_darknet, yolov4, YOLOV4_Backbone_Weights, YOLOV4_Weights from .yolo_networks import ( DarknetNetwork, - YOLOV4TinyNetwork, YOLOV4Network, YOLOV4P6Network, + YOLOV4TinyNetwork, YOLOV5Network, YOLOV7Network, YOLOXNetwork, diff --git a/torchvision/models/detection/target_matching.py b/torchvision/models/detection/target_matching.py index f0c2f06fda8..43b512dc683 100644 --- a/torchvision/models/detection/target_matching.py +++ b/torchvision/models/detection/target_matching.py @@ -334,7 +334,7 @@ class SimOTAMatching: prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that this layer uses. loss_func: A ``YOLOLoss`` object that can be used to calculate the pairwise costs. - spatial_range: For each target, restrict to the anchors that are within an `N × N` grid cell are centered at the + spatial_range: For each target, restrict to the anchors that are within an `N x N` grid cell are centered at the target, where `N` is the value of this parameter. size_range: For each target, restrict to the anchors whose prior dimensions are not larger than the target dimensions multiplied by this value and not smaller than the target dimensions divided by this value. diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py index b7fce691faf..21e2305b62a 100644 --- a/torchvision/models/detection/yolo.py +++ b/torchvision/models/detection/yolo.py @@ -1,8 +1,8 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn -import warnings from torch import Tensor from ...ops import batched_nms @@ -11,7 +11,7 @@ from .._utils import _ovewrite_value_param from ..yolo import YOLOV4Backbone from .backbone_utils import _validate_trainable_layers -from .yolo_networks import DarknetNetwork, YOLOV4Network, PRED, TARGET, TARGETS +from .yolo_networks import DarknetNetwork, PRED, TARGET, TARGETS, YOLOV4Network IMAGES = List[Tensor] # TorchScript doesn't allow a tuple. diff --git a/torchvision/models/detection/yolo_loss.py b/torchvision/models/detection/yolo_loss.py index 3de448da411..1ac0940680a 100644 --- a/torchvision/models/detection/yolo_loss.py +++ b/torchvision/models/detection/yolo_loss.py @@ -17,10 +17,8 @@ def _binary_cross_entropy( - inputs: Tensor, - targets: Tensor, - reduction: str = "mean", - input_is_normalized: bool = True): + inputs: Tensor, targets: Tensor, reduction: str = "mean", input_is_normalized: bool = True +) -> Tensor: """Returns the binary cross entropy from either normalized inputs or logits. It would be more convenient to pass the correct cross entropy function to every function that uses it, but @@ -246,7 +244,9 @@ def pairwise( overlap_loss = 1.0 - overlap assert overlap_loss.shape == loss_shape - confidence_loss = _pairwise_confidence_loss(preds["confidences"], overlap, input_is_normalized, self.predict_overlap) + confidence_loss = _pairwise_confidence_loss( + preds["confidences"], overlap, input_is_normalized, self.predict_overlap + ) assert confidence_loss.shape == loss_shape pred_probs = preds["classprobs"].unsqueeze(1) # [N, 1, classes] @@ -292,7 +292,9 @@ def elementwise_sums( overlap = 1.0 - overlap_loss overlap_loss = (overlap_loss * _size_compensation(targets["boxes"], image_size)).sum() - confidence_loss = _foreground_confidence_loss(preds["confidences"], overlap, input_is_normalized, self.predict_overlap) + confidence_loss = _foreground_confidence_loss( + preds["confidences"], overlap, input_is_normalized, self.predict_overlap + ) confidence_loss += _background_confidence_loss(preds["bg_confidences"], input_is_normalized) pred_probs = preds["classprobs"] diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py index 47e3e145b2e..5daa405a157 100644 --- a/torchvision/models/detection/yolo_networks.py +++ b/torchvision/models/detection/yolo_networks.py @@ -267,7 +267,7 @@ def create_detection_layer( ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target, where `N` is the value of this parameter. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. @@ -329,6 +329,7 @@ class DetectionStage(nn.Module): It might be cleaner to implement this as a function, but TorchScript allows only specific types in function arguments, not modules. """ + def __init__(self, **kwargs: Any) -> None: super().__init__() self.detection_layer = create_detection_layer(**kwargs) @@ -352,7 +353,8 @@ def forward( image_size: Width and height in a vector that defines the scale of the target coordinates. detections: A list where a tensor containing the detections will be appended to. losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. - hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is + given. """ output, preds = self.detection_layer(layer_input, image_size) detections.append(output) @@ -367,13 +369,16 @@ class DetectionStageWithAux(nn.Module): """This class represents a combination of a lead and an auxiliary detection layer. Args: - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target. This parameter specifies `N` for the lead head. - aux_spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + aux_spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target. This parameter specifies `N` for the auxiliary head. aux_weight: Weight for the loss from the auxiliary head. """ - def __init__(self, spatial_range: float = 5.0, aux_spatial_range: float = 3.0, aux_weight: float = 0.25, **kwargs: Any) -> None: + + def __init__( + self, spatial_range: float = 5.0, aux_spatial_range: float = 3.0, aux_weight: float = 0.25, **kwargs: Any + ) -> None: self.detection_layer = create_detection_layer(spatial_range=spatial_range, **kwargs) self.aux_detection_layer = create_detection_layer(spatial_range=aux_spatial_range, **kwargs) self.aux_weight = aux_weight @@ -400,7 +405,8 @@ def forward( image_size: Width and height in a vector that defines the scale of the target coordinates. detections: A list where a tensor containing the detections will be appended to. losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given. - hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is given. + hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is + given. """ output, preds = self.detection_layer(layer_input, image_size) detections.append(output) @@ -460,7 +466,7 @@ class YOLOV4TinyNetwork(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target, where `N` is the value of this parameter. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. @@ -599,7 +605,7 @@ class YOLOV4Network(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target, where `N` is the value of this parameter. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. @@ -769,7 +775,7 @@ class YOLOV4P6Network(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target, where `N` is the value of this parameter. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. @@ -968,7 +974,7 @@ class YOLOV5Network(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target, where `N` is the value of this parameter. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. @@ -1135,9 +1141,9 @@ class YOLOV7Network(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target. This parameter specifies `N` for the lead head. - aux_spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + aux_spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target. This parameter specifies `N` for the auxiliary head. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. @@ -1401,7 +1407,7 @@ class YOLOXNetwork(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target, where `N` is the value of this parameter. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. @@ -1558,7 +1564,7 @@ class DarknetNetwork(nn.Module): ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target, where `N` is the value of this parameter. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. @@ -1942,16 +1948,16 @@ def _create_yolo( num_inputs: Number of channels in the input of every layer up to this layer. Not used by the detection layer. prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for matching the targets to the anchors. The list should contain [width, height] pairs in the network input - resolution. There should be `M × N` pairs, where `M` is the number of detection layers and `N` is the number - of anchors per spatial location. They are assigned to the layers from the lowest (high-resolution) to the - highest (low-resolution) layer, meaning that you typically want to sort the shapes from the smallest to the - largest. + resolution. There should be `M x N` pairs, where `M` is the number of detection layers and `N` is the number + of anchors per spatial location. They are assigned to the layers from the lowest (high-resolution) to the + highest (low-resolution) layer, meaning that you typically want to sort the shapes from the smallest to the + largest. matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that gives the highest IoU, default). matching_threshold: Threshold for "size" and "iou" matching algorithms. - spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N × N` grid cell + spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell area centered at the target, where `N` is the value of this parameter. size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and no less than `1/N` times the target dimensions, where `N` is the value of this parameter. diff --git a/torchvision/models/yolo.py b/torchvision/models/yolo.py index aaaa8a1e449..8f3be429a63 100644 --- a/torchvision/models/yolo.py +++ b/torchvision/models/yolo.py @@ -2,7 +2,7 @@ from typing import List, Optional, Sequence, Tuple import torch -from torch import Tensor, nn +from torch import nn, Tensor def _get_padding(kernel_size: int, stride: int) -> Tuple[int, nn.Module]: From 27b9c9e9bbe93f84496e3e431e29a8623cfff24b Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 19 Apr 2023 20:25:49 +0300 Subject: [PATCH 09/13] Fixed flake8 --- test/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index 4dfbebad0fe..b220aebff77 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -922,7 +922,7 @@ def test_yolo_darknet(dev): if not full_validation: msg = ( - f"The output of yolo_darknet could only be partially validated. " + "The output of yolo_darknet could only be partially validated. " "This is likely due to unit-test flakiness, but you may " "want to do additional manual checks if you made " "significant changes to the codebase." From 89f1500bb4a354c0f8d354159c4349f2c79e76e9 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Wed, 19 Apr 2023 20:35:55 +0300 Subject: [PATCH 10/13] Fixed docstrings --- torchvision/models/detection/yolo.py | 4 ++-- torchvision/models/detection/yolo_networks.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py index 21e2305b62a..3eb9d7a1c04 100644 --- a/torchvision/models/detection/yolo.py +++ b/torchvision/models/detection/yolo.py @@ -17,8 +17,8 @@ class YOLO(nn.Module): - """YOLO implementation that supports the most important features of YOLOv3, YOLOv4, YOLOv5, YOLOv7, Scaled-YOLOv4, - and YOLOX. + """YOLO implementation that supports the most important features of YOLOv3, YOLOv4, YOLOv5, YOLOv7, Scaled- + YOLOv4, and YOLOX. *YOLOv3 paper*: `Joseph Redmon and Ali Farhadi `__ diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py index 5daa405a157..3765d585e7c 100644 --- a/torchvision/models/detection/yolo_networks.py +++ b/torchvision/models/detection/yolo_networks.py @@ -393,8 +393,8 @@ def forward( losses: List[Tensor], hits: List[int], ) -> None: - """Runs the detection layer and the auxiliary detection layer on their respective inputs and appends the outputs - to the ``detections`` list. + """Runs the detection layer and the auxiliary detection layer on their respective inputs and appends the + outputs to the ``detections`` list. If ``targets`` is given, also calculates the losses and appends to the ``losses`` list. @@ -1949,9 +1949,9 @@ def _create_yolo( prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for matching the targets to the anchors. The list should contain [width, height] pairs in the network input resolution. There should be `M x N` pairs, where `M` is the number of detection layers and `N` is the number - of anchors per spatial location. They are assigned to the layers from the lowest (high-resolution) to the - highest (low-resolution) layer, meaning that you typically want to sort the shapes from the smallest to the - largest. + of anchors per spatial location. They are assigned to the layers from the lowest (high-resolution) to the + highest (low-resolution) layer, meaning that you typically want to sort the shapes from the smallest to the + largest. matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that From b7a836e336b921bf1fed61c4fb8e39d1d5ca01dd Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 20 Apr 2023 09:45:42 +0300 Subject: [PATCH 11/13] Fixed target matching --- .../models/detection/target_matching.py | 61 ++++++++----------- torchvision/models/detection/yolo_networks.py | 1 + 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/torchvision/models/detection/target_matching.py b/torchvision/models/detection/target_matching.py index 43b512dc683..b28b89d3b33 100644 --- a/torchvision/models/detection/target_matching.py +++ b/torchvision/models/detection/target_matching.py @@ -11,7 +11,7 @@ PRIOR_SHAPES = List[List[int]] # TorchScript doesn't allow a list of tuples. -def target_boxes_to_grid(preds: Tensor, targets: Tensor, image_size: Tensor) -> Tensor: +def target_boxes_to_grid(preds: Tensor, targets: Tensor, image_size: Tensor) -> Tuple[Tensor, Tensor]: """Scales target bounding boxes to feature map coordinates. It would be better to implement this in a super class, but TorchScript doesn't allow class inheritance. @@ -22,7 +22,9 @@ def target_boxes_to_grid(preds: Tensor, targets: Tensor, image_size: Tensor) -> image_size: Input image width and height. Returns: - A tensor containing target x, y, width, and height in the feature map coordinates. + Two tensors with as many rows as there are targets. An integer tensor containing x/y coordinates to the feature + map that correspond to the target position, and a floating point tensor containing the target width and height + scaled to the feature map size. """ height, width = preds.shape[:2] @@ -33,11 +35,11 @@ def target_boxes_to_grid(preds: Tensor, targets: Tensor, image_size: Tensor) -> # Bounding box center coordinates are converted to the feature map dimensions so that the whole number tells the # cell index and the fractional part tells the location inside the cell. xywh = box_convert(targets, in_fmt="xyxy", out_fmt="cxcywh") - grid_xy = xywh[:, :2] * image_to_grid - cell_i = grid_xy[:, 0].to(torch.int64).clamp(0, width - 1) - cell_j = grid_xy[:, 1].to(torch.int64).clamp(0, height - 1) - - return torch.cat((cell_i.unsqueeze(1), cell_j.unsqueeze(1), xywh[:, 2:]), 1) + xy = (xywh[:, :2] * image_to_grid).to(torch.int64) + x = xy[:, 0].clamp(0, width - 1) + y = xy[:, 1].clamp(0, height - 1) + xy = torch.stack((x, y), 1) + return xy, xywh[:, 2:] class HighestIoUMatching: @@ -105,21 +107,18 @@ def __call__( Returns: The indices of the matched predictions, background mask, and a mask for selecting the matched targets. """ - scaled_targets = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) - target_selector, anchor_selector = self.match(scaled_targets[:, 2:]) - - scaled_targets = scaled_targets[target_selector] - cell_i = scaled_targets[:, 0] - cell_j = scaled_targets[:, 1] + anchor_xy, target_wh = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) + target_selector, anchor_idx = self.match(target_wh) + anchor_x = anchor_xy[target_selector, 0] + anchor_y = anchor_xy[target_selector, 1] # Background mask is used to select anchors that are not responsible for predicting any object, for # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a # predicted box overlaps any target significantly, or if a prediction is matched to a target. background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) - background_mask[cell_j, cell_i, anchor_selector] = False - - pred_selector = [cell_j, cell_i, anchor_selector] + background_mask[anchor_y, anchor_x, anchor_idx] = False + pred_selector = [anchor_y, anchor_x, anchor_idx] return pred_selector, background_mask, target_selector @@ -183,21 +182,18 @@ def __call__( Returns: The indices of the matched predictions, background mask, and a mask for selecting the matched targets. """ - scaled_targets = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) - target_selector, anchor_selector = self.match(scaled_targets[:, 2:]) - - scaled_targets = scaled_targets[target_selector] - cell_i = scaled_targets[:, 0] - cell_j = scaled_targets[:, 1] + anchor_xy, target_wh = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) + target_selector, anchor_idx = self.match(target_wh) + anchor_x = anchor_xy[target_selector, 0] + anchor_y = anchor_xy[target_selector, 1] # Background mask is used to select anchors that are not responsible for predicting any object, for # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a # predicted box overlaps any target significantly, or if a prediction is matched to a target. background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) - background_mask[cell_j, cell_i, anchor_selector] = False - - pred_selector = [cell_j, cell_i, anchor_selector] + background_mask[anchor_y, anchor_x, anchor_idx] = False + pred_selector = [anchor_y, anchor_x, anchor_idx] return pred_selector, background_mask, target_selector @@ -262,21 +258,18 @@ def __call__( Returns: The indices of the matched predictions, background mask, and a mask for selecting the matched targets. """ - scaled_targets = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) - target_selector, anchor_selector = self.match(scaled_targets[:, 2:]) - - scaled_targets = scaled_targets[target_selector] - cell_i = scaled_targets[:, 0] - cell_j = scaled_targets[:, 1] + anchor_xy, target_wh = target_boxes_to_grid(preds["boxes"], targets["boxes"], image_size) + target_selector, anchor_idx = self.match(target_wh) + anchor_x = anchor_xy[target_selector, 0] + anchor_y = anchor_xy[target_selector, 1] # Background mask is used to select anchors that are not responsible for predicting any object, for # calculating the part of the confidence loss with zero as the target confidence. It is set to False, if a # predicted box overlaps any target significantly, or if a prediction is matched to a target. background_mask = iou_below(preds["boxes"], targets["boxes"], self.ignore_bg_threshold) - background_mask[cell_j, cell_i, anchor_selector] = False - - pred_selector = [cell_j, cell_i, anchor_selector] + background_mask[anchor_y, anchor_x, anchor_idx] = False + pred_selector = [anchor_y, anchor_x, anchor_idx] return pred_selector, background_mask, target_selector diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py index 3765d585e7c..d51f4bfe54e 100644 --- a/torchvision/models/detection/yolo_networks.py +++ b/torchvision/models/detection/yolo_networks.py @@ -379,6 +379,7 @@ class DetectionStageWithAux(nn.Module): def __init__( self, spatial_range: float = 5.0, aux_spatial_range: float = 3.0, aux_weight: float = 0.25, **kwargs: Any ) -> None: + super().__init__() self.detection_layer = create_detection_layer(spatial_range=spatial_range, **kwargs) self.aux_detection_layer = create_detection_layer(spatial_range=aux_spatial_range, **kwargs) self.aux_weight = aux_weight From 37200dc9d59c30ef39a345dc999e0140595518d9 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Thu, 20 Apr 2023 16:08:54 +0300 Subject: [PATCH 12/13] Fixed DarknetNetwork TorchScript compilation --- torchvision/models/detection/yolo_networks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/models/detection/yolo_networks.py b/torchvision/models/detection/yolo_networks.py index d51f4bfe54e..dd395f95af5 100644 --- a/torchvision/models/detection/yolo_networks.py +++ b/torchvision/models/detection/yolo_networks.py @@ -1617,6 +1617,10 @@ def __init__( with open(weights_path) as weight_file: self.load_weights(weight_file) + # A workaround for TorchScript compilation. For some reason, the compilation will crash with "Unknown type name + # 'ShortcutLayer'" without this. + self._ = ShortcutLayer(0) + def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT: outputs: List[Tensor] = [] # Outputs from all layers detections: List[Tensor] = [] # Outputs from detection layers From ae30df455405fb56946425bf3f3c318280b0a7ae Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 1 May 2023 20:58:57 +0300 Subject: [PATCH 13/13] YOLO.forward() returns only losses when training --- references/detection/presets.py | 10 ++++++++++ torchvision/models/detection/yolo.py | 21 +++++++++++---------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/references/detection/presets.py b/references/detection/presets.py index 779f3f218ca..c13d0256e8d 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -53,6 +53,16 @@ def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104 T.ConvertImageDtype(torch.float), ] ) + elif data_augmentation == "yolo": + self.transforms = T.Compose( + [ + T.ScaleJitter(target_size=(640, 640)), + T.FixedSizeCrop(size=(640, 640), fill=mean), + T.RandomHorizontalFlip(p=hflip_prob), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), + ] + ) else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') diff --git a/torchvision/models/detection/yolo.py b/torchvision/models/detection/yolo.py index 3eb9d7a1c04..54ae7b9aee9 100644 --- a/torchvision/models/detection/yolo.py +++ b/torchvision/models/detection/yolo.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.nn as nn @@ -111,7 +111,7 @@ def __init__( def forward( self, images: Union[Tensor, IMAGES], targets: Optional[TARGETS] = None - ) -> Union[Tensor, Tuple[Tensor, Tensor, List[int]]]: + ) -> Union[Tensor, Dict[str, Tensor]]: """Runs a forward pass through the network (all layers listed in ``self.network``), and if training targets are provided, computes the losses from the detection layers. @@ -121,25 +121,26 @@ def forward( Args: images: A tensor of size ``[batch_size, channels, height, width]`` containing a batch of images or a list of image tensors. - targets: If given, computes losses from detection layers against these targets. A list of target - dictionaries, one for each image. + targets: Compute losses against these targets. A list of dictionaries, one for each image. Must be given in + training mode. Returns: - detections (:class:`~torch.Tensor`), losses (:class:`~torch.Tensor`), hits (List[int]): Detections, and if - targets were provided, a dictionary of losses. Detections are shaped ``[batch_size, anchors, classes + 5]``, - where ``anchors`` is the feature map size (width * height) times the number of anchors per cell. The - predicted box coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. + If targets are given, returns a dictionary containing the three losses (overlap, confidence, and + classification). Otherwise returns detections in a tensor shaped ``[batch_size, anchors, classes + 5]``, + where ``anchors`` is the total number of anchors in all detection layers. The number of anchors in a + detection layer is the feature map size (width * height) times the number of anchors per cell (usually 3 or + 4). The predicted box coordinates are in `(x1, y1, x2, y2)` format and scaled to the input image size. """ self.validate_batch(images, targets) images_tensor = images if isinstance(images, Tensor) else torch.stack(images) detections, losses, hits = self.network(images_tensor, targets) - detections = torch.cat(detections, 1) if targets is None: + detections = torch.cat(detections, 1) return detections losses = torch.stack(losses).sum(0) - return detections, losses, hits + return {"overlap": losses[0], "confidence": losses[1], "classification": losses[2]} def infer(self, image: Tensor) -> PRED: """Feeds an image to the network and returns the detected bounding boxes, confidence scores, and class