Skip to content

Commit

Permalink
fix a bug in cpp mask (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Jun 23, 2020
1 parent a47bbf1 commit af84878
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions example/cpp/bert_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,13 @@ struct BertModel::Impl {
for (size_t i = 0; i < inputs.size();
++i, iptr += max_seq_len, mptr += max_seq_len) {
auto &input = inputs[i];
// TODO(jiaruifang) Bert_Attention use mask value as 1 to indicate a valid
// position.
std::copy(input.begin(), input.end(), iptr);
std::fill(mptr, mptr + input.size(), 0);
std::fill(mptr, mptr + input.size(), 1);
if (input.size() != static_cast<size_t>(max_seq_len)) {
std::fill(iptr + input.size(), iptr + max_seq_len, 0);
std::fill(mptr + input.size(), mptr + max_seq_len, 1);
std::fill(mptr + input.size(), mptr + max_seq_len, 0);
}
}
if (device_type_ == DLDeviceType::kDLGPU) {
Expand Down

0 comments on commit af84878

Please sign in to comment.