From 8a8b35bc7846118bfa05a44e29f20c1e49db6491 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 15 Jul 2024 11:24:18 +0100 Subject: [PATCH] Add basic check for int32 --- test/test_transforms_v2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 5e12a8b860a..3a338f55409 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2179,6 +2179,7 @@ def test_uint16(self): img_uint8 = F.to_dtype(img_uint16, torch.uint8, scale=True) img_float32 = F.to_dtype(img_uint16, torch.float32, scale=True) + img_int32 = F.to_dtype(img_uint16, torch.int32, scale=True) assert_equal(img_uint8, (img_uint16 / 256).to(torch.uint8)) assert_close(img_float32, (img_uint16 / 65535)) @@ -2187,6 +2188,7 @@ def test_uint16(self): # Ideally we'd check against (img_uint16 & 0xFF00) but bitwise and isn't supported for it yet # so we simulate it by scaling down and up again. assert_equal(F.to_dtype(img_uint8, torch.uint16, scale=True), ((img_uint16 / 256).to(torch.uint16) * 256)) + assert_equal(F.to_dtype(img_int32, torch.uint16, scale=True), img_uint16) assert_equal(F.to_dtype(img_float32, torch.uint8, scale=True), img_uint8) assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2)