Skip to content

Commit

Permalink
Support advanced options for pooling operators
Browse files Browse the repository at this point in the history
1. Support count_include_pad and divisor_override in average pool.
2. Refactor avg and max pool operators computation definition
  • Loading branch information
yaoyaoding committed Dec 28, 2023
1 parent 9a2f213 commit 4980a7c
Show file tree
Hide file tree
Showing 6 changed files with 445 additions and 289 deletions.
26 changes: 11 additions & 15 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,27 +328,23 @@ def unsqueeze(x: Tensor, dim: int):
def avg_pool2d(
x: Tensor, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None
):
if ceil_mode:
raise NotImplementedError("ceil_mode=True")
if not count_include_pad:
raise NotImplementedError("count_include_pad=False")
if divisor_override is not None:
raise NotImplementedError("divisor_override is not None")
if stride is None:
stride = kernel_size
y = ops.avg_pool2d(x, kernel_size, stride, padding)
y = ops.avg_pool2d(
x,
kernel_size,
stride,
padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
divisor_override=divisor_override,
)
return y


@register_function(torch.nn.functional.avg_pool3d)
def avg_pool3d(x: Tensor, kernel_size, stride, padding, ceil_mode=False, count_include_pad=True, divisor_override=None):
if ceil_mode:
raise NotImplementedError("ceil_mode=True")
if not count_include_pad:
raise NotImplementedError("count_include_pad=False")
if divisor_override is not None:
raise NotImplementedError("divisor_override is not None")
y = ops.avg_pool3d(x, kernel_size, stride, padding)
y = ops.avg_pool3d(x, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
return y


Expand Down Expand Up @@ -1238,7 +1234,7 @@ def isinf(x: Tensor) -> Tensor:


@register_function(torch.nn.functional.pad)
def torch_pad(x: Tensor, pad: Union[Tuple[int], List[int]], mode: str = 'constant', value=0):
def torch_pad(x: Tensor, pad: Union[Tuple[int, ...], List[int]], mode: str = 'constant', value=0):
if isinstance(pad, tuple):
pad = list(pad)
# Torch's pad list has form [p2left, p2right, p1left, p1right, p0left, p0right]
Expand Down
7 changes: 7 additions & 0 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def __call__(self, x: Tensor) -> Tensor:
)


@register_module(torch.nn.ZeroPad2d)
class HidetZeroPad2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ZeroPad2d)
return regs.torch_pad(x=x, pad=self.mod.padding, mode='constant', value=0.0)


@register_module(torch.nn.Linear)
class HidetLinear(HidetModule):
def __init__(self, torch_module: torch.nn.Module):
Expand Down
Loading

0 comments on commit 4980a7c

Please sign in to comment.