From 4c5585ac5bb9eb1a322e420cc1cd73c2190add41 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 2 Jul 2024 18:38:53 +0800 Subject: [PATCH] wrap unsqueeze --- k2/python/csrc/torch/v2/any.cu | 2 ++ k2/python/csrc/torch/v2/ragged_any.cu | 8 ++++++++ k2/python/csrc/torch/v2/ragged_any.h | 2 ++ k2/python/k2/__init__.py | 1 + 4 files changed, 13 insertions(+) diff --git a/k2/python/csrc/torch/v2/any.cu b/k2/python/csrc/torch/v2/any.cu index 4475e3c09..a3c6653f0 100644 --- a/k2/python/csrc/torch/v2/any.cu +++ b/k2/python/csrc/torch/v2/any.cu @@ -353,6 +353,8 @@ void PybindRaggedAny(py::module &m) { any.def("unique", &RaggedAny::Unique, py::arg("need_num_repeats") = false, py::arg("need_new2old_indexes") = false, kRaggedAnyUniqueDoc); + any.def("unsqueeze", &RaggedAny::Unsqueeze, py::arg("axis")); + any.def("normalize", &RaggedAny::Normalize, py::arg("use_log"), kRaggedAnyNormalizeDoc); diff --git a/k2/python/csrc/torch/v2/ragged_any.cu b/k2/python/csrc/torch/v2/ragged_any.cu index 3c690d5d9..e0114ae07 100644 --- a/k2/python/csrc/torch/v2/ragged_any.cu +++ b/k2/python/csrc/torch/v2/ragged_any.cu @@ -575,6 +575,14 @@ RaggedAny RaggedAny::Cat(const std::vector &srcs, int32_t axis) { return {}; } +RaggedAny RaggedAny::Unsqueeze(int32_t axis) { + DeviceGuard guard(any.Context()); + Dtype t = any.GetDtype(); + FOR_REAL_AND_INT32_TYPES(t, T, { + return RaggedAny(k2::Unsqueeze(any.Specialize(), axis).Generic()); + }); +} + std::tuple, torch::optional> RaggedAny::Unique(bool need_num_repeats /*= false*/, diff --git a/k2/python/csrc/torch/v2/ragged_any.h b/k2/python/csrc/torch/v2/ragged_any.h index d70011924..b741347b2 100644 --- a/k2/python/csrc/torch/v2/ragged_any.h +++ b/k2/python/csrc/torch/v2/ragged_any.h @@ -264,6 +264,8 @@ struct RaggedAny { torch::optional> Unique(bool need_num_repeats = false, bool need_new2old_indexes = false); + RaggedAny Unsqueeze(int32_t axis); + /// Wrapper for k2::NormalizePerSublist RaggedAny Normalize(bool use_log) /*const*/; diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 5d0691c1e..96696de7f 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -31,6 +31,7 @@ from .ragged import RaggedShape from .ragged import RaggedTensor +from .ragged import create_ragged_shape2 from . import autograd from . import autograd_utils