Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement parallel levenshtein distance on GPU #1057

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions k2/python/csrc/torch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "k2/python/csrc/torch/fsa_algo.h"
#include "k2/python/csrc/torch/index_add.h"
#include "k2/python/csrc/torch/index_select.h"
#include "k2/python/csrc/torch/levenshtein_distance.h"
#include "k2/python/csrc/torch/mutual_information.h"
#include "k2/python/csrc/torch/nbest.h"
#include "k2/python/csrc/torch/ragged.h"
Expand All @@ -42,6 +43,7 @@ void PybindTorch(py::module &m) {
PybindFsaAlgo(m);
PybindIndexAdd(m);
PybindIndexSelect(m);
PybindLevenshteinDistance(m);
PybindMutualInformation(m);
PybindNbest(m);
PybindRagged(m);
Expand Down
1 change: 0 additions & 1 deletion k2/python/csrc/torch.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

#include "k2/csrc/log.h"
#include "k2/csrc/torch_util.h"
#include "k2/python/csrc/torch.h"
#include "torch/extension.h"

namespace pybind11 {
Expand Down
4 changes: 3 additions & 1 deletion k2/python/csrc/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ set(torch_srcs
fsa_algo.cu
index_add.cu
index_select.cu
levenshtein_distance.cu
levenshtein_distance_cpu.cu
mutual_information.cu
mutual_information_cpu.cu
nbest.cu
Expand All @@ -20,7 +22,7 @@ set(torch_srcs
)

if (K2_WITH_CUDA)
list(APPEND torch_srcs mutual_information_cuda.cu)
list(APPEND torch_srcs levenshtein_distance_cuda.cu mutual_information_cuda.cu)
endif()

set(torch_srcs_with_prefix)
Expand Down
44 changes: 44 additions & 0 deletions k2/python/csrc/torch/levenshtein_distance.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "k2/csrc/device_guard.h"
#include "k2/csrc/torch_util.h"
#include "k2/python/csrc/torch/levenshtein_distance.h"

void PybindLevenshteinDistance(py::module &m) {
m.def(
"levenshtein_distance",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary) -> torch::Tensor {
k2::DeviceGuard guard(k2::GetContext(px));
if (px.device().is_cpu()) {
return k2::LevenshteinDistanceCpu(px, py, boundary);
} else {
#ifdef K2_WITH_CUDA
return k2::LevenshteinDistanceCuda(px, py, boundary);
#else
K2_LOG(FATAL) << "Failed to find native CUDA module, make sure "
<< "that you compiled the code with K2_WITH_CUDA.";
return torch::Tensor();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"));
}
74 changes: 74 additions & 0 deletions k2/python/csrc/torch/levenshtein_distance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef K2_PYTHON_CSRC_TORCH_LEVENSHTEIN_DISTANCE_H_
#define K2_PYTHON_CSRC_TORCH_LEVENSHTEIN_DISTANCE_H_

#include <torch/extension.h>

#include <vector>

#include "k2/python/csrc/torch.h"

namespace k2 {

/*
Compute the levenshtein distance between sequences in batches.

@param px A two-dimensional tensor with the shape of ``[B][S]`` containing
sequences. It's data type MUST be ``torch.int32``.
@param py A two-dimensional tensor with the shape of ``[B][U]`` containing
sequences. It's data type MUST be ``torch.int32``.
``py`` and ``px`` should have the same batch size.
@param boundary If supplied, a torch.LongTensor of shape ``[B][4]``, where
each row contains ``[s_begin, u_begin, s_end, u_end]``,
with ``0 <= s_begin <= s_end <= S`` and
``0 <= u_begin <= u_end <= U``
(this implies that empty sequences are allowed).
If not supplied, the values ``[0, 0, S, U]`` will be
assumed. These are the beginning and one-past-the-last
positions in the ``px`` and ``py`` sequences respectively,
and can be used if not all sequences are of the same
length.
@return A tensor with shape ``[B][S + 1][U + 1]`` containing the
levenshtein distance between the sequences. Each element
``[b][s][u]`` means the levenshtein distance between ``px[b][:s]``
and ``py[b][:u]``. If `boundary` is set, the values in the
positions out of the range of boundary are uninitialized, can be
any random values.
*/
torch::Tensor LevenshteinDistanceCpu(
torch::Tensor px, // [B][S]
torch::Tensor py, // [B][U]
torch::optional<torch::Tensor> boundary); // [B][4], int64_t.

/*
The same as above function, but it runs on GPU.
*/
torch::Tensor LevenshteinDistanceCuda(
torch::Tensor px, // [B][S]
torch::Tensor py, // [B][U]
torch::optional<torch::Tensor> boundary); // [B][4], int64_t.

} // namespace k2

void PybindLevenshteinDistance(py::module &m);

#endif // K2_PYTHON_CSRC_TORCH_LEVENSHTEIN_DISTANCE_H_
81 changes: 81 additions & 0 deletions k2/python/csrc/torch/levenshtein_distance_cpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/**
* @copyright
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <algorithm>

#include "k2/python/csrc/torch/levenshtein_distance.h"

namespace k2 {

torch::Tensor LevenshteinDistanceCpu(
torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary) {
TORCH_CHECK(px.dim() == 2, "px must be 2-dimensional");
TORCH_CHECK(py.dim() == 2, "py must be 2-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu(),
"inputs must be CPU tensors");
TORCH_CHECK(px.dtype() == torch::kInt32 && py.dtype() == torch::kInt32,
"The dtype of inputs must be kInt32");

auto opts = torch::TensorOptions().dtype(px.dtype()).device(px.device());

const int B = px.size(0), S = px.size(1), U = py.size(1);
TORCH_CHECK(B == py.size(0), "px and py must have same batch size");

auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, U},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);

torch::Tensor ans = torch::empty({B, S + 1, U + 1}, opts);

auto px_a = px.accessor<int32_t, 2>(), py_a = py.accessor<int32_t, 2>();
auto boundary_a = boundary.accessor<int64_t, 2>();
auto ans_a = ans.accessor<int32_t, 3>();

for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int u_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int u_end = boundary_a[b][3];
ans_a[b][s_begin][u_begin] = 0;

for (int s = s_begin + 1; s <= s_end; ++s)
ans_a[b][s][u_begin] = s - s_begin;
for (int u = u_begin + 1; u <= u_end; ++u)
ans_a[b][s_begin][u] = u - u_begin;

for (int s = s_begin + 1; s <= s_end; ++s) {
for (int u = u_begin + 1; u <= u_end; ++u) {
int cost = px_a[b][s - 1] == py_a[b][u - 1] ? 0 : 1;
ans_a[b][s][u] =
min(min(ans_a[b][s - 1][u] + 1, ans_a[b][s][u - 1] + 1),
ans_a[b][s - 1][u - 1] + cost);
}
}
}
return ans;
}

} // namespace k2
Loading