Skip to content

Commit e86363f

Browse files
[SYCL] Fix race during image decompression (#19981)
**Problem** There can be race during image decompression. Consider the case when one thread is reading the decompressed buffer while another thread modifies it. This can lead to invalid SPIRV errors emitted by IGC. **Proposed solution** Use `std::call_once` to ensure that only one thread does the decompression of an image.
1 parent bbae036 commit e86363f

File tree

3 files changed

+84
-12
lines changed

3 files changed

+84
-12
lines changed

sycl/source/detail/device_binary_image.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -710,22 +710,29 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage(
710710
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart));
711711
}
712712

713+
// std::call_once ensures that this function is thread_safe and prevents
714+
// race during image decompression.
713715
void CompressedRTDeviceBinaryImage::Decompress() {
716+
auto DecompressFunc = [&]() {
717+
size_t CompressedDataSize =
718+
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart);
714719

715-
size_t CompressedDataSize =
716-
static_cast<size_t>(Bin->BinaryEnd - Bin->BinaryStart);
720+
size_t DecompressedSize = 0;
721+
m_DecompressedData = ZSTDCompressor::DecompressBlob(
722+
reinterpret_cast<const char *>(Bin->BinaryStart), CompressedDataSize,
723+
DecompressedSize);
717724

718-
size_t DecompressedSize = 0;
719-
m_DecompressedData = ZSTDCompressor::DecompressBlob(
720-
reinterpret_cast<const char *>(Bin->BinaryStart), CompressedDataSize,
721-
DecompressedSize);
725+
Bin->BinaryStart =
726+
reinterpret_cast<const unsigned char *>(m_DecompressedData.get());
727+
Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize;
722728

723-
Bin->BinaryStart =
724-
reinterpret_cast<const unsigned char *>(m_DecompressedData.get());
725-
Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize;
729+
Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());
730+
Format = static_cast<ur::DeviceBinaryType>(Bin->Format);
726731

727-
Bin->Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());
728-
Format = static_cast<ur::DeviceBinaryType>(Bin->Format);
732+
m_IsCompressed.store(false);
733+
};
734+
735+
std::call_once(m_InitFlag, DecompressFunc);
729736
}
730737

731738
CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage() {

sycl/source/detail/device_binary_image.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <atomic>
2020
#include <cstring>
2121
#include <memory>
22+
#include <mutex>
2223

2324
namespace sycl {
2425
inline namespace _V1 {
@@ -321,7 +322,8 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
321322
return m_ImageSize;
322323
}
323324

324-
bool IsCompressed() const { return m_DecompressedData.get() == nullptr; }
325+
bool IsCompressed() const { return m_IsCompressed.load(); }
326+
325327
void print() const override {
326328
RTDeviceBinaryImage::print();
327329
std::cerr << " COMPRESSED\n";
@@ -330,6 +332,10 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
330332
private:
331333
std::unique_ptr<char[]> m_DecompressedData;
332334
size_t m_ImageSize = 0;
335+
336+
// Flag to ensure decompression happens only once.
337+
std::once_flag m_InitFlag;
338+
std::atomic<bool> m_IsCompressed{true};
333339
};
334340
#endif // SYCL_RT_ZSTD_AVAILABLE
335341

sycl/unittests/compression/CompressionTests.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "../thread_safety/ThreadUtils.h"
1010
#include <detail/compression.hpp>
11+
#include <detail/device_binary_image.hpp>
1112
#include <sycl/sycl.hpp>
1213

1314
#include <string>
@@ -113,3 +114,61 @@ TEST(CompressionTest, ConcurrentCompressionDecompression) {
113114
::ThreadPool MPool(ThreadCount, testCompressDecompress);
114115
}
115116
}
117+
118+
// Test to decompress CompressedRTDeviceImage using multiple threads.
119+
// The idea behind this test is to ensure that a device image is
120+
// decompressed only once even if multiple threads try to decompress
121+
// it at the same time.
122+
TEST(CompressionTest, ConcurrentDecompressionOfDeviceImage) {
123+
// Data to compress.
124+
std::string data = "Hello World! I'm about to get compressed :P";
125+
126+
// Compress this data.
127+
size_t compressedSize = 0;
128+
auto compressedData = ZSTDCompressor::CompressBlob(data.c_str(), data.size(),
129+
compressedSize, 1);
130+
131+
unsigned char *compressedDataPtr =
132+
reinterpret_cast<unsigned char *>(compressedData.get());
133+
134+
const char *EntryName = "Entry";
135+
_sycl_offload_entry_struct EntryStruct = {
136+
/*addr*/ nullptr, const_cast<char *>(EntryName), strlen(EntryName),
137+
/*flags*/ 0, /*reserved*/ 0};
138+
sycl_device_binary_struct BinStruct{/*Version*/ 1,
139+
/*Kind*/ 4,
140+
/*Format*/ SYCL_DEVICE_BINARY_TYPE_SPIRV,
141+
/*DeviceTargetSpec*/ nullptr,
142+
/*CompileOptions*/ nullptr,
143+
/*LinkOptions*/ nullptr,
144+
/*ManifestStart*/ nullptr,
145+
/*ManifestEnd*/ nullptr,
146+
/*BinaryStart*/ compressedDataPtr,
147+
/*BinaryEnd*/ compressedDataPtr +
148+
compressedSize,
149+
/*EntriesBegin*/ &EntryStruct,
150+
/*EntriesEnd*/ &EntryStruct + 1,
151+
/*PropertySetsBegin*/ nullptr,
152+
/*PropertySetsEnd*/ nullptr};
153+
sycl_device_binary Bin = &BinStruct;
154+
CompressedRTDeviceBinaryImage Img{Bin};
155+
156+
// Decompress the image with multiple threads.
157+
constexpr size_t ThreadCount = 20;
158+
Barrier b(ThreadCount);
159+
{
160+
auto testDecompress = [&](size_t threadId) {
161+
b.wait();
162+
Img.Decompress();
163+
164+
// Check if decompressed data is same as original data.
165+
// Img.getRawData will change if there's a race in image decompression
166+
// and the check will fail.
167+
for (size_t i = 0; i < Img.getSize(); ++i) {
168+
ASSERT_EQ(data[i], Img.getRawData().BinaryStart[i]);
169+
}
170+
};
171+
172+
::ThreadPool MPool(ThreadCount, testDecompress);
173+
}
174+
}

0 commit comments

Comments
 (0)