Skip to content

Commit

Permalink
[Operators] Adding support for torch.nn.GLU module (#461)
Browse files Browse the repository at this point in the history
Closes #228 

Additionally, while working on PR #455 , I noticed that we didn't
register the function/method `floor_divide`. Adding support for this one
is straightforward as it was functionally equivalent to `torch.div(...,
rounding_mode='floor')`. I forgot to include the change in that PR, so I
am including it here.
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Dec 20, 2024
1 parent 798ce6e commit f756051
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,13 @@ def div(x: Union[Tensor, Number], y: Union[Tensor, Number], *, rounding_mode: Op
return int(result)


@register_function(torch.floor_divide)
@register_method(torch.Tensor.floor_divide)
@register_method(torch.Tensor.floor_divide_)
def floor_divide(x: Union[Tensor, Number], y: Union[Tensor, Number]):
return div(x, y, rounding_mode='floor')


@register_function(torch.as_strided)
@register_method(torch.Tensor.as_strided)
def torch_as_strided(
Expand Down Expand Up @@ -1161,6 +1168,14 @@ def silu(x: Tensor, inplace: bool = False):
return ops.silu(x)


@register_function(torch.nn.functional.glu)
def glu(x: Tensor, dim: int = -1):

# split the tensor into two halves along the specified dim
x1, x2 = ops.split(x, 2, axis=dim)
return x1 * ops.sigmoid(x2)


@register_function(torch.nn.functional.hardswish)
def hardswish(x: Tensor, inplace: bool = False):
if inplace:
Expand Down
7 changes: 7 additions & 0 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def __call__(self, x: Tensor) -> Tensor:
return reg_funcs.leaky_relu(x, self.mod.negative_slope, self.mod.inplace)


@register_module(torch.nn.GLU)
class HidetGLU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.GLU)
return reg_funcs.glu(x, self.mod.dim)


@register_module(torch.nn.MaxPool2d)
class HidetMaxPool2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand Down
6 changes: 6 additions & 0 deletions tests/frontends/torch/test_torch_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,11 @@ def test_mish(shape, dtype):
check_module(torch.nn.Mish(), [torch.randn(shape, dtype=dtype)])


@pytest.mark.parametrize("shape", [(10, 20)])
@pytest.mark.parametrize("dim", [0, 1, -1])
def test_glu(shape, dim):
check_module(torch.nn.GLU(dim), [torch.randn(shape)])


if __name__ == '__main__':
pytest.main([__file__])
2 changes: 2 additions & 0 deletions tests/frontends/torch/test_torch_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def test_torch_div(input1, input2):
input2 = input2.cuda() if isinstance(input2, torch.Tensor) else input2
func = FunctionalModule(op=lambda x, y: torch.div(x, y))
func_floor = FunctionalModule(op=lambda x, y: torch.div(x, y, rounding_mode='floor'))
func_floor_divice = FunctionalModule(op=lambda x, y: torch.floor_divide(x, y))
check_module(func, args=[input1, input2], atol=1e-5, rtol=1e-5)
check_module(func_floor, args=[input1, input2], atol=1e-5, rtol=1e-5)
check_module(func_floor_divice, args=[input1, input2], atol=1e-5, rtol=1e-5)


@pytest.mark.parametrize('shape,expanded_shape', [([2, 1], [2, 11]), ([2, 3, 4], [2, 3, 4]), ([1], [6])])
Expand Down

0 comments on commit f756051

Please sign in to comment.