Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable DETR and RT-DETRv2 training #1574

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(ddetect_SOURCES deepdetect.h deepdetect.cc mllibstrategy.h mlmodel.h
svminputfileconn.h svminputfileconn.cc txtinputfileconn.h
txtinputfileconn.cc apidata.h apidata.cc chain_actions.h chain_actions.cc
service_stats.h service_stats.cc chain.h chain.cc resources.cc ext/rmustache/mustache.h ext/rmustache/mustache.cc
utils/oatpp.cc dto/ddtypes.cc utils/db.cpp utils/db_lmdb.cpp ${CMAKE_BINARY_DIR}/src/caffe.pb.cc ${CMAKE_BINARY_DIR}/dd_config.cc)
utils/oatpp.cc dto/ddtypes.cc utils/db.cpp utils/db_lmdb.cpp utils/rectangular_lsap.cpp ${CMAKE_BINARY_DIR}/src/caffe.pb.cc ${CMAKE_BINARY_DIR}/dd_config.cc)

if (USE_JSON_API)
list(APPEND ddetect_SOURCES jsonapi.h jsonapi.cc)
Expand Down
2 changes: 2 additions & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ namespace dd
}
else if (_template == "detr")
{
_module._loss_id = 0;
_module._detr = true;
}
else if (!_template.empty())
{
Expand Down
49 changes: 49 additions & 0 deletions src/backends/torch/torchmodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "graph/graph.h"
#include "native/native.h"
#include "torchutils.h"
#include "utils/rectangular_lsap.h"

namespace dd
{
Expand Down Expand Up @@ -347,6 +348,9 @@ namespace dd
}
}

if (_training && _detr)
return detr_postprocess(source);

if (_training && _loss_id >= 0)
{
// if we are in training mode and model does output the loss (eg
Expand Down Expand Up @@ -751,6 +755,51 @@ namespace dd
_frozen_params_count = total_frozen_count;
}

c10::IValue TorchModule::detr_postprocess(std::vector<c10::IValue> &source)
{
// DETR matcher outputs a List[Tuple[Tensor, Tensor]]
// https://github.com/facebookresearch/detr/blob/main/models/matcher.py#L82
// which seems impossible/difficult to wrap in IValue
// https://github.com/pytorch/pytorch/issues/90398
// we output a List[Tensor] of 2D tensors instead and unwrap it later
std::vector<torch::Tensor> out_list;

// solve the linear_sum_assignment problems
// TODO: do it in parallel?
auto in_list_raw = source.at(3); // detr_indices
auto in_list = torch_utils::unwrap_c10_vector(in_list_raw);
for (auto &in_item_raw : in_list)
{
auto in_item = in_item_raw.toTensor().to(torch::kFloat64);
auto shape = in_item.sizes();
int rows = shape[0];
int cols = shape[1];
auto out_item = torch::zeros({ 2, cols }, torch::kInt64);
auto ret = scipy::solve_rectangular_linear_sum_assignment(
rows, cols, in_item.data_ptr<double>(), false,
out_item[0].data_ptr<int64_t>(), out_item[1].data_ptr<int64_t>());
if (ret)
throw MLLibBadParamException(
"detr_postprocess: linear_sum_assignment error");
out_list.push_back(out_item);
}

// call loss
if (!_traced)
throw MLLibBadParamException("detr_postprocess: model is not traced");
auto method = _traced->find_method("loss");
if (!method)
throw MLLibBadParamException("detr_postprocess: loss method not found");
auto output = (*method)({
source.at(1), // detr_outputs
source.at(2), // detr_targets
out_list // detr_indices
});

// return a dictionary of losses, with total_loss key used as custom loss
return output;
}

template void TorchModule::post_transform(
const std::string tmpl, const APIData &template_params,
const ImgTorchInputFileConn &inputc, const TorchModel &tmodel,
Expand Down
6 changes: 6 additions & 0 deletions src/backends/torch/torchmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ namespace dd
int _linear_in = 0; /**<id of the input of the final linear layer */
int _loss_id = -1; /**<id of the loss output. If >= 0, forward returns this
output only during training */
bool _detr = false;
bool _hidden_states = false; /**< Take BERT hidden states as input. */

unsigned int _nclasses = 0; /**< number of classes */
Expand Down Expand Up @@ -277,6 +278,11 @@ namespace dd
* load linear layer weights only from pt format
*/
void crnn_head_load();

/**
* DETR postprocessing
*/
c10::IValue detr_postprocess(std::vector<c10::IValue> &source);
};
}
#endif
Loading