Skip to content

Commit 4a4ad1e

Browse files
committed
enable decode_jpeg to support privateuseone device
Signed-off-by: taozhiwei <[email protected]>
1 parent 218d2ab commit 4a4ad1e

File tree

2 files changed

+97
-2
lines changed

2 files changed

+97
-2
lines changed

test/test_image.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,61 @@ def test_decode_bad_huffman_images():
160160
decode_jpeg(bad_huff)
161161

162162

163+
def test_encode_jpeg_privateuseone_custom_backend():
164+
privateuseone_name = torch._C._get_privateuse1_backend_name()
165+
device = torch.device(privateuseone_name)
166+
167+
data = torch.randint(0, 256, size=(3, 4, 5), dtype=torch.uint8, device=device)
168+
169+
with pytest.raises(RuntimeError, match="encode_jpegs_privateuseone"):
170+
encode_jpeg(data)
171+
172+
lib = torch.library.Library("image", "FRAGMENT")
173+
called = {}
174+
175+
try:
176+
lib.define("encode_jpegs_privateuseone(Tensor[] input, int quality=75) -> Tensor[]")
177+
178+
@torch.library.impl(lib, "encode_jpegs_privateuseone", "PrivateUse1")
179+
def _encode_jpegs_privateuseone(input, quality=75):
180+
called["value"] = True
181+
return input
182+
except RuntimeError:
183+
pass
184+
encoded = encode_jpeg(data)
185+
assert called.get("value") is True
186+
torch.testing.assert_close(encoded, data.cpu())
187+
188+
189+
def test_decode_jpeg_privateuseone_custom_backend():
190+
privateuseone_name = torch._C._get_privateuse1_backend_name()
191+
device = torch.device(privateuseone_name)
192+
data = torch.full((1, 2, 3), 233, dtype=torch.uint8)
193+
# When the custom operator is not registered, an error should
194+
# be reported and prompted to register decode_jpegs_privateuseone.
195+
with pytest.raises(RuntimeError, match="decode_jpegs_privateuseone"):
196+
decode_jpeg(data, device=device)
197+
198+
# Register a simple custom implementation to return the original data
199+
called = {}
200+
lib = torch.library.Library("image", "FRAGMENT")
201+
try:
202+
lib.define(
203+
"decode_jpegs_privateuseone(Tensor[] input, int mode=0, bool apply_exif_orientation=False) -> Tensor[]"
204+
)
205+
206+
@torch.library.impl(lib, "decode_jpegs_privateuseone", "CPU")
207+
def _decode_jpegs_privateuseone(input, mode=0, apply_exif_orientation=False):
208+
called["value"] = True
209+
return input
210+
except RuntimeError:
211+
pass
212+
213+
output = decode_jpeg(data, device=device)
214+
assert called.get("value") is True
215+
torch.testing.assert_close(output, data)
216+
217+
163218
@pytest.mark.parametrize(
164219
"img_path",
165220
[

torchvision/io/image.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,17 @@ def decode_jpeg(
210210
raise ValueError("All elements of the input list must be tensors.")
211211
if not all(t.device.type == "cpu" for t in input):
212212
raise ValueError("Input list must contain tensors on CPU.")
213+
custom_privateuse1_name = torch._C._get_privateuse1_backend_name()
214+
if device.type == custom_privateuse1_name or device.type == "privateuseone":
215+
# When the target device is privateuseone, switch to calling the custom decode_jpegs_privateuseone.
216+
# This operator needs to be pre-registered by the user through torch.library.define/impl.
217+
decoder = getattr(torch.ops.image, "decode_jpegs_privateuseone", None)
218+
if decoder is None:
219+
raise RuntimeError(
220+
"decode_jpeg tensors on PrivateUse1 device require registering "
221+
"torch.ops.image.decode_jpegs_privateuseone."
222+
)
223+
return decoder(input, mode.value, apply_exif_orientation)
213224
if device.type == "cuda":
214225
return torch.ops.image.decode_jpegs_cuda(input, mode.value, device)
215226
else:
@@ -218,6 +229,15 @@ def decode_jpeg(
218229
else: # input is tensor
219230
if input.device.type != "cpu":
220231
raise ValueError("Input tensor must be a CPU tensor")
232+
if device.type == custom_privateuse1_name or device.type == "privateuseone":
233+
custom_privateuse1_name = torch._C._get_privateuse1_backend_name()
234+
decoder = getattr(torch.ops.image, "decode_jpegs_privateuseone", None)
235+
if decoder is None:
236+
raise RuntimeError(
237+
"decode_jpeg tensor on PrivateUse1 device require registering "
238+
"torch.ops.image.decode_jpegs_privateuseone."
239+
)
240+
return decoder([input], mode.value, apply_exif_orientation)[0]
221241
if device.type == "cuda":
222242
return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0]
223243
else:
@@ -246,16 +266,36 @@ def encode_jpeg(
246266
_log_api_usage_once(encode_jpeg)
247267
if quality < 1 or quality > 100:
248268
raise ValueError("Image quality should be a positive number between 1 and 100")
269+
custom_privateuse1_name = torch._C._get_privateuse1_backend_name()
270+
249271
if isinstance(input, list):
250272
if not input:
251273
raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
252-
if input[0].device.type == "cuda":
274+
device_type = input[0].device.type
275+
if device_type == custom_privateuse1_name or device_type == "privateuseone":
276+
encoder = getattr(torch.ops.image, "encode_jpegs_privateuseone", None)
277+
if encoder is None:
278+
raise RuntimeError(
279+
"encode_jpeg tensors on PrivateUse1 device require registering "
280+
"torch.ops.image.encode_jpegs_privateuseone."
281+
)
282+
return encoder(input, quality)
283+
if device_type == "cuda":
253284
return torch.ops.image.encode_jpegs_cuda(input, quality)
254285
else:
255286
return [torch.ops.image.encode_jpeg(image, quality) for image in input]
256287
else: # single input tensor
257-
if input.device.type == "cuda":
288+
device_type = input.device.type
289+
if device_type == "cuda":
258290
return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
291+
elif device_type == custom_privateuse1_name or device_type == "privateuseone":
292+
encoder = getattr(torch.ops.image, "encode_jpegs_privateuseone", None)
293+
if encoder is None:
294+
raise RuntimeError(
295+
"encode_jpeg tensor on PrivateUse1 device require registering "
296+
"torch.ops.image.encode_jpegs_privateuseone."
297+
)
298+
return encoder([input], quality)[0]
259299
else:
260300
return torch.ops.image.encode_jpeg(input, quality)
261301

0 commit comments

Comments
 (0)