Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
expecttest
filelock
fsspec
jinja2
Expand Down
4 changes: 4 additions & 0 deletions requirements_lock_3_10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#
# bazel run //:requirements.update
#
expecttest==0.3.0 \
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
# via -r requirements.in
filelock==3.14.0 \
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a
Expand Down
4 changes: 4 additions & 0 deletions requirements_lock_3_11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#
# bazel run //:requirements.update
#
expecttest==0.3.0 \
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
# via -r requirements.in
filelock==3.14.0 \
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a
Expand Down
4 changes: 4 additions & 0 deletions requirements_lock_3_12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#
# bazel run //:requirements.update
#
expecttest==0.3.0 \
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
# via -r requirements.in
filelock==3.18.0 \
--hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \
--hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de
Expand Down
4 changes: 4 additions & 0 deletions requirements_lock_3_13.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#
# bazel run //:requirements.update
#
expecttest==0.3.0 \
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
# via -r requirements.in
filelock==3.18.0 \
--hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \
--hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de
Expand Down
4 changes: 4 additions & 0 deletions requirements_lock_3_8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#
# bazel run //:requirements.update
#
expecttest==0.3.0 \
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
# via -r requirements.in
filelock==3.14.0 \
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a
Expand Down
4 changes: 4 additions & 0 deletions requirements_lock_3_9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
#
# bazel run //:requirements.update
#
expecttest==0.3.0 \
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
# via -r requirements.in
filelock==3.14.0 \
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a
Expand Down
164 changes: 0 additions & 164 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(value=True, reason=reason)


def onlyOnCPU(fn):
accelerator = os.environ.get("PJRT_DEVICE").lower()
return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn)


def onlyIfXLAExperimentalContains(feat):
experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":")
return unittest.skipIf(feat not in experimental,
Expand Down Expand Up @@ -2372,165 +2367,6 @@ def test_isneginf_no_fallback(self):
t = t.to(torch.float16)
self._test_no_fallback(torch.isneginf, (t,))

def test_add_broadcast_error(self):
a = torch.rand(2, 2, 4, 4, device="xla")
b = torch.rand(2, 2, device="xla")

expected_regex = (
r"Shapes are not compatible for broadcasting: f32\[2,2,4,4\] vs. f32\[2,2\]. "
r"Expected dimension 2 of shape f32\[2,2,4,4\] \(4\) to match dimension "
r"0 of shape f32\[2,2\] \(2\). .*")

with self.assertRaisesRegex(RuntimeError, expected_regex):
torch.add(a, b)
torch_xla.sync()

@onlyOnCPU
def test_construct_large_tensor_raises_error(self):
with self.assertRaisesRegex(RuntimeError,
r"Out of memory allocating \d+ bytes"):
# When eager-mode is enabled, OOM is triggered here.
a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device())
b = a.sum()
# OOM is raised when we try to bring data from the device.
b.cpu()

def test_cat_raises_error_on_incompatible_shapes(self):
a = torch.rand(2, 2, device=torch_xla.device())
b = torch.rand(5, 1, device=torch_xla.device())

try:
torch.cat([a, b])
except RuntimeError as e:
expected_error = (
"cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] "
"at dimension 0. Expected shapes to be equal (except at dimension 0) "
"or that either of them was a 1D empty tensor of size (0,).")
self.assertEqual(str(e), expected_error)

def test_div_raises_error_on_invalid_rounding_mode(self):
a = torch.rand(2, 2, device=torch_xla.device())

try:
torch.div(a, 2, rounding_mode="bad")
except RuntimeError as e:
expected_error = (
"div(): invalid rounding mode `bad`. Expected it to be either "
"'trunc', 'floor', or be left unspecified.")
self.assertEqual(str(e), expected_error)

def test_flip_raises_error_on_duplicated_dims(self):
a = torch.rand(2, 2, 2, 2, device=torch_xla.device())
dims = [0, 0, 0, 1, 2, 3, -1]
dims_suggestion = [0, 1, 2, 3]

try:
torch.flip(a, dims=dims)
except RuntimeError as e:
expected_error = (
"flip(): expected each dimension to appear at most once. Found "
"dimensions: 0 (3 times), 3 (2 times). Consider changing dims "
f"from {dims} to {dims_suggestion}.")
self.assertEqual(str(e), expected_error)

def test_full_raises_error_on_negative_size(self):
shape = [2, -2, 2]
try:
torch.full(shape, 1.5, device="xla")
except RuntimeError as e:
expected_error = (
"full(): expected concrete sizes (i.e. non-symbolic) to be "
f"positive values. However found negative ones: {shape}.")
self.assertEqual(str(e), expected_error)

def test_gather_raises_error_on_rank_mismatch(self):
S = 2

input = torch.arange(4, device=torch_xla.device()).view(S, S)
index = torch.randint(0, S, (S, S, S), device=torch_xla.device())
dim = 1

try:
torch.gather(input, dim, index)
except RuntimeError as e:
expected_error = (
"gather(): expected rank of input (2) and index (3) tensors "
"to be the same.")
self.assertEqual(str(e), expected_error)

def test_gather_raises_error_on_invalid_index_size(self):
S = 2
X = S + 2

input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S)
index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device())
dim = 1

try:
torch.gather(input, dim, index)
except RuntimeError as e:
expected_error = (
f"gather(): expected sizes of index [{X}, {S}, {X}, {S}] to be "
f"smaller or equal those of input [{S}, {S}, {S}, {S}] on all "
f"dimensions, except on dimension {dim}. "
"However, that's not true on dimensions [0, 2].")
self.assertEqual(str(e), expected_error)

def test_random__raises_error_on_empty_interval(self):
a = torch.empty(10, device=torch_xla.device())
from_ = 3
to_ = 1

try:
a.random_(from_, to_)
except RuntimeError as e:
expected_error = (
f"random_(): expected `from` ({from_}) to be smaller than "
f"`to` ({to_}).")
self.assertEqual(str(e), expected_error)

def test_random__raises_error_on_value_out_of_type_value_range(self):
a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16)
from_ = 3
to_ = 65504 + 1

try:
a.random_(from_, to_)
except RuntimeError as e:
expected_error = (
f"random_(): expected `to` to be within the range "
f"[-65504, 65504]. However got value {to_}, which is greater "
"than the upper bound.")
self.assertEqual(str(e), expected_error)

def test_mm_raises_error_on_non_matrix_input(self):
device = torch_xla.device()
a = torch.rand(2, 2, 2, device=device)
b = torch.rand(2, 2, device=device)

try:
torch.mm(a, b)
except RuntimeError as e:
expected_error = (
"mm(): expected the first input tensor f32[2,2,2] to be a "
"matrix (i.e. a 2D tensor).")
self.assertEqual(str(e), expected_error)

def test_mm_raises_error_on_incompatible_shapes(self):
device = torch_xla.device()
a = torch.rand(2, 5, device=device)
b = torch.rand(8, 2, device=device)

try:
torch.mm(a, b)
except RuntimeError as e:
expected_error = (
"mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. "
"Expected the size of dimension 1 of the first input tensor (5) "
"to be equal the size of dimension 0 of the second input "
"tensor (8).")
self.assertEqual(str(e), expected_error)


class MNISTComparator(nn.Module):

Expand Down
Loading
Loading