From f7d2dbf25aaa5491a3e3b8951ef779425c9823e1 Mon Sep 17 00:00:00 2001 From: Antoine Jacquet Date: Wed, 13 Nov 2024 14:44:10 +0100 Subject: [PATCH] feat: enable DETR and RT-DETRv2 training --- src/CMakeLists.txt | 2 +- src/backends/torch/torchlib.cc | 2 + src/backends/torch/torchmodule.cc | 49 +++++ src/backends/torch/torchmodule.h | 6 + src/utils/rectangular_lsap.cpp | 297 +++++++++++++++++++++++++++++ src/utils/rectangular_lsap.h | 63 ++++++ tools/torch/trace_detr.py | 220 +++++++++++++++++++-- tools/torch/trace_rtdetrv2.py | 307 ++++++++++++++++++++++++++++++ 8 files changed, 932 insertions(+), 14 deletions(-) create mode 100644 src/utils/rectangular_lsap.cpp create mode 100644 src/utils/rectangular_lsap.h create mode 100644 tools/torch/trace_rtdetrv2.py diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cad0bd617..8e6d3ec82 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,7 +24,7 @@ set(ddetect_SOURCES deepdetect.h deepdetect.cc mllibstrategy.h mlmodel.h svminputfileconn.h svminputfileconn.cc txtinputfileconn.h txtinputfileconn.cc apidata.h apidata.cc chain_actions.h chain_actions.cc service_stats.h service_stats.cc chain.h chain.cc resources.cc ext/rmustache/mustache.h ext/rmustache/mustache.cc - utils/oatpp.cc dto/ddtypes.cc utils/db.cpp utils/db_lmdb.cpp ${CMAKE_BINARY_DIR}/src/caffe.pb.cc ${CMAKE_BINARY_DIR}/dd_config.cc) + utils/oatpp.cc dto/ddtypes.cc utils/db.cpp utils/db_lmdb.cpp utils/rectangular_lsap.cpp ${CMAKE_BINARY_DIR}/src/caffe.pb.cc ${CMAKE_BINARY_DIR}/dd_config.cc) if (USE_JSON_API) list(APPEND ddetect_SOURCES jsonapi.h jsonapi.cc) diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 419969348..7178c0d10 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -468,6 +468,8 @@ namespace dd } else if (_template == "detr") { + _module._loss_id = 0; + _module._detr = true; } else if (!_template.empty()) { diff --git a/src/backends/torch/torchmodule.cc b/src/backends/torch/torchmodule.cc index 004024833..4012d84fc 100644 --- a/src/backends/torch/torchmodule.cc +++ b/src/backends/torch/torchmodule.cc @@ -24,6 +24,7 @@ #include "graph/graph.h" #include "native/native.h" #include "torchutils.h" +#include "utils/rectangular_lsap.h" namespace dd { @@ -347,6 +348,9 @@ namespace dd } } + if (_training && _detr) + return detr_postprocess(source); + if (_training && _loss_id >= 0) { // if we are in training mode and model does output the loss (eg @@ -751,6 +755,51 @@ namespace dd _frozen_params_count = total_frozen_count; } + c10::IValue TorchModule::detr_postprocess(std::vector &source) + { + // DETR matcher outputs a List[Tuple[Tensor, Tensor]] + // https://github.com/facebookresearch/detr/blob/main/models/matcher.py#L82 + // which seems impossible/difficult to wrap in IValue + // https://github.com/pytorch/pytorch/issues/90398 + // we output a List[Tensor] of 2D tensors instead and unwrap it later + std::vector out_list; + + // solve the linear_sum_assignment problems + // TODO: do it in parallel? + auto in_list_raw = source.at(3); // detr_indices + auto in_list = torch_utils::unwrap_c10_vector(in_list_raw); + for (auto &in_item_raw : in_list) + { + auto in_item = in_item_raw.toTensor().to(torch::kFloat64); + auto shape = in_item.sizes(); + int rows = shape[0]; + int cols = shape[1]; + auto out_item = torch::zeros({ 2, cols }, torch::kInt64); + auto ret = scipy::solve_rectangular_linear_sum_assignment( + rows, cols, in_item.data_ptr(), false, + out_item[0].data_ptr(), out_item[1].data_ptr()); + if (ret) + throw MLLibBadParamException( + "detr_postprocess: linear_sum_assignment error"); + out_list.push_back(out_item); + } + + // call loss + if (!_traced) + throw MLLibBadParamException("detr_postprocess: model is not traced"); + auto method = _traced->find_method("loss"); + if (!method) + throw MLLibBadParamException("detr_postprocess: loss method not found"); + auto output = (*method)({ + source.at(1), // detr_outputs + source.at(2), // detr_targets + out_list // detr_indices + }); + + // return a dictionary of losses, with total_loss key used as custom loss + return output; + } + template void TorchModule::post_transform( const std::string tmpl, const APIData &template_params, const ImgTorchInputFileConn &inputc, const TorchModel &tmodel, diff --git a/src/backends/torch/torchmodule.h b/src/backends/torch/torchmodule.h index 3ccb6e40e..b13ab578f 100644 --- a/src/backends/torch/torchmodule.h +++ b/src/backends/torch/torchmodule.h @@ -228,6 +228,7 @@ namespace dd int _linear_in = 0; /**= 0, forward returns this output only during training */ + bool _detr = false; bool _hidden_states = false; /**< Take BERT hidden states as input. */ unsigned int _nclasses = 0; /**< number of classes */ @@ -277,6 +278,11 @@ namespace dd * load linear layer weights only from pt format */ void crnn_head_load(); + + /** + * DETR postprocessing + */ + c10::IValue detr_postprocess(std::vector &source); }; } #endif diff --git a/src/utils/rectangular_lsap.cpp b/src/utils/rectangular_lsap.cpp new file mode 100644 index 000000000..2625afc45 --- /dev/null +++ b/src/utils/rectangular_lsap.cpp @@ -0,0 +1,297 @@ +// from https://github.com/scipy/scipy + +/* +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +This code implements the shortest augmenting path algorithm for the +rectangular assignment problem. This implementation is based on the +pseudocode described in pages 1685-1686 of: + + DF Crouse. On implementing 2D rectangular assignment algorithms. + IEEE Transactions on Aerospace and Electronic Systems + 52(4):1679-1696, August 2016 + doi: 10.1109/TAES.2016.140952 + +Author: PM Larsen +*/ + +#include +#include +#include +#include +#include "rectangular_lsap.h" + +namespace dd +{ + namespace scipy + { + + template + std::vector argsort_iter(const std::vector &v) + { + std::vector index(v.size()); + std::iota(index.begin(), index.end(), 0); + std::sort(index.begin(), index.end(), + [&v](intptr_t i, intptr_t j) { return v[i] < v[j]; }); + return index; + } + + static intptr_t + augmenting_path(intptr_t nc, double *cost, std::vector &u, + std::vector &v, std::vector &path, + std::vector &row4col, + std::vector &shortestPathCosts, intptr_t i, + std::vector &SR, std::vector &SC, + std::vector &remaining, double *p_minVal) + { + double minVal = 0; + + // Crouse's pseudocode uses set complements to keep track of remaining + // nodes. Here we use a vector, as it is more efficient in C++. + intptr_t num_remaining = nc; + for (intptr_t it = 0; it < nc; it++) + { + // Filling this up in reverse order ensures that the solution of a + // constant cost matrix is the identity matrix (c.f. #11602). + remaining[it] = nc - it - 1; + } + + std::fill(SR.begin(), SR.end(), false); + std::fill(SC.begin(), SC.end(), false); + std::fill(shortestPathCosts.begin(), shortestPathCosts.end(), INFINITY); + + // find shortest augmenting path + intptr_t sink = -1; + while (sink == -1) + { + + intptr_t index = -1; + double lowest = INFINITY; + SR[i] = true; + + for (intptr_t it = 0; it < num_remaining; it++) + { + intptr_t j = remaining[it]; + + double r = minVal + cost[i * nc + j] - u[i] - v[j]; + if (r < shortestPathCosts[j]) + { + path[j] = i; + shortestPathCosts[j] = r; + } + + // When multiple nodes have the minimum cost, we select one which + // gives us a new sink node. This is particularly important for + // integer cost matrices with small co-efficients. + if (shortestPathCosts[j] < lowest + || (shortestPathCosts[j] == lowest && row4col[j] == -1)) + { + lowest = shortestPathCosts[j]; + index = it; + } + } + + minVal = lowest; + if (minVal == INFINITY) + { // infeasible cost matrix + return -1; + } + + intptr_t j = remaining[index]; + if (row4col[j] == -1) + { + sink = j; + } + else + { + i = row4col[j]; + } + + SC[j] = true; + remaining[index] = remaining[--num_remaining]; + } + + *p_minVal = minVal; + return sink; + } + + static int solve(intptr_t nr, intptr_t nc, double *cost, bool maximize, + int64_t *a, int64_t *b) + { + // handle trivial inputs + if (nr == 0 || nc == 0) + { + return 0; + } + + // tall rectangular cost matrix must be transposed + bool transpose = nc < nr; + + // make a copy of the cost matrix if we need to modify it + std::vector temp; + if (transpose || maximize) + { + temp.resize(nr * nc); + + if (transpose) + { + for (intptr_t i = 0; i < nr; i++) + { + for (intptr_t j = 0; j < nc; j++) + { + temp[j * nr + i] = cost[i * nc + j]; + } + } + + std::swap(nr, nc); + } + else + { + std::copy(cost, cost + nr * nc, temp.begin()); + } + + // negate cost matrix for maximization + if (maximize) + { + for (intptr_t i = 0; i < nr * nc; i++) + { + temp[i] = -temp[i]; + } + } + + cost = temp.data(); + } + + // test for NaN and -inf entries + for (intptr_t i = 0; i < nr * nc; i++) + { + if (cost[i] != cost[i] || cost[i] == -INFINITY) + { + return RECTANGULAR_LSAP_INVALID; + } + } + + // initialize variables + std::vector u(nr, 0); + std::vector v(nc, 0); + std::vector shortestPathCosts(nc); + std::vector path(nc, -1); + std::vector col4row(nr, -1); + std::vector row4col(nc, -1); + std::vector SR(nr); + std::vector SC(nc); + std::vector remaining(nc); + + // iteratively build the solution + for (intptr_t curRow = 0; curRow < nr; curRow++) + { + + double minVal; + intptr_t sink = augmenting_path(nc, cost, u, v, path, row4col, + shortestPathCosts, curRow, SR, SC, + remaining, &minVal); + if (sink < 0) + { + return RECTANGULAR_LSAP_INFEASIBLE; + } + + // update dual variables + u[curRow] += minVal; + for (intptr_t i = 0; i < nr; i++) + { + if (SR[i] && i != curRow) + { + u[i] += minVal - shortestPathCosts[col4row[i]]; + } + } + + for (intptr_t j = 0; j < nc; j++) + { + if (SC[j]) + { + v[j] -= minVal - shortestPathCosts[j]; + } + } + + // augment previous solution + intptr_t j = sink; + while (1) + { + intptr_t i = path[j]; + row4col[j] = i; + std::swap(col4row[i], j); + if (i == curRow) + { + break; + } + } + } + + if (transpose) + { + intptr_t i = 0; + for (auto v : argsort_iter(col4row)) + { + a[i] = col4row[v]; + b[i] = v; + i++; + } + } + else + { + for (intptr_t i = 0; i < nr; i++) + { + a[i] = i; + b[i] = col4row[i]; + } + } + + return 0; + } + +#ifdef __cplusplus + extern "C" + { +#endif + + int solve_rectangular_linear_sum_assignment(intptr_t nr, intptr_t nc, + double *input_cost, + bool maximize, int64_t *a, + int64_t *b) + { + return solve(nr, nc, input_cost, maximize, a, b); + } + +#ifdef __cplusplus + } +#endif + + } +} diff --git a/src/utils/rectangular_lsap.h b/src/utils/rectangular_lsap.h new file mode 100644 index 000000000..14caec954 --- /dev/null +++ b/src/utils/rectangular_lsap.h @@ -0,0 +1,63 @@ +// from https://github.com/scipy/scipy + +/* +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef RECTANGULAR_LSAP_H +#define RECTANGULAR_LSAP_H + +#define RECTANGULAR_LSAP_INFEASIBLE -1 +#define RECTANGULAR_LSAP_INVALID -2 + +namespace dd +{ + namespace scipy + { + +#ifdef __cplusplus + extern "C" + { +#endif + +#include +#include + + int solve_rectangular_linear_sum_assignment(intptr_t nr, intptr_t nc, + double *input_cost, + bool maximize, int64_t *a, + int64_t *b); + +#ifdef __cplusplus + } +#endif + +#endif + } +} diff --git a/tools/torch/trace_detr.py b/tools/torch/trace_detr.py index 5936d5b98..36aeef741 100644 --- a/tools/torch/trace_detr.py +++ b/tools/torch/trace_detr.py @@ -28,7 +28,6 @@ def box_cxcywh_to_xyxy(self, x): (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) - @torch.no_grad() def forward(self, out_logits, out_bbox, target_sizes): """ Perform the computation Parameters: @@ -42,6 +41,8 @@ def forward(self, out_logits, out_bbox, target_sizes): prob = self.softmax(out_logits) scores, labels = prob[..., :-1].max(-1) + # DETR classes start at 0 + labels += 1 # convert to [x0, y0, x1, y1] format boxes = self.box_cxcywh_to_xyxy(out_bbox).cpu() @@ -52,35 +53,115 @@ def forward(self, out_logits, out_bbox, target_sizes): results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] return results +class PostProcessTrain(torch.nn.Module): + def __init__(self, criterion): + super().__init__() + self.criterion = criterion + + def box_xyxy_to_cxcywh(self, x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) * 0.5, (y0 + y1) * 0.5, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + def forward(self, detr_outputs, target_sizes, ids, bboxes, labels): + # type: (Dict[str, Tensor], Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) + assert ids is not None + assert bboxes is not None + assert labels is not None + + # convert DD bboxes to DETR bboxes instead, because of the L1 loss here: + # https://github.com/facebookresearch/detr/blob/main/models/detr.py#L153 + + # assume all images in the batch are the same size + img_h, img_w = target_sizes[0].unbind(0) + + # convert to [xc, yc, w, h] format + bboxes = self.box_xyxy_to_cxcywh(bboxes) + + # and to relative [0, 1] coordinates + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=0).cuda() + bboxes = bboxes / scale_fct + + # convert ids, bboxes, labels to DETR targets + # DD uses ids, DETR expects lists of boxes, labels + detr_targets : List[Dict[str, torch.Tensor]] = [] + # DETR classes start at 0 + labels -= 1 + batch_size = target_sizes.shape[0] + count = torch.bincount(ids, minlength=batch_size) + start = 0 + for i in range(batch_size): + stop = start + count[i] + target = { + "labels": labels[start:stop], + "boxes": bboxes[start:stop], + } + detr_targets.append(target) + start = stop + + with torch.no_grad(): + detr_indices = self.criterion.matcher(detr_outputs, detr_targets) + + return detr_targets, detr_indices + class WrappedDETR(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, criterion): super().__init__() self.model = model + self.criterion = criterion self.pp = PostProcess() - - def forward(self, x): + self.pp_train = PostProcessTrain(criterion) + + def forward(self, x, ids=None, bboxes=None, labels=None): + # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) """ x: one image of dimensions [batch size, channel count, width, height] """ l_x = [x[i] for i in range(x.shape[0])] sample = nested_tensor_from_tensor_list(l_x) - output = self.model(sample) + + # default output placeholders + dd_outputs = [{"dummy": torch.zeros((0, ))}] + detr_outputs = self.model(sample) + detr_indices = [torch.zeros((0, ))] + detr_targets : List[Dict[str, torch.Tensor]] = [] + image_sizes = torch.zeros([len(l_x),2]).cpu() i = 0 for x in l_x: image_sizes[i][0] = x.shape[1] image_sizes[i][1] = x.shape[2] i += 1 - - # converting detr to torchvision detection format - processed_output = self.pp(output['pred_logits'], output['pred_boxes'], image_sizes) - return processed_output + + if self.training: + detr_targets, detr_indices = self.pp_train(detr_outputs, image_sizes, ids, bboxes, labels) + else: + with torch.no_grad(): + # converting detr to torchvision detection format + dd_outputs = self.pp(detr_outputs['pred_logits'], detr_outputs['pred_boxes'], image_sizes) + + return dd_outputs, detr_outputs, detr_targets, detr_indices + + @torch.jit.export + def loss(self, outputs, targets, indices): + # type: (Dict[str, Tensor], List[Dict[str, Tensor]], List[Tensor]) + # convert List[Tensor] of 2D tensors indices to List[Tuple[Tensor, Tensor]] as expected by DETR criterion + indices = [(x[0], x[1]) for x in indices] + losses = self.criterion(outputs, targets, indices) + weights = self.criterion.weight_dict + losses = {k: losses[k] * weights[k] for k in losses.keys()} + # DD expects a total_loss key as the model loss + losses["total_loss"] = torch.stack(losses.values()).sum() + return losses parser = argparse.ArgumentParser(description="Trace DETR model") parser.add_argument('--model-in-file',help='path to model .pth file') parser.add_argument('--dataset-file',type=str,help='unused',default='coco') parser.add_argument('--device',default='cuda',help='device used for inference') parser.add_argument('--path-to-detr',help='path to detr repository',required=True) +parser.add_argument("--num_classes", type=int, default=91 + 1, help="Number of classes of the model") +parser.add_argument('-o', "--output-dir", default=".", type=str, help="Output directory for traced models") # * Training parser.add_argument('--lr', default=1e-4, type=float) @@ -140,6 +221,8 @@ def forward(self, x): help="Relative classification weight of the no-object class") args = parser.parse_args() +# DETR already reserves a no-object class +args.num_classes -= 1 sys.path.append(args.path_to_detr) import models @@ -148,9 +231,21 @@ def forward(self, x): model_without_ddp, criterion, postprocessors = build_model(args) model_without_ddp.eval() -checkpoint = torch.load(args.model_in_file,map_location='cpu') -model_without_ddp.load_state_dict(checkpoint['model']) -model_without_ddp = WrappedDETR(model_without_ddp) + +if args.model_in_file: + checkpoint = torch.load(args.model_in_file,map_location='cpu') + # handle pretrained with different number of classes + # https://github.com/facebookresearch/detr/issues/9#issuecomment-636391562 + if checkpoint['model']['class_embed.bias'].shape[0] != args.num_classes + 1: + print('pretrained used different num_classes, removing class_embed') + del checkpoint['model']['class_embed.weight'] + del checkpoint['model']['class_embed.bias'] + if checkpoint['model']['query_embed.weight'].shape[0] != args.num_queries: + print('pretrained used different num_queries, removing query_embed') + del checkpoint['model']['query_embed.weight'] + model_without_ddp.load_state_dict(checkpoint['model'], strict=False) + +model_without_ddp = WrappedDETR(model_without_ddp, criterion) model_without_ddp.cuda() ## code for inference / testing @@ -159,7 +254,106 @@ def forward(self, x): #output = model_without_ddp(image_loader(data_transforms, '/home/beniz/bus.jpg')) #print(output) +filename = os.path.join( + args.output_dir, + "detr_" + + args.backbone + + "_cls" + + str(args.num_classes + 1) + + "_queries" + + str(args.num_queries) + + ("_pretrained" if args.model_in_file else "") + + ".pt", +) + print('Attempting jit export...') model_jit = torch.jit.script(model_without_ddp) -model_jit.save(args.model_in_file.replace('.pth','.pt')) +model_jit.save(filename) print('jit detr export successful') +quit() + +# TODO: remove? some debug code below +from scipy.optimize import linear_sum_assignment +from PIL import Image +def image_loader(loader, image_name): + image = Image.open(image_name) + image = loader(image).float() + image = torch.tensor(image, requires_grad=False) + image = image.unsqueeze(0) + return image + +## code for inference / testing +print('predict on image\n') +data_transforms = transforms.Compose([transforms.ToTensor()]) +images = image_loader(data_transforms, '/home/royale/dog.jpg').cuda() + +# targets in DD format +ids = torch.tensor([0]).cuda() +labels = torch.tensor([18]).cuda() +boxes = torch.tensor([[30.0, 20.0, 330.0, 240.0]]).cuda() + +if 0: + # check train with bs=2 and various size bboxes + images = torch.rand((2, 3, 100, 100)).cuda() + ids = torch.tensor([0, 1, 1]).cuda() + labels = torch.tensor([10, 11, 11]).cuda() + boxes = torch.tensor([ + [10., 10., 10., 10.], + [21., 21., 21., 21.], + [32., 32., 32., 32.], + ]).cuda() + +if 0: + # check torch eval + model_without_ddp.eval() + outputs = model_without_ddp(image)[0] + scores, labels, boxes = outputs.values() + n = scores.argmax() + print("max", n) + print("label", labels[n]) + print("bbox", boxes[n]) + quit() + +if 0: + # check torch train + model_without_ddp.train() + outputs = model_without_ddp(image, ids, boxes, labels) + quit() + +if 0: + # check jit eval + model_jit.eval() + dd_outputs, _, _, _ = model_jit(images) + scores, labels, boxes = dd_outputs[0].values() + n = scores.argmax() + print("max", n) + print("label", labels[n]) + print("bbox", boxes[n]) + quit() + +# check jit train +model_jit.train() +_, detr_outputs, detr_targets, detr_indices = model_jit.forward(images, ids, boxes, labels) +with torch.no_grad(): + print("detr_indices", detr_indices) + print("len detr_indices", len(detr_indices)) + print("detr_index.shape", detr_indices[0].shape) + detr_indices = [torch.tensor(linear_sum_assignment(i)) for i in detr_indices] + print(detr_indices) + print(detr_indices[0].shape) + print(detr_indices[0].dtype) + print("linear_sum_assignment", detr_indices) + #detr_indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in detr_indices] +#losses = model_jit.criterion(detr_outputs, detr_targets, detr_indices) +losses = model_jit.loss(detr_outputs, detr_targets, detr_indices) +from pprint import pprint +pprint(losses) + +# check backward step +if 0: + loss = sum(x for x in losses.values()) + print(loss) + layer = list(model_jit.named_parameters())[0][1] + print("grad before", layer.grad) + loss.backward() + print("grad after", layer.grad) diff --git a/tools/torch/trace_rtdetrv2.py b/tools/torch/trace_rtdetrv2.py new file mode 100644 index 000000000..db0a32b58 --- /dev/null +++ b/tools/torch/trace_rtdetrv2.py @@ -0,0 +1,307 @@ +import os +import sys +import argparse +import torch + + +class WrappedRTDETR(torch.nn.Module): + def __init__(self, cfg): + super().__init__() + self.model = cfg.model.deploy() + self.postprocessor = cfg.postprocessor.deploy() + self.criterion = cfg.criterion + + def box_xyxy_to_cxcywh(self, x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) * 0.5, (y0 + y1) * 0.5, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + def dd_targets_to_detr_targets(self, ids, bboxes, labels, target_sizes): + # type: (Optional[Tensor], Optional[Tensor], Optional[Tensor], Tensor) + assert ids is not None + assert bboxes is not None + assert labels is not None + + # convert DD bboxes to DETR bboxes instead, because of the L1 loss here: + # https://github.com/facebookresearch/detr/blob/main/models/detr.py#L153 + + # assume all images in the batch are the same size + img_h, img_w = target_sizes[0].unbind(0) + + # convert to [xc, yc, w, h] format + bboxes = self.box_xyxy_to_cxcywh(bboxes) + + # and to relative [0, 1] coordinates + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=0).cuda() + bboxes = bboxes / scale_fct + + # convert ids, bboxes, labels to DETR targets + # DD uses ids, DETR expects lists of boxes, labels + detr_targets: List[Dict[str, torch.Tensor]] = [] + # DETR classes start at 0 + labels -= 1 + batch_size = target_sizes.shape[0] + count = torch.bincount(ids, minlength=batch_size) + start = 0 + for i in range(batch_size): + stop = start + count[i] + target = { + "labels": labels[start:stop], + "boxes": bboxes[start:stop], + } + detr_targets.append(target) + start = stop + + return detr_targets + + def detr_outputs_to_dd_outputs(self, outputs, target_sizes): + # type: (Dict[str, Tensor], Tensor) + labels, boxes, scores = self.postprocessor(outputs, target_sizes) + # DETR classes start at 0 + labels += 1 + results = [ + {"scores": s, "labels": l, "boxes": b} + for s, l, b in zip(scores, labels, boxes) + ] + return results + + def forward(self, x, ids=None, bboxes=None, labels=None): + # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) + """ + x: one image of dimensions [batch size, channel count, width, height] + """ + l_x = [x[i] for i in range(x.shape[0])] + sample = x + image_sizes = torch.zeros([len(l_x), 2]).cuda() + i = 0 + for x in l_x: + image_sizes[i][0] = x.shape[1] + image_sizes[i][1] = x.shape[2] + i += 1 + + # default placeholders + dd_outputs = [{"dummy": torch.zeros((0,))}] + detr_targets: List[Dict[str, torch.Tensor]] = [] + detr_indices = [torch.zeros((0,))] + + # get targets + if self.training: + detr_targets = self.dd_targets_to_detr_targets( + ids, bboxes, labels, image_sizes + ) + + # forward with the targets + detr_outputs = self.model(sample, detr_targets) + + if self.training: + with torch.no_grad(): + # do the match + detr_indices = self.criterion.matcher(detr_outputs, detr_targets) + else: + with torch.no_grad(): + dd_outputs = self.detr_outputs_to_dd_outputs(detr_outputs, image_sizes) + # converting detr to torchvision detection format + # dd_outputs = self.pp(detr_outputs['pred_logits'], detr_outputs['pred_boxes'], image_sizes) + + return dd_outputs, detr_outputs, detr_targets, detr_indices + + @torch.jit.export + def loss(self, outputs, targets, indices): + # type: (Dict[str, Tensor], List[Dict[str, Tensor]], List[Tensor]) + # convert List[Tensor] of 2D tensors indices to List[Tuple[Tensor, Tensor]] as expected by DETR criterion + indices = [(x[0], x[1]) for x in indices] + losses = self.criterion(outputs, targets, indices) + weights = self.criterion.weight_dict + losses = {k: losses[k] * weights[k] for k in losses.keys()} + # DD expects a total_loss key as the model loss + losses["total_loss"] = torch.stack(list(losses.values())).sum() + return losses + + +# map model names to their config files +configs = { + # base models + "rtdetrv2_s": "rtdetrv2_r18vd_120e_coco.yml", + "rtdetrv2_m": "rtdetrv2_r50vd_m_7x_coco.yml", + "rtdetrv2_l": "rtdetrv2_r50vd_6x_coco.yml", + "rtdetrv2_x": "rtdetrv2/rtdetrv2_r101vd_6x_coco.yml", + # discrete sampling + "rtdetrv2_s_dsp": "rtdetrv2_r18vd_dsp_3x_coco.yml", + "rtdetrv2_m_dsp": "rtdetrv2_r50vd_m_dsp_3x_coco.yml", + "rtdetrv2_l_dsp": "rtdetrv2_r50vd_dsp_1x_coco.yml", +} + +parser = argparse.ArgumentParser(description="Trace RT-DETR model") +parser.add_argument("model", type=str, help="Model to export", choices=configs.keys()) +parser.add_argument( + "--path-to-rtdetrv2", help="path to rtdetrv2 repository", required=True +) +parser.add_argument("--model-in-file", help="path to model .pth file") +parser.add_argument( + "-o", + "--output-dir", + default=".", + type=str, + help="Output directory for traced models", +) +parser.add_argument( + "--num_classes", type=int, default=81, help="Number of classes of the model" +) +# parser.add_argument( +# "--num_queries", default=300, type=int, help="Number of query slots" +# ) +args = parser.parse_args() + +# DETR already reserves a no-object class +args.num_classes -= 1 + +# load model +sys.path.append(args.path_to_rtdetrv2) +from src.core import YAMLConfig + +# TODO handle cfg eval_spatial_size, 640x640 by default +config = args.path_to_rtdetrv2 + "/configs/rtdetrv2/" + configs[args.model] + +# from https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetrv2_pytorch/tools/export_onnx.py +cfg = YAMLConfig( + config, + resume=args.model_in_file, + num_classes=args.num_classes, + PResNet={"pretrained": args.model_in_file is not None}, + # RTDETRTransformerv2={"num_queries": args.num_queries}, + # RTDETRPostProcessor={"num_top_queries": args.num_queries}, +) + +# load checkpoint +if args.model_in_file: + checkpoint = torch.load(args.model_in_file, map_location="cpu") + if "ema" in checkpoint: + state = checkpoint["ema"]["module"] + else: + state = checkpoint["model"] + + # handle keys that moved due to tracing + state = { + k.replace("decoder.query_pos_head", "decoder.decoder.query_pos_head") + .replace("decoder.dec_bbox_head", "decoder.decoder.bbox_head") + .replace("decoder.dec_score_head", "decoder.decoder.score_head"): v + for k, v in state.items() + } + + # remove keys incompatible with num_classes + if args.num_classes != 80: + state = { + k: v + for k, v in state.items() + if not any( + k.startswith(x) + for x in [ + "decoder.denoising_class_embed", + "decoder.enc_score_head", + "decoder.decoder.score_head", + ] + ) + } + + cfg.model.load_state_dict(state, strict=False) + +# wrap model +model = WrappedRTDETR(cfg) +model.cuda() +model.eval() +filename = os.path.join( + args.output_dir, + args.model + "_cls" + str(args.num_classes + 1) + # + "_queries" + # + str(args.num_queries) + + ("_pretrained" if args.model_in_file else "") + ".pt", +) +print("Attempting jit export...") +model_jit = torch.jit.script(model) +model_jit.save(filename) +print("jit detr export successful") +quit() + + +# TODO: remove? some debug code below +from PIL import Image +from torchvision import transforms + + +def image_loader(loader, image_name): + image = Image.open(image_name) + image = loader(image).float() + image = torch.tensor(image, requires_grad=False) + image = image.unsqueeze(0) + return image + + +## code for inference / testing +print("predict on image\n") +data_transforms = transforms.Compose( + [transforms.Resize((640, 640)), transforms.ToTensor()] +) +images = image_loader(data_transforms, "/home/royale/dog.jpg").cuda() +size = torch.tensor([[640, 640]]) + +# targets in DD format +ids = torch.tensor([0]).cuda() +labels = torch.tensor([18]).cuda() +boxes = torch.tensor([[30.0, 20.0, 330.0, 240.0]]).cuda() + +from scipy.optimize import linear_sum_assignment + +# torch inference +if 0: + model.eval() + outputs = model(images, ids, boxes, labels) + print(outputs) + +# torch train +if 0: + model.train() + _, detr_outputs, detr_targets, detr_indices = model.forward( + images, ids, boxes, labels + ) + with torch.no_grad(): + print("detr_indices", detr_indices) + print("len detr_indices", len(detr_indices)) + print("detr_index.shape", detr_indices[0].shape) + detr_indices = [torch.tensor(linear_sum_assignment(i)) for i in detr_indices] + print(detr_indices) + print(detr_indices[0].shape) + print(detr_indices[0].dtype) + print("linear_sum_assignment", detr_indices) + # detr_indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in detr_indices] + losses = model.loss(detr_outputs, detr_targets, detr_indices) + from pprint import pprint + + pprint(losses) + quit() + +# jit inference +if 0: + model_jit.eval() + outputs = model_jit(images, ids, boxes, labels) + print(outputs) + +# jit train +if 0: + model_jit.train() + _, detr_outputs, detr_targets, detr_indices = model_jit.forward( + images, ids, boxes, labels + ) + with torch.no_grad(): + print("detr_indices", detr_indices) + print("len detr_indices", len(detr_indices)) + print("detr_index.shape", detr_indices[0].shape) + detr_indices = [torch.tensor(linear_sum_assignment(i)) for i in detr_indices] + print(detr_indices) + print(detr_indices[0].shape) + print(detr_indices[0].dtype) + print("linear_sum_assignment", detr_indices) + # detr_indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in detr_indices] + losses = model_jit.loss(detr_outputs, detr_targets, detr_indices) + from pprint import pprint + + pprint(losses)