Skip to content

Commit

Permalink
Merge pull request #188 from asappresearch/torchscript_gpu_v2.5
Browse files Browse the repository at this point in the history
support GPU inference in torchscript model for v2.5 / v2.6
  • Loading branch information
taoleicn authored May 18, 2021
2 parents 6e0038e + 32d9ebb commit a698784
Show file tree
Hide file tree
Showing 12 changed files with 1,714 additions and 544 deletions.
9 changes: 9 additions & 0 deletions sru/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,20 @@ target_compile_features(sru_cpu PRIVATE cxx_std_14)
# Link against LibTorch
target_link_libraries(sru_cpu "${TORCH_LIBRARIES}")

# Define our library target
add_library(sru_cuda SHARED sru_cuda_impl_dummy.cpp)
# Enable C++14
target_compile_features(sru_cuda PRIVATE cxx_std_14)
# Link against LibTorch
target_link_libraries(sru_cuda "${TORCH_LIBRARIES}")

add_executable(example_app main_test_cpp.cpp)
target_link_libraries(example_app "${TORCH_LIBRARIES}")
if (UNIX AND NOT APPLE)
target_link_libraries(example_app -Wl,--no-as-needed sru_cpu)
target_link_libraries(example_app -Wl,--no-as-needed sru_cuda)
else()
target_link_libraries(example_app -Wl,-all_load sru_cpu)
target_link_libraries(example_app -Wl,-all_load sru_cuda)
endif()
target_compile_features(example_app PRIVATE cxx_std_14)
8 changes: 4 additions & 4 deletions sru/csrc/sru_cpu_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ std::vector<at::Tensor> cpu_forward(
const float* reset_b_ptr = forget_b_ptr + hidden_size;
const float* U_ptr = U.data_ptr<float>();
const float* x_ptr = x.data_ptr<float>();
const float* pad_ptr = mask_pad.has_value() ?
mask_pad.value().data_ptr<float>() : NULL;
const bool* pad_ptr = mask_pad.has_value() ?
mask_pad.value().data_ptr<bool>() : NULL;

auto h = at::zeros({length, batch_size, hidden_size}, U.options());
auto c = c_init.clone();
Expand Down Expand Up @@ -175,8 +175,8 @@ std::vector<at::Tensor> cpu_bi_forward(
const float* reset_b_ptr = forget_b_ptr + hidden_size*2;
const float* U_ptr = U.data_ptr<float>();
const float* x_ptr = x.data_ptr<float>();
const float* pad_ptr = mask_pad.has_value() ?
mask_pad.value().data_ptr<float>() : NULL;
const bool* pad_ptr = mask_pad.has_value() ?
mask_pad.value().data_ptr<bool>() : NULL;

auto h = at::zeros({length, batch_size, hidden_size*2}, U.options());
auto c = c_init.clone();
Expand Down
Loading

0 comments on commit a698784

Please sign in to comment.