From ccbd57a8b665fbb5b1d566c0b800dc6ede509e8e Mon Sep 17 00:00:00 2001
From: Joao Gante <joaofranciscocardosogante@gmail.com>
Date: Mon, 4 Nov 2024 16:18:50 +0000
Subject: [PATCH] MPS: `isin_mps_friendly` can support 0D tensors (#34538)

* apply fix

* tested

* make fixup
---
 src/transformers/pytorch_utils.py  | 5 ++++-
 tests/utils/test_modeling_utils.py | 7 ++++++-
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py
index f3663c09902f..a808f2cb63e8 100644
--- a/src/transformers/pytorch_utils.py
+++ b/src/transformers/pytorch_utils.py
@@ -314,7 +314,7 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
 
     Args:
         elements (`torch.Tensor`): Input elements
-        test_elements (`torch.Tensor`): The elements to check against.
+        test_elements (`torch.Tensor` or `int`): The elements to check against.
 
     Returns:
         `torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
@@ -322,6 +322,9 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
     """
 
     if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
+        test_elements = torch.tensor(test_elements)
+        if test_elements.ndim == 0:
+            test_elements = test_elements.unsqueeze(0)
         return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
     else:
         # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py
index 8af47cde8e53..430043496c42 100644
--- a/tests/utils/test_modeling_utils.py
+++ b/tests/utils/test_modeling_utils.py
@@ -1710,7 +1710,12 @@ def test_isin_mps_friendly(self):
                 torch.isin(random_ids, random_test_integer), isin_mps_friendly(random_ids, random_test_integer)
             )
         )
-        # We can match against an tensor of integers
+        # We can match against an 0D tensor
+        random_test_tensor = torch.randint(0, 100, (1,)).squeeze()
+        self.assertTrue(
+            torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))
+        )
+        # We can match against an 1D tensor (with many items)
         random_test_tensor = torch.randint(0, 100, (10,))
         self.assertTrue(
             torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor))