From f75605124a72bdc7369f24b4d513a0103f53c6d3 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 12 Sep 2024 13:08:32 -0400 Subject: [PATCH] [Operators] Adding support for `torch.nn.GLU` module (#461) Closes #228 Additionally, while working on PR #455 , I noticed that we didn't register the function/method `floor_divide`. Adding support for this one is straightforward as it was functionally equivalent to `torch.div(..., rounding_mode='floor')`. I forgot to include the change in that PR, so I am including it here. --- .../graph/frontend/torch/register_functions.py | 15 +++++++++++++++ .../graph/frontend/torch/register_modules.py | 7 +++++++ tests/frontends/torch/test_torch_activation.py | 6 ++++++ .../torch/test_torch_interoperability.py | 2 ++ 4 files changed, 30 insertions(+) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 0f6f89c3b..2de7b2a16 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -619,6 +619,13 @@ def div(x: Union[Tensor, Number], y: Union[Tensor, Number], *, rounding_mode: Op return int(result) +@register_function(torch.floor_divide) +@register_method(torch.Tensor.floor_divide) +@register_method(torch.Tensor.floor_divide_) +def floor_divide(x: Union[Tensor, Number], y: Union[Tensor, Number]): + return div(x, y, rounding_mode='floor') + + @register_function(torch.as_strided) @register_method(torch.Tensor.as_strided) def torch_as_strided( @@ -1161,6 +1168,14 @@ def silu(x: Tensor, inplace: bool = False): return ops.silu(x) +@register_function(torch.nn.functional.glu) +def glu(x: Tensor, dim: int = -1): + + # split the tensor into two halves along the specified dim + x1, x2 = ops.split(x, 2, axis=dim) + return x1 * ops.sigmoid(x2) + + @register_function(torch.nn.functional.hardswish) def hardswish(x: Tensor, inplace: bool = False): if inplace: diff --git a/python/hidet/graph/frontend/torch/register_modules.py b/python/hidet/graph/frontend/torch/register_modules.py index 41f874a18..dd0eb23e8 100644 --- a/python/hidet/graph/frontend/torch/register_modules.py +++ b/python/hidet/graph/frontend/torch/register_modules.py @@ -143,6 +143,13 @@ def __call__(self, x: Tensor) -> Tensor: return reg_funcs.leaky_relu(x, self.mod.negative_slope, self.mod.inplace) +@register_module(torch.nn.GLU) +class HidetGLU(HidetModule): + def __call__(self, x: Tensor) -> Tensor: + assert isinstance(self.mod, torch.nn.GLU) + return reg_funcs.glu(x, self.mod.dim) + + @register_module(torch.nn.MaxPool2d) class HidetMaxPool2d(HidetModule): def __call__(self, x: Tensor) -> Tensor: diff --git a/tests/frontends/torch/test_torch_activation.py b/tests/frontends/torch/test_torch_activation.py index 4ebde06db..c6c7ec458 100644 --- a/tests/frontends/torch/test_torch_activation.py +++ b/tests/frontends/torch/test_torch_activation.py @@ -121,5 +121,11 @@ def test_mish(shape, dtype): check_module(torch.nn.Mish(), [torch.randn(shape, dtype=dtype)]) +@pytest.mark.parametrize("shape", [(10, 20)]) +@pytest.mark.parametrize("dim", [0, 1, -1]) +def test_glu(shape, dim): + check_module(torch.nn.GLU(dim), [torch.randn(shape)]) + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/frontends/torch/test_torch_interoperability.py b/tests/frontends/torch/test_torch_interoperability.py index 5e1b25e85..230fcfb4b 100644 --- a/tests/frontends/torch/test_torch_interoperability.py +++ b/tests/frontends/torch/test_torch_interoperability.py @@ -47,8 +47,10 @@ def test_torch_div(input1, input2): input2 = input2.cuda() if isinstance(input2, torch.Tensor) else input2 func = FunctionalModule(op=lambda x, y: torch.div(x, y)) func_floor = FunctionalModule(op=lambda x, y: torch.div(x, y, rounding_mode='floor')) + func_floor_divice = FunctionalModule(op=lambda x, y: torch.floor_divide(x, y)) check_module(func, args=[input1, input2], atol=1e-5, rtol=1e-5) check_module(func_floor, args=[input1, input2], atol=1e-5, rtol=1e-5) + check_module(func_floor_divice, args=[input1, input2], atol=1e-5, rtol=1e-5) @pytest.mark.parametrize('shape,expanded_shape', [([2, 1], [2, 11]), ([2, 3, 4], [2, 3, 4]), ([1], [6])])