@@ -272,21 +272,36 @@ SYCLBINBinaries::SYCLBINBinaries(const char *SYCLBINContent, size_t SYCLBINSize)
272
272
: SYCLBINContentCopy{ContentCopy (SYCLBINContent, SYCLBINSize)},
273
273
SYCLBINContentCopySize{SYCLBINSize},
274
274
ParsedSYCLBIN (SYCLBIN{SYCLBINContentCopy.get (), SYCLBINSize}) {
275
- size_t NumJITBinaries = 0 , NumNativeBinaries = 0 ;
276
- for (const SYCLBIN::AbstractModule &AM : ParsedSYCLBIN.AbstractModules ) {
277
- NumJITBinaries += AM.IRModules .size ();
278
- NumNativeBinaries += AM.NativeDeviceCodeImages .size ();
279
- }
280
- DeviceBinaries.reserve (NumJITBinaries + NumNativeBinaries);
281
- JITDeviceBinaryImages.reserve (NumJITBinaries);
282
- NativeDeviceBinaryImages.reserve (NumNativeBinaries);
275
+ AbstractModuleDescriptors = std::unique_ptr<AbstractModuleDesc[]>(
276
+ new AbstractModuleDesc[ParsedSYCLBIN.AbstractModules .size ()]);
277
+
278
+ size_t NumBinaries = 0 ;
279
+ for (const SYCLBIN::AbstractModule &AM : ParsedSYCLBIN.AbstractModules )
280
+ NumBinaries += AM.IRModules .size () + AM.NativeDeviceCodeImages .size ();
281
+ DeviceBinaries.reserve (NumBinaries);
282
+ BinaryImages = std::unique_ptr<RTDeviceBinaryImage[]>(
283
+ new RTDeviceBinaryImage[NumBinaries]);
284
+
285
+ RTDeviceBinaryImage *CurrentBinaryImagesStart = BinaryImages.get ();
286
+ for (size_t I = 0 ; I < getNumAbstractModules (); ++I) {
287
+ SYCLBIN::AbstractModule &AM = ParsedSYCLBIN.AbstractModules [I];
288
+ AbstractModuleDesc &AMDesc = AbstractModuleDescriptors[I];
289
+
290
+ // Set up the abstract module descriptor.
291
+ AMDesc.NumJITBinaries = AM.IRModules .size ();
292
+ AMDesc.NumNativeBinaries = AM.NativeDeviceCodeImages .size ();
293
+ AMDesc.JITBinaries = CurrentBinaryImagesStart;
294
+ AMDesc.NativeBinaries = CurrentBinaryImagesStart + AMDesc.NumJITBinaries ;
295
+ CurrentBinaryImagesStart +=
296
+ AMDesc.NumJITBinaries + AM.NativeDeviceCodeImages .size ();
283
297
284
- for (SYCLBIN::AbstractModule &AM : ParsedSYCLBIN.AbstractModules ) {
285
298
// Construct properties from SYCLBIN metadata.
286
299
std::vector<_sycl_device_binary_property_set_struct> &BinPropertySets =
287
300
convertAbstractModuleProperties (AM);
288
301
289
- for (SYCLBIN::IRModule &IRM : AM.IRModules ) {
302
+ for (size_t J = 0 ; J < AM.IRModules .size (); ++J) {
303
+ SYCLBIN::IRModule &IRM = AM.IRModules [J];
304
+
290
305
sycl_device_binary_struct &DeviceBinary = DeviceBinaries.emplace_back ();
291
306
DeviceBinary.Version = SYCL_DEVICE_BINARY_VERSION;
292
307
DeviceBinary.Kind = 4 ;
@@ -309,11 +324,12 @@ SYCLBINBinaries::SYCLBINBinaries(const char *SYCLBINContent, size_t SYCLBINSize)
309
324
DeviceBinary.PropertySetsEnd =
310
325
BinPropertySets.data () + BinPropertySets.size ();
311
326
// Create an image from it.
312
- JITDeviceBinaryImages. emplace_back ( &DeviceBinary) ;
327
+ AMDesc. JITBinaries [J] = RTDeviceBinaryImage{ &DeviceBinary} ;
313
328
}
314
329
315
- for (const SYCLBIN::NativeDeviceCodeImage &NDCI :
316
- AM.NativeDeviceCodeImages ) {
330
+ for (size_t J = 0 ; J < AM.NativeDeviceCodeImages .size (); ++J) {
331
+ const SYCLBIN::NativeDeviceCodeImage &NDCI = AM.NativeDeviceCodeImages [J];
332
+
317
333
assert (NDCI.Metadata != nullptr );
318
334
PropertySet &NDCIMetadataProps = (*NDCI.Metadata )
319
335
[PropertySetRegistry::SYCLBIN_NATIVE_DEVICE_CODE_IMAGE_METADATA];
@@ -346,7 +362,7 @@ SYCLBINBinaries::SYCLBINBinaries(const char *SYCLBINContent, size_t SYCLBINSize)
346
362
DeviceBinary.PropertySetsEnd =
347
363
BinPropertySets.data () + BinPropertySets.size ();
348
364
// Create an image from it.
349
- NativeDeviceBinaryImages. emplace_back ( &DeviceBinary) ;
365
+ AMDesc. NativeBinaries [J] = RTDeviceBinaryImage{ &DeviceBinary} ;
350
366
}
351
367
}
352
368
}
@@ -394,33 +410,44 @@ SYCLBINBinaries::convertAbstractModuleProperties(SYCLBIN::AbstractModule &AM) {
394
410
}
395
411
396
412
std::vector<const RTDeviceBinaryImage *>
397
- SYCLBINBinaries::getBestCompatibleImages (device_impl &Dev) {
398
- auto SelectCompatibleImages =
399
- [&](const std::vector<RTDeviceBinaryImage> &Imgs) {
400
- std::vector<const RTDeviceBinaryImage *> CompatImgs;
401
- for (const RTDeviceBinaryImage &Img : Imgs)
402
- if (doesDevSupportDeviceRequirements (Dev, Img) &&
403
- doesImageTargetMatchDevice (Img, Dev))
404
- CompatImgs.push_back (&Img);
405
- return CompatImgs;
406
- };
407
-
408
- // Try with native images first.
409
- std::vector<const RTDeviceBinaryImage *> NativeImgs =
410
- SelectCompatibleImages (NativeDeviceBinaryImages);
411
- if (!NativeImgs.empty ())
412
- return NativeImgs;
413
-
414
- // If there were no native images, pick JIT images.
415
- return SelectCompatibleImages (JITDeviceBinaryImages);
413
+ SYCLBINBinaries::getBestCompatibleImages (device_impl &Dev, bundle_state State) {
414
+ auto GetCompatibleImage = [&](const RTDeviceBinaryImage *Imgs,
415
+ size_t NumImgs) {
416
+ const RTDeviceBinaryImage *CompatImagePtr =
417
+ std::find_if (Imgs, Imgs + NumImgs, [&](const RTDeviceBinaryImage &Img) {
418
+ return doesDevSupportDeviceRequirements (Dev, Img) &&
419
+ doesImageTargetMatchDevice (Img, Dev);
420
+ });
421
+ return (CompatImagePtr != Imgs + NumImgs) ? CompatImagePtr : nullptr ;
422
+ };
423
+
424
+ std::vector<const RTDeviceBinaryImage *> Images;
425
+ for (size_t I = 0 ; I < getNumAbstractModules (); ++I) {
426
+ const AbstractModuleDesc &AMDesc = AbstractModuleDescriptors[I];
427
+ // If the target state is executable, try with native images first.
428
+ if (State == bundle_state::executable) {
429
+ if (const RTDeviceBinaryImage *CompatImagePtr = GetCompatibleImage (
430
+ AMDesc.NativeBinaries , AMDesc.NumNativeBinaries )) {
431
+ Images.push_back (CompatImagePtr);
432
+ continue ;
433
+ }
434
+ }
435
+
436
+ // Otherwise, select the first compatible JIT binary.
437
+ if (const RTDeviceBinaryImage *CompatImagePtr =
438
+ GetCompatibleImage (AMDesc.JITBinaries , AMDesc.NumJITBinaries ))
439
+ Images.push_back (CompatImagePtr);
440
+ }
441
+ return Images;
416
442
}
417
443
418
444
std::vector<const RTDeviceBinaryImage *>
419
- SYCLBINBinaries::getBestCompatibleImages (devices_range Devs) {
445
+ SYCLBINBinaries::getBestCompatibleImages (devices_range Devs,
446
+ bundle_state State) {
420
447
std::set<const RTDeviceBinaryImage *> Images;
421
448
for (device_impl &Dev : Devs) {
422
449
std::vector<const RTDeviceBinaryImage *> BestImagesForDev =
423
- getBestCompatibleImages (Dev);
450
+ getBestCompatibleImages (Dev, State );
424
451
Images.insert (BestImagesForDev.cbegin (), BestImagesForDev.cend ());
425
452
}
426
453
return {Images.cbegin (), Images.cend ()};
0 commit comments