Skip to content

Commit

Permalink
try to resolve a torchscript loading issue
Browse files Browse the repository at this point in the history
  • Loading branch information
taoleicn committed May 13, 2021
1 parent bcf6a0d commit 72d9184
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 41 deletions.
188 changes: 188 additions & 0 deletions sru/csrc/sru_cuda_impl_dummy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#include <torch/script.h>
#include <vector>

/* Implementation starts here */

// unidirectional forward()
std::vector<at::Tensor> sru_forward_simple(
const at::Tensor & U,
const at::optional<at::Tensor> & x,
const at::Tensor & weight_c,
const at::Tensor & bias,
const at::Tensor & c_init,
const at::optional<at::Tensor> & mask_c,
const at::optional<at::Tensor> & mask_pad,
const int64_t length,
const int64_t batch_size,
const int64_t hidden_size) {

throw "Failed to load SRU recurrence operators for GPU";
}

// unidirectional backward()
std::vector<at::Tensor> sru_backward_simple(
const at::Tensor & U,
const at::optional<at::Tensor> & x,
const at::Tensor & weight_c,
const at::Tensor & bias,
const at::Tensor & c_init,
const at::optional<at::Tensor> & mask_c,
const at::optional<at::Tensor> & mask_pad,
const at::Tensor & c,
const at::Tensor & grad_h,
const at::Tensor & grad_last,
const int64_t length,
const int64_t batch_size,
const int64_t hidden_size) {

throw "Failed to load SRU recurrence operators for GPU";
}

// bidirectional forward()
std::vector<at::Tensor> sru_bi_forward_simple(
const at::Tensor & U,
const at::optional<at::Tensor> & x,
const at::Tensor & weight_c,
const at::Tensor & bias,
const at::Tensor & c_init,
const at::optional<at::Tensor> & mask_c,
const at::optional<at::Tensor> & mask_pad,
const int64_t length,
const int64_t batch_size,
const int64_t hidden_size) {

throw "Failed to load SRU recurrence operators for GPU";
}

// bidirectional backward()
std::vector<at::Tensor> sru_bi_backward_simple(
const at::Tensor & U,
const at::optional<at::Tensor> & x,
const at::Tensor & weight_c,
const at::Tensor & bias,
const at::Tensor & c_init,
const at::optional<at::Tensor> & mask_c,
const at::optional<at::Tensor> & mask_pad,
const at::Tensor & c,
const at::Tensor & grad_h,
const at::Tensor & grad_last,
const int64_t length,
const int64_t batch_size,
const int64_t hidden_size) {

throw "Failed to load SRU recurrence operators for GPU";
}

// unidirectional forward()
std::vector<at::Tensor> sru_forward(
const at::Tensor & U,
const at::optional<at::Tensor> & x,
const at::Tensor & weight_c,
const at::Tensor & bias,
const at::Tensor & c_init,
const at::optional<at::Tensor> & mask_c,
const at::optional<at::Tensor> & mask_pad,
const int64_t length,
const int64_t batch_size,
const int64_t hidden_size,
const int64_t k,
const int64_t activation_type,
const int64_t skip_type,
const int64_t is_custom) {

throw "Failed to load SRU recurrence operators for GPU";
}

// bidirectional forward()
std::vector<at::Tensor> sru_bi_forward(
const at::Tensor & U,
const at::optional<at::Tensor> & x,
const at::Tensor & weight_c,
const at::Tensor & bias,
const at::Tensor & c_init,
const at::optional<at::Tensor> & mask_c,
const at::optional<at::Tensor> & mask_pad,
const int64_t length,
const int64_t batch_size,
const int64_t hidden_size,
const int64_t k,
const int64_t activation_type,
const int64_t skip_type,
const int64_t is_custom) {

throw "Failed to load SRU recurrence operators for GPU";
}

// unidirectional backward()
std::vector<at::Tensor> sru_backward(
const at::Tensor & U,
const at::optional<at::Tensor> & x,
const at::Tensor & weight_c,
const at::Tensor & bias,
const at::Tensor & c_init,
const at::optional<at::Tensor> & mask_c,
const at::optional<at::Tensor> & mask_pad,
const at::Tensor & c,
const at::Tensor & grad_h,
const at::Tensor & grad_last,
const int64_t length,
const int64_t batch_size,
const int64_t hidden_size,
const int64_t k,
const int64_t activation_type,
const int64_t skip_type,
const int64_t is_custom) {

throw "Failed to load SRU recurrence operators for GPU";
}

// bidirectional backward()
std::vector<at::Tensor> sru_bi_backward(
const at::Tensor & U,
const at::optional<at::Tensor> & x,
const at::Tensor & weight_c,
const at::Tensor & bias,
const at::Tensor & c_init,
const at::optional<at::Tensor> & mask_c,
const at::optional<at::Tensor> & mask_pad,
const at::Tensor & c,
const at::Tensor & grad_h,
const at::Tensor & grad_last,
const int64_t length,
const int64_t batch_size,
const int64_t hidden_size,
const int64_t k,
const int64_t activation_type,
const int64_t skip_type,
const int64_t is_custom) {

throw "Failed to load SRU recurrence operators for GPU";
}

// This way of registing custom op is based on earlier PRs of Pytorch:
// https://github.com/pytorch/pytorch/pull/28229
//
// In Pytorch 1.6, the recommended way is to use TORCH_LIBRARY(), e.g.
//
// TORCH_LIBRARY(sru_cpu, m) {
// m.def("cpu_forward", &cpu_forward);
// m.def("cpu_bi_forward", &cpu_bi_forward);
// }
//
// We choose this way for backward compatibility.
static auto registory1 =
torch::RegisterOperators("sru_cuda::sru_forward_simple", &sru_forward_simple);
static auto registory2 =
torch::RegisterOperators("sru_cuda::sru_backward_simple", &sru_backward_simple);
static auto registory3 =
torch::RegisterOperators("sru_cuda::sru_bi_forward_simple", &sru_bi_forward_simple);
static auto registory4 =
torch::RegisterOperators("sru_cuda::sru_bi_backward_simple", &sru_bi_backward_simple);
static auto registory5 =
torch::RegisterOperators("sru_cuda::sru_forward", &sru_forward);
static auto registory6 =
torch::RegisterOperators("sru_cuda::sru_backward", &sru_backward);
static auto registory7 =
torch::RegisterOperators("sru_cuda::sru_bi_forward", &sru_bi_forward);
static auto registory8 =
torch::RegisterOperators("sru_cuda::sru_bi_backward", &sru_bi_backward);
37 changes: 26 additions & 11 deletions sru/cuda_functional.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
from typing import Optional, Tuple
import os
import warnings

import torch
from torch import Tensor
from torch.autograd import Function
from torch.utils.cpp_extension import load

sources = [
os.path.join(os.path.dirname(__file__), "csrc", "sru_cuda_impl.cpp"),
os.path.join(os.path.dirname(__file__), "csrc", "sru_cuda_kernel.cu"),
]
load(
name="sru_cuda",
sources=sources,
extra_cflags=['-O3'],
is_python_module=False,
verbose=False
)
try:
sources = [
os.path.join(os.path.dirname(__file__), "csrc", "sru_cuda_impl.cpp"),
os.path.join(os.path.dirname(__file__), "csrc", "sru_cuda_kernel.cu"),
]
load(
name="sru_cuda",
sources=sources,
extra_cflags=['-O3'],
is_python_module=False,
verbose=False
)
except Exception as e:
warnings.warn("Just-in-time loading and compiling the CUDA kernels of SRU was unsuccessful. "
"Got the following error:\n" + str(e))
sources_dummy = [
os.path.join(os.path.dirname(__file__), "csrc", "sru_cuda_impl_dummy.cpp"),
]
load(
name="sru_cuda",
sources=sources_dummy,
extra_cflags=['-O3'],
is_python_module=False,
verbose=False
)


@torch.jit.script
Expand Down
32 changes: 2 additions & 30 deletions sru/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch import Tensor
from torch.utils.cpp_extension import load
from .cuda_functional import elementwise_recurrence_forward

# JIT compilation of elementwise fwd operator (CPU version)
cpu_source = os.path.join(os.path.dirname(__file__), "csrc", "sru_cpu_impl.cpp")
Expand All @@ -18,35 +19,6 @@
)


def elementwise_recurrence_dummy(
u: Tensor,
x: Tensor,
weight_c: Tensor,
bias: Tensor,
init: Tensor,
activation_type: int,
d_out: int,
bidirectional: bool,
has_skip_term: bool,
scale_x: Optional[Tensor] = None,
mask_c: Optional[Tensor] = None,
mask_pad: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Tensor]:
"""Dummy function for the case that CUDA isn't available
"""
raise Exception("Failed to load the CUDA kernel of SRU elementwise recurrence.")


# If we failed to import CUDA implementation, we use a dummy method that simply
# raises an exception. This ensures the torchscript method can compile on machines
# that don't have GPUs or CUDA.
try:
from .cuda_functional import elementwise_recurrence_forward
elementwise_recurrence_cuda_torchscript = elementwise_recurrence_forward
except Exception:
elementwise_recurrence_cuda_torchscript = elementwise_recurrence_dummy


@torch.jit.script
def elementwise_recurrence_inference(U: Tensor,
x: Tensor,
Expand All @@ -71,7 +43,7 @@ def elementwise_recurrence_inference(U: Tensor,
is_custom = weight_c.dim() > 1
mask_pad = None if mask_pad is None else mask_pad.to(dtype=torch.bool).contiguous()
if U.is_cuda:
h, last_hidden, c = elementwise_recurrence_cuda_torchscript(
h, last_hidden, c = elementwise_recurrence_forward(
U,
x,
weight_c,
Expand Down

0 comments on commit 72d9184

Please sign in to comment.