diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 940a0e681..a5ed4f3f3 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -962,6 +962,8 @@ def tensor_where(self: Tensor, condition: Tensor, y: Union[Tensor, Number]): @register_method(torch.Tensor.pow_) def torch_pow(base: Union[Number, Tensor], exponent: Union[Number, Tensor]): if isinstance(exponent, (int, float, bool)): + if exponent in (2, 2.0): + return ops.square(base) exponent = full_like(base, exponent) elif isinstance(base, (int, float, bool)): base = full_like(exponent, base)