From f421a43a24a9c3a92e831b895ab5ebb3574d94f6 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Date: Tue, 27 Aug 2024 06:01:29 +0400 Subject: [PATCH] [PERF] Specialize pow(x,2) as x*x. llama-7B (#434) Right now `pow` with const exp argument is implemented simply. We convert const to const tensor and run elementwise `pow` of 2 tensors. It is simply but not always efficient. llama2 (RMSNorm part) has `x*x` that implemented as `tensor.pow(2)`. Convert `pow(x,2)` to `x*x`. Improvement on llama2-7B is around **0.237%** --- python/hidet/graph/frontend/torch/register_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 940a0e681..a5ed4f3f3 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -962,6 +962,8 @@ def tensor_where(self: Tensor, condition: Tensor, y: Union[Tensor, Number]): @register_method(torch.Tensor.pow_) def torch_pow(base: Union[Number, Tensor], exponent: Union[Number, Tensor]): if isinstance(exponent, (int, float, bool)): + if exponent in (2, 2.0): + return ops.square(base) exponent = full_like(base, exponent) elif isinstance(base, (int, float, bool)): base = full_like(exponent, base)