Skip to content

Commit

Permalink
support GPU on torchscript
Browse files Browse the repository at this point in the history
  • Loading branch information
taoleicn committed May 12, 2021
1 parent 6e0038e commit 69d28cd
Show file tree
Hide file tree
Showing 6 changed files with 1,406 additions and 523 deletions.
8 changes: 4 additions & 4 deletions sru/csrc/sru_cpu_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ std::vector<at::Tensor> cpu_forward(
const float* reset_b_ptr = forget_b_ptr + hidden_size;
const float* U_ptr = U.data_ptr<float>();
const float* x_ptr = x.data_ptr<float>();
const float* pad_ptr = mask_pad.has_value() ?
mask_pad.value().data_ptr<float>() : NULL;
const bool* pad_ptr = mask_pad.has_value() ?
mask_pad.value().data_ptr<bool>() : NULL;

auto h = at::zeros({length, batch_size, hidden_size}, U.options());
auto c = c_init.clone();
Expand Down Expand Up @@ -175,8 +175,8 @@ std::vector<at::Tensor> cpu_bi_forward(
const float* reset_b_ptr = forget_b_ptr + hidden_size*2;
const float* U_ptr = U.data_ptr<float>();
const float* x_ptr = x.data_ptr<float>();
const float* pad_ptr = mask_pad.has_value() ?
mask_pad.value().data_ptr<float>() : NULL;
const bool* pad_ptr = mask_pad.has_value() ?
mask_pad.value().data_ptr<bool>() : NULL;

auto h = at::zeros({length, batch_size, hidden_size*2}, U.options());
auto c = c_init.clone();
Expand Down
Loading

0 comments on commit 69d28cd

Please sign in to comment.