@@ -210,6 +210,16 @@ def decode_jpeg(
210210 raise ValueError ("All elements of the input list must be tensors." )
211211 if not all (t .device .type == "cpu" for t in input ):
212212 raise ValueError ("Input list must contain tensors on CPU." )
213+ custom_privateuse1_name = torch ._C ._get_privateuse1_backend_name ()
214+ if device .type == custom_privateuse1_name or device .type == "privateuseone" :
215+ # When the target device is privateuseone, switch to calling the custom decode_jpegs_privateuseone.
216+ # This operator needs to be pre-registered by the user through torch.library.define/impl.
217+ decoder = getattr (torch .ops .image , "decode_jpegs_privateuseone" , None )
218+ if decoder is None :
219+ raise RuntimeError (
220+ "decode_jpeg(device='privateuseone') need register torch.ops.image.decode_jpegs_privateuseone."
221+ )
222+ return decoder (input , mode .value , apply_exif_orientation )
213223 if device .type == "cuda" :
214224 return torch .ops .image .decode_jpegs_cuda (input , mode .value , device )
215225 else :
@@ -218,6 +228,14 @@ def decode_jpeg(
218228 else : # input is tensor
219229 if input .device .type != "cpu" :
220230 raise ValueError ("Input tensor must be a CPU tensor" )
231+ if device .type == custom_privateuse1_name or device .type == "privateuseone" :
232+ custom_privateuse1_name = torch ._C ._get_privateuse1_backend_name ()
233+ decoder = getattr (torch .ops .image , "decode_jpegs_privateuseone" , None )
234+ if decoder is None :
235+ raise RuntimeError (
236+ "decode_jpeg(device='privateuseone') need register torch.ops.image.decode_jpegs_privateuseone."
237+ )
238+ return decoder ([input ], mode .value , apply_exif_orientation )[0 ]
221239 if device .type == "cuda" :
222240 return torch .ops .image .decode_jpegs_cuda ([input ], mode .value , device )[0 ]
223241 else :
0 commit comments