Skip to content

Commit c95fac6

Browse files
committed
enable decode_jpeg to support privateuseone device
Signed-off-by: taozhiwei <[email protected]>
1 parent 218d2ab commit c95fac6

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

test/test_image.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,35 @@ def test_decode_bad_huffman_images():
160160
decode_jpeg(bad_huff)
161161

162162

163+
def test_decode_jpeg_privateuseone_custom_backend():
164+
privateuseone_name = torch._C._get_privateuse1_backend_name()
165+
device = torch.device(privateuseone_name)
166+
data = torch.full((1, 2, 3), 233, dtype=torch.uint8)
167+
# When the custom operator is not registered, an error should
168+
# be reported and prompted to register decode_jpegs_privateuseone.
169+
with pytest.raises(RuntimeError, match="decode_jpegs_privateuseone"):
170+
decode_jpeg(data, device=device)
171+
172+
# Register a simple custom implementation to return the original data
173+
called = {}
174+
lib = torch.library.Library("image", "FRAGMENT")
175+
try:
176+
lib.define(
177+
"decode_jpegs_privateuseone(Tensor[] input, int mode=0, bool apply_exif_orientation=False) -> Tensor[]"
178+
)
179+
180+
@torch.library.impl(lib, "decode_jpegs_privateuseone", "CPU")
181+
def _decode_jpegs_privateuseone(input, mode=0, apply_exif_orientation=False):
182+
called["value"] = True
183+
return input
184+
except RuntimeError:
185+
pass
186+
187+
output = decode_jpeg(data, device=device)
188+
assert called.get("value") is True
189+
torch.testing.assert_close(output, data)
190+
191+
163192
@pytest.mark.parametrize(
164193
"img_path",
165194
[

torchvision/io/image.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,16 @@ def decode_jpeg(
210210
raise ValueError("All elements of the input list must be tensors.")
211211
if not all(t.device.type == "cpu" for t in input):
212212
raise ValueError("Input list must contain tensors on CPU.")
213+
custom_privateuse1_name = torch._C._get_privateuse1_backend_name()
214+
if device.type == custom_privateuse1_name or device.type == "privateuseone":
215+
# When the target device is privateuseone, switch to calling the custom decode_jpegs_privateuseone.
216+
# This operator needs to be pre-registered by the user through torch.library.define/impl.
217+
decoder = getattr(torch.ops.image, "decode_jpegs_privateuseone", None)
218+
if decoder is None:
219+
raise RuntimeError(
220+
"decode_jpeg(device='privateuseone') need register torch.ops.image.decode_jpegs_privateuseone."
221+
)
222+
return decoder(input, mode.value, apply_exif_orientation)
213223
if device.type == "cuda":
214224
return torch.ops.image.decode_jpegs_cuda(input, mode.value, device)
215225
else:
@@ -218,6 +228,14 @@ def decode_jpeg(
218228
else: # input is tensor
219229
if input.device.type != "cpu":
220230
raise ValueError("Input tensor must be a CPU tensor")
231+
if device.type == custom_privateuse1_name or device.type == "privateuseone":
232+
custom_privateuse1_name = torch._C._get_privateuse1_backend_name()
233+
decoder = getattr(torch.ops.image, "decode_jpegs_privateuseone", None)
234+
if decoder is None:
235+
raise RuntimeError(
236+
"decode_jpeg(device='privateuseone') need register torch.ops.image.decode_jpegs_privateuseone."
237+
)
238+
return decoder([input], mode.value, apply_exif_orientation)[0]
221239
if device.type == "cuda":
222240
return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0]
223241
else:

0 commit comments

Comments
 (0)