Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoyaoding committed Dec 28, 2023
1 parent e69b068 commit 18763f7
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions python/hidet/graph/ops/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def __init__(
class AdaptivePoolNdOp(Operator):
spatial_ndim: Optional[int] = None
reduce_type: Optional[str] = None
last_channel_layout: Optional[bool] = None
last_channel: Optional[bool] = None

def __init__(self, x: Tensor, output_size):
if len(x.shape) != self.spatial_ndim + 2:
Expand All @@ -362,7 +362,7 @@ def __init__(self, x: Tensor, output_size):
self.reduce_type = self.reduce_type

# todo: merge AdaptivePoolTask and AdaptivePoolChannelLastTask into one class
if self.last_channel_layout:
if self.last_channel:
task = AdaptivePoolChannelLastTask(input_like(x, 'x'), output_size, reduce_type=self.reduce_type)
else:
task = AdaptivePoolTask(input_like(x, 'x'), output_size, reduce_type=self.reduce_type)
Expand Down Expand Up @@ -433,43 +433,43 @@ class AvgPool3dChannelLastOp(AvgPoolNdOp):
class AdaptiveAvgPool1dOp(AdaptivePoolNdOp):
reduce_type = 'avg'
spatial_ndim = 1
last_channel_layout = False
last_channel = False


class AdaptiveAvgPool2dOp(AdaptivePoolNdOp):
reduce_type = 'avg'
spatial_ndim = 2
last_channel_layout = False
last_channel = False


class AdaptiveAvgPool3dOp(AdaptivePoolNdOp):
reduce_type = 'avg'
spatial_ndim = 3
last_channel_layout = False
last_channel = False


class AdaptiveAvgPool2dChannelLastOp(AdaptivePoolNdOp):
reduce_type = 'avg'
spatial_ndim = 2
last_channel_layout = True
last_channel = True


class AdaptiveMaxPool1dOp(AdaptivePoolNdOp):
reduce_type = 'max'
spatial_ndim = 1
last_channel_layout = False
last_channel = False


class AdaptiveMaxPool2dOp(AdaptivePoolNdOp):
reduce_type = 'max'
spatial_ndim = 2
last_channel_layout = False
last_channel = False


class AdaptiveMaxPool3dOp(AdaptivePoolNdOp):
reduce_type = 'max'
spatial_ndim = 3
last_channel_layout = False
last_channel = False


def max_pool1d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor:
Expand Down Expand Up @@ -575,8 +575,10 @@ def adaptive_avg_pool2d_channel_last(x: Tensor, output_size: Union[int, Sequence

@register_resolve_rule(AdaptivePoolNdOp)
class AdaptivePoolResolveRule(ResolveRule):
def resolve(self, op: Operator) -> Optional[List[Tensor]]:
def resolve(self, op: AdaptivePoolNdOp) -> Optional[List[Tensor]]:
assert isinstance(op, AdaptivePoolNdOp)
if not op.last_channel:
return None
x: Tensor = op.inputs[0]
output_size = op.attrs['output_size']
reduce_type = op.reduce_type
Expand All @@ -590,24 +592,3 @@ def resolve(self, op: Operator) -> Optional[List[Tensor]]:
elif reduce_type == 'avg':
return [mean(x, dims=dims[2:], keep_dim=True)]
return None


@register_resolve_rule(AdaptivePoolChannelLastOp)
class AdaptivePoolChannelLastResolveRule(ResolveRule):
def resolve(self, op: Operator) -> Optional[List[Tensor]]:
assert isinstance(op, AdaptivePoolChannelLastOp)
x: Tensor = op.inputs[0]
# TODO: Deal with generic N-dimensional convolution
if len(x.shape) != 4:
return None
output_size = op.attrs['output_size']
reduce_type = op.reduce_type
resolve_to_reduce = output_size == 1 if isinstance(output_size, int) else all(d == 1 for d in output_size)
if resolve_to_reduce:
from hidet.graph.ops import mean, max

if reduce_type == 'max':
return [max(x, dims=[1, 2], keep_dim=True)]
elif reduce_type == 'avg':
return [mean(x, dims=[1, 2], keep_dim=True)]
return None

0 comments on commit 18763f7

Please sign in to comment.