diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index e40183d8..4b232f84 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -4,6 +4,7 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any from typing import Any, Literal +import math import torch @@ -857,13 +858,24 @@ def _isscalar(a: object): min_is_scalar = _isscalar(min) max_is_scalar = _isscalar(max) - if min is not None and max is not None: - if min_is_scalar and not max_is_scalar: - min = torch.as_tensor(min, dtype=x.dtype, device=x.device) - if max_is_scalar and not min_is_scalar: - max = torch.as_tensor(max, dtype=x.dtype, device=x.device) + if min_is_scalar and max_is_scalar: + if (min is not None and math.isnan(min)) or (max is not None and math.isnan(max)): + # edge case: torch.clamp(torch.zeros(1), float('nan')) -> tensor(0.) + # https://github.com/pytorch/pytorch/issues/172067 + return torch.full_like(x, fill_value=torch.nan) + return torch.clamp(x, min, max, **kwargs) - return torch.clamp(x, min, max, **kwargs) + # pytorch has (tensor, tensor, tensor) and (tensor, scalar, scalar) signatures, + # but does not accept (tensor, scalar, tensor) + a_min = min + if min is not None and min_is_scalar: + a_min = torch.as_tensor(min, dtype=x.dtype, device=x.device) + + a_max = max + if max is not None and max_is_scalar: + a_max = torch.as_tensor(max, dtype=x.dtype, device=x.device) + + return torch.clamp(x, a_min, a_max, **kwargs) def sign(x: Array, /) -> Array: