diff --git a/gallery/how-to-guides/add-subgraph-rewrite-rule.py b/gallery/how-to-guides/add-subgraph-rewrite-rule.py index 9048b3250..f78f7cdff 100644 --- a/gallery/how-to-guides/add-subgraph-rewrite-rule.py +++ b/gallery/how-to-guides/add-subgraph-rewrite-rule.py @@ -133,10 +133,13 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]: # %% # We can check that the rewrite rule has been registered: -from hidet.graph.transforms import registered_rewrite_rules +from hidet.graph.transforms import ( + registered_rewrite_rules, + clear_registered_rewrite_rules, +) print('Registered rewrite rules:') -for rule in registered_rewrite_rules: +for rule in registered_rewrite_rules(): assert isinstance(rule, SubgraphRewriteRule) print(rule.name) @@ -146,7 +149,7 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]: # Besides the predefined rewrite rules, we can see that the rewrite rule we just registered is also included at the # last line. In this tutorial, to prevent the default rewrite rules from being applied, we first clear the registered # rewrite rules and then register the rewrite rule we just defined: -registered_rewrite_rules.clear() +clear_registered_rewrite_rules() register_rewrite_rule( FuseTwoMatmulRewriteRule() ) # a second way to register the rewrite rule diff --git a/python/hidet/graph/frontend/torch/interpreter.py b/python/hidet/graph/frontend/torch/interpreter.py index f1407b504..eb95bcfff 100644 --- a/python/hidet/graph/frontend/torch/interpreter.py +++ b/python/hidet/graph/frontend/torch/interpreter.py @@ -160,11 +160,33 @@ def __init__(self, torch_module: torch.nn.Module): def __call__(self, *args, **kwargs): raise NotImplementedError() + def _get_weight_norm_hook(self, name: str): + from torch.nn.utils.weight_norm import WeightNorm + + for hook in self.mod._forward_pre_hooks.values(): # pylint: disable=protected-access + if isinstance(hook, WeightNorm) and hook.name == name: + return hook + return None + + def _used_weight_norm(self, name: str) -> bool: + return self._get_weight_norm_hook(name) is not None + + def _compute_weight_norm(self, name: str) -> Tensor: + hook = self._get_weight_norm_hook(name) + return hook.compute_weight(self.mod) + def param(self, name: str, optional=False) -> Optional[Tensor]: if name not in self.torch_params: + # see https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html + # to learn more about weight norm. + if self._used_weight_norm(name): + self.torch_params[name] = self._compute_weight_norm(name) + return self.param(name, optional) + if optional: return None raise RuntimeError(f"hidet: {self.mod} has no parameter/buffer {name}") + if name not in self.hidet_params: if self.torch_params[name] is None: self.hidet_params[name] = None diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 9302dadbb..243321251 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -10,7 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=redefined-builtin -from .conv1d import conv1d +from .matmul import batch_matmul, matmul, matmul_x86 +from .conv1d import conv1d, conv1d_gemm from .conv1d_transpose import conv1d_transpose from .conv2d import ( conv2d, @@ -24,7 +25,6 @@ from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm from .conv3d import conv3d, conv3d_gemm from .conv3d_transpose import conv3d_transpose -from .matmul import batch_matmul, matmul, matmul_x86 from .pool import avg_pool2d, avg_pool3d, adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d from .pool import max_pool2d, max_pool3d, adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d from .activation import relu, leaky_relu, sigmoid, hardsigmoid, clip, relu6, prelu, gelu, silu, hardswish diff --git a/python/hidet/graph/ops/conv1d/__init__.py b/python/hidet/graph/ops/conv1d/__init__.py index 27c9890c2..27173dc14 100644 --- a/python/hidet/graph/ops/conv1d/__init__.py +++ b/python/hidet/graph/ops/conv1d/__init__.py @@ -11,3 +11,6 @@ # limitations under the License. from .conv1d import conv1d from .conv1d import Conv1dOp +from .conv1d_gemm import conv1d_gemm + +from . import resolve diff --git a/python/hidet/graph/ops/conv1d/conv1d_gemm.py b/python/hidet/graph/ops/conv1d/conv1d_gemm.py new file mode 100644 index 000000000..0b365d22f --- /dev/null +++ b/python/hidet/graph/ops/conv1d/conv1d_gemm.py @@ -0,0 +1,93 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.graph.ops.utils import Operator, input_like +from hidet.graph.ops.utils import normalize_kernel, normalize_stride +from hidet.graph.tensor import Tensor +from hidet.ir.compute import TensorNode +from hidet.ir.compute import compute +from hidet.ir.expr import is_constant +from hidet.ir.task import Task +from .utils import infer_conv1d_shape + + +class Conv1dGemmImageTransformTask(Task): + def __init__(self, x: TensorNode, kernel: int, stride: int, dilation: int, groups: int): + n, c, h = x.shape + kx = kernel + sx = stride + dilx = dilation + p = (h - dilx * (kx - 1) - 1) // sx + 1 + self._assert( + c % groups == 0, + msg='Conv1d expect in_channels % groups == 0, but got in_channels {} and groups {}'.format(c, groups), + ) + gc = c // groups # group channels + gemm_x = compute( + name='gemm_x', + shape=[groups, n * p, gc * kx], + fcompute=lambda g, i, k: x[i // p, g * gc + k // kx, i % p * sx + k % kx * dilx], + ) + super().__init__(name='conv1d_gemm_image_transform', inputs=[x], outputs=[gemm_x]) + + +class Conv1dGemmImageTransformOp(Operator): + def __init__(self, x: Tensor, kernel, stride, dilations, groups): + (kernel,) = normalize_kernel(kernel, dim=1) + (stride,) = normalize_stride(stride, dim=1) + super().__init__( + inputs=[x], + attributes={'kernel': kernel, 'stride': stride, 'groups': groups, 'dilations': dilations}, + task=Conv1dGemmImageTransformTask(input_like(x, 'x'), kernel, stride, dilations, groups), + ) + + +def conv1d_gemm_image_transform(x: Tensor, kernel: int, stride: int, dilation: int, groups: int = 1) -> Tensor: + return Conv1dGemmImageTransformOp(x, kernel, stride, dilation, groups).outputs[0] + + +def conv1d_gemm_filter_transform(w: Tensor, groups: int = 1) -> Tensor: + # weight shape: [oc, c, kx] + # output shape: [groups, c * kx, ogc] where ogc = oc // groups + oc, c, kx = w.shape + # TODO: current assertion mechanism does not cover this use case (only on the task-level) + if is_constant(oc, groups) and oc % groups != 0: + raise ValueError('invalid conv1d groups {} for out channels {}'.format(groups, oc)) + ogc = oc // groups + w = w.reshape([groups, ogc, c, kx]) # [groups, ogc, c, kx] + w = w.rearrange([[0], [2, 3], [1]]) # [groups, c * kx, ogc] + return w + + +def conv1d_gemm_inverse_transform(gemm_y: Tensor, out_height) -> Tensor: + # gemm_y shape: [groups, n * p, ogc] + # output shape: [n, oc, p] where oc = groups * ogc + p = out_height + groups, npq, ogc = gemm_y.shape + # TODO: current assertion mechanism does not cover this use case (only on the task-level) + if is_constant(npq, p) and npq % p != 0: + raise ValueError('invalid conv1d output shape {} for dimension {}'.format(npq, p)) + n = npq // p + y = gemm_y.reshape([groups, n, p, ogc]) + y = y.rearrange([[1], [0, 3], [2]]) + return y + + +def conv1d_gemm(data: Tensor, weight: Tensor, stride, dilation: int = 1, groups: int = 1) -> Tensor: + from hidet import ops + + gemm_x = conv1d_gemm_image_transform(data, kernel=weight.shape[2], stride=stride, dilation=dilation, groups=groups) + gemm_w = conv1d_gemm_filter_transform(weight, groups=groups) + gemm_y = ops.matmul(gemm_x, gemm_w, require_prologue=True) + + y_shape = infer_conv1d_shape(data.shape, weight.shape, stride, groups, dilation) + y = conv1d_gemm_inverse_transform(gemm_y, out_height=y_shape[2]) + return y diff --git a/python/hidet/graph/ops/conv1d/resolve.py b/python/hidet/graph/ops/conv1d/resolve.py new file mode 100644 index 000000000..e843a1647 --- /dev/null +++ b/python/hidet/graph/ops/conv1d/resolve.py @@ -0,0 +1,28 @@ +from typing import List, Optional +from hidet.graph.operator import Operator, Tensor +from hidet.graph.transforms import ResolveRule, register_resolve_rule +from hidet.graph import ops +from hidet.ir.expr import is_constant + +from .conv1d import Conv1dOp + + +@register_resolve_rule(Conv1dOp) +class Conv1dResolveRule(ResolveRule): + def __init__(self, enable_winograd=False): + self.enable_winograd = enable_winograd + + def resolve(self, op: Operator) -> Optional[List[Tensor]]: + assert isinstance(op, Conv1dOp) + (stride,) = ops.utils.normalize_stride(op.attrs['stride'], dim=1) + groups = op.attrs['groups'] + (dilations,) = op.attrs['dilations'] + channels = op.inputs[1].shape[0] + + if is_constant(channels) and groups == channels: + return None # use depthwise schedule in the default Task + + data, weight = op.inputs + # implicit gemm algorithm + out = ops.conv1d_gemm(data, weight, stride, dilations, groups) + return [out] diff --git a/python/hidet/graph/ops/conv1d/utils.py b/python/hidet/graph/ops/conv1d/utils.py new file mode 100644 index 000000000..5fec60d1d --- /dev/null +++ b/python/hidet/graph/ops/conv1d/utils.py @@ -0,0 +1,31 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Sequence +from hidet.ir.expr import is_constant +from ..utils import normalize_stride + + +def infer_conv1d_shape( + x_shape: Sequence[int], w_shape: Sequence[int], stride: int, groups: int, dilation: int +) -> List[int]: + n, c, d = x_shape + oc, gc, kd = w_shape + (sx,) = normalize_stride(stride, dim=1) + dilx = dilation + if is_constant(c) and gc * groups != c: + msg = 'Conv2d: x has {} input channels, w has {} group channels, and groups={}'.format(c, gc, groups) + raise ValueError(msg) + if oc % groups != 0: + msg = 'Conv2d expects out_channels % groups == 0, got out_channels {} and groups {}'.format(oc, groups) + raise ValueError(msg) + p = (d - dilx * (kd - 1) - 1) // sx + 1 + return [n, oc, p] diff --git a/python/hidet/graph/transforms/__init__.py b/python/hidet/graph/transforms/__init__.py index 1a32b155d..2601dd2ba 100644 --- a/python/hidet/graph/transforms/__init__.py +++ b/python/hidet/graph/transforms/__init__.py @@ -21,7 +21,7 @@ from .resolve_variant import ResolveRule, register_resolve_rule, get_resolve_chain from .graph_patterns import TensorPattern, OperatorPattern, SubgraphRewriteRule, register_rewrite_rule, op_pattern -from .graph_patterns import registered_rewrite_rules +from .graph_patterns import registered_rewrite_rules, clear_registered_rewrite_rules def optimize(graph: FlowGraph) -> FlowGraph: diff --git a/python/hidet/graph/transforms/base.py b/python/hidet/graph/transforms/base.py index 5b1e3a8e0..5016489cc 100644 --- a/python/hidet/graph/transforms/base.py +++ b/python/hidet/graph/transforms/base.py @@ -91,7 +91,7 @@ def __enter__(self) -> PassContext: return self def __exit__(self, exc_type, exc_val, exc_tb): - from ..transforms.graph_patterns import deregister_attn_patterns + from ..transforms.graph_patterns.attn_patterns import deregister_attn_patterns deregister_attn_patterns() popped = self._stack.pop() @@ -166,7 +166,7 @@ def set_use_attention(self, flag=False) -> PassContext: if cc < (7, 5): return self - from ..transforms.graph_patterns import register_attn_patterns, deregister_attn_patterns + from ..transforms.graph_patterns.attn_patterns import register_attn_patterns, deregister_attn_patterns self.configs['use_attention'] = flag if flag: diff --git a/python/hidet/graph/transforms/graph_patterns/__init__.py b/python/hidet/graph/transforms/graph_patterns/__init__.py index 687e8babf..a9cc146a4 100644 --- a/python/hidet/graph/transforms/graph_patterns/__init__.py +++ b/python/hidet/graph/transforms/graph_patterns/__init__.py @@ -11,8 +11,4 @@ # limitations under the License. from .base import TensorPattern, OperatorPattern, SubgraphRewriteRule, MatchDict, Usage, graph_pattern_match from .base import register_rewrite_rule, op_pattern, registered_rewrite_rules, deregister_rewrite_rule -from .arithmetic_patterns import arithmetic_patterns -from .transform_patterns import transform_patterns -from .attn_patterns import attn_patterns, register_attn_patterns, deregister_attn_patterns -from .conv2d_patterns import conv2d_patterns -from .matmul_patterns import matmul_patterns +from .base import clear_registered_rewrite_rules diff --git a/python/hidet/graph/transforms/graph_patterns/base.py b/python/hidet/graph/transforms/graph_patterns/base.py index 60b815480..d2ad9389a 100644 --- a/python/hidet/graph/transforms/graph_patterns/base.py +++ b/python/hidet/graph/transforms/graph_patterns/base.py @@ -286,7 +286,19 @@ def graph_pattern_match(pattern: TensorPattern, target: Tensor, usage: Usage) -> return None -registered_rewrite_rules: List[SubgraphRewriteRule] = [] +_registered_rewrite_rules: List[SubgraphRewriteRule] = [] + + +def registered_rewrite_rules(): + # pylint: disable=unused-import + + from . import register_all_patterns # register on demand + + return list(_registered_rewrite_rules) + + +def clear_registered_rewrite_rules(): + _registered_rewrite_rules.clear() def register_rewrite_rule(rule: Union[SubgraphRewriteRule, Type[SubgraphRewriteRule]]): @@ -300,10 +312,10 @@ def register_rewrite_rule(rule: Union[SubgraphRewriteRule, Type[SubgraphRewriteR should be an instance of SubgraphRewriteRule. """ if isinstance(rule, SubgraphRewriteRule): - registered_rewrite_rules.append(rule) + _registered_rewrite_rules.append(rule) return None elif issubclass(rule, SubgraphRewriteRule): - registered_rewrite_rules.append(rule()) + _registered_rewrite_rules.append(rule()) return rule else: raise TypeError('rule should be a SubgraphRewriteRule or a subclass of SubgraphRewriteRule') @@ -319,7 +331,7 @@ def deregister_rewrite_rule(rule: SubgraphRewriteRule): The rule to be deregistered. """ if isinstance(rule, SubgraphRewriteRule): - registered_rewrite_rules.remove(rule) + _registered_rewrite_rules.remove(rule) return None else: raise TypeError('rule should be a SubgraphRewriteRule') diff --git a/python/hidet/graph/transforms/graph_patterns/register_all_patterns.py b/python/hidet/graph/transforms/graph_patterns/register_all_patterns.py new file mode 100644 index 000000000..a0ecb0fa2 --- /dev/null +++ b/python/hidet/graph/transforms/graph_patterns/register_all_patterns.py @@ -0,0 +1,6 @@ +# pylint: disable=unused-import +from .arithmetic_patterns import arithmetic_patterns +from .transform_patterns import transform_patterns +from .attn_patterns import attn_patterns, register_attn_patterns, deregister_attn_patterns +from .conv2d_patterns import conv2d_patterns +from .matmul_patterns import matmul_patterns diff --git a/python/hidet/graph/transforms/subgraph_rewrite.py b/python/hidet/graph/transforms/subgraph_rewrite.py index f03fac87f..71b0e3a7b 100644 --- a/python/hidet/graph/transforms/subgraph_rewrite.py +++ b/python/hidet/graph/transforms/subgraph_rewrite.py @@ -42,7 +42,7 @@ class SubgraphRewritePass(GraphPass): def process_graph(self, graph: FlowGraph) -> FlowGraph: graph = graph_utils.functors.clone(graph) for _ in range(self.max_num_transforms): - updated, graph = self.try_transform(graph, registered_rewrite_rules) + updated, graph = self.try_transform(graph, registered_rewrite_rules()) if not updated: graph.update_nodes() return graph diff --git a/tests/models/test_gpt2.py b/tests/models/test_gpt2.py index 4ef273ee9..ad46db0a6 100644 --- a/tests/models/test_gpt2.py +++ b/tests/models/test_gpt2.py @@ -36,7 +36,7 @@ def generate(model, text, num_hidden_layers, num_heads, head_dim, device, tokens @pytest.mark.parametrize('device,opt', [('cpu', False), ('cpu', True), ('cuda', False), ('cuda', True)]) def test_gpt2(device: str, opt: bool): - gpt2_module = hidet.testing.models.gpt2.model() + gpt2_module = hidet.testing.models.gpt2.model(disable_cache=True) if device == 'cuda': gpt2_module.cuda()