Skip to content

Commit

Permalink
feat: enable DETR training
Browse files Browse the repository at this point in the history
  • Loading branch information
royale committed Nov 29, 2024
1 parent 0bbaf10 commit f888196
Show file tree
Hide file tree
Showing 7 changed files with 595 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(ddetect_SOURCES deepdetect.h deepdetect.cc mllibstrategy.h mlmodel.h
svminputfileconn.h svminputfileconn.cc txtinputfileconn.h
txtinputfileconn.cc apidata.h apidata.cc chain_actions.h chain_actions.cc
service_stats.h service_stats.cc chain.h chain.cc resources.cc ext/rmustache/mustache.h ext/rmustache/mustache.cc
utils/oatpp.cc dto/ddtypes.cc utils/db.cpp utils/db_lmdb.cpp ${CMAKE_BINARY_DIR}/src/caffe.pb.cc ${CMAKE_BINARY_DIR}/dd_config.cc)
utils/oatpp.cc dto/ddtypes.cc utils/db.cpp utils/db_lmdb.cpp utils/rectangular_lsap.cpp ${CMAKE_BINARY_DIR}/src/caffe.pb.cc ${CMAKE_BINARY_DIR}/dd_config.cc)

if (USE_JSON_API)
list(APPEND ddetect_SOURCES jsonapi.h jsonapi.cc)
Expand Down
2 changes: 2 additions & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ namespace dd
}
else if (_template == "detr")
{
_module._loss_id = 0;
_module._detr = true;
}
else if (!_template.empty())
{
Expand Down
48 changes: 48 additions & 0 deletions src/backends/torch/torchmodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "graph/graph.h"
#include "native/native.h"
#include "torchutils.h"
#include "utils/rectangular_lsap.h"

namespace dd
{
Expand Down Expand Up @@ -347,6 +348,9 @@ namespace dd
}
}

if (_training && _detr)
return detr_postprocess(source);

if (_training && _loss_id >= 0)
{
// if we are in training mode and model does output the loss (eg
Expand Down Expand Up @@ -751,6 +755,50 @@ namespace dd
_frozen_params_count = total_frozen_count;
}

c10::IValue TorchModule::detr_postprocess(std::vector<c10::IValue> &source)
{
// DETR matcher outputs a List[Tuple[Tensor, Tensor]]
// https://github.com/facebookresearch/detr/blob/main/models/matcher.py#L82
// which seems impossible/difficult to wrap in IValue
// https://github.com/pytorch/pytorch/issues/90398
// we output a List[Tensor] of 2D tensors instead and unwrap it later
std::vector<torch::Tensor> out_list;

// solve the linear_sum_assignment problems
// TODO: do it in parallel?
auto in_list_raw = source.at(3); // detr_indices
auto in_list = torch_utils::unwrap_c10_vector(in_list_raw);
for (auto &in_item_raw : in_list)
{
auto in_item = in_item_raw.toTensor().to(torch::kFloat64);
auto shape = in_item.sizes();
int rows = shape[0];
int cols = shape[1];
auto out_item = torch::zeros({ 2, cols }, torch::kInt64);
auto ret = scipy::solve_rectangular_linear_sum_assignment(
rows, cols, in_item.data_ptr<double>(), false,
out_item[0].data_ptr<int64_t>(), out_item[1].data_ptr<int64_t>());
if (ret)
throw MLLibBadParamException(
"detr_postprocess: linear_sum_assignment error");
out_list.push_back(out_item);
}

// call loss
if (!_traced)
throw MLLibBadParamException("detr_postprocess: model is not traced");
auto method = _traced->find_method("loss");
if (!method)
throw MLLibBadParamException("detr_postprocess: loss method not found");
auto output = (*method)({
source.at(1), // detr_outputs
source.at(2), // detr_targets
out_list // detr_indices
});
source = torch_utils::unwrap_c10_vector(output);
return source.at(0);
}

template void TorchModule::post_transform(
const std::string tmpl, const APIData &template_params,
const ImgTorchInputFileConn &inputc, const TorchModel &tmodel,
Expand Down
6 changes: 6 additions & 0 deletions src/backends/torch/torchmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ namespace dd
int _linear_in = 0; /**<id of the input of the final linear layer */
int _loss_id = -1; /**<id of the loss output. If >= 0, forward returns this
output only during training */
bool _detr = false;
bool _hidden_states = false; /**< Take BERT hidden states as input. */

unsigned int _nclasses = 0; /**< number of classes */
Expand Down Expand Up @@ -277,6 +278,11 @@ namespace dd
* load linear layer weights only from pt format
*/
void crnn_head_load();

/**
* DETR postprocessing
*/
c10::IValue detr_postprocess(std::vector<c10::IValue> &source);
};
}
#endif
Loading

0 comments on commit f888196

Please sign in to comment.