diff --git a/bolt/shuffle/sparksql/BoltShuffleReader.cpp b/bolt/shuffle/sparksql/BoltShuffleReader.cpp index 67fc92265..6a5c6d045 100644 --- a/bolt/shuffle/sparksql/BoltShuffleReader.cpp +++ b/bolt/shuffle/sparksql/BoltShuffleReader.cpp @@ -371,6 +371,50 @@ RowVectorPtr makeColumnarBatch( } // namespace +std::unique_ptr BoltColumnarBatchDeserializer::drainSaved() { + if (savedPayloads_.payloads.empty()) { + return nullptr; + } + if (savedPayloads_.payloads.size() == 1) { + auto payload = std::move(savedPayloads_.payloads[0]); + savedPayloads_ = {}; + return payload; + } + auto numBuffers = savedPayloads_.payloads[0]->numBuffers(); + // padding size to avoid simd instructions memory overflow + std::vector bufferSizes(numBuffers, simd::kPadding); + for (const auto& payload : savedPayloads_.payloads) { + for (size_t i = 0; i < numBuffers; ++i) { + bufferSizes[i] += payload->bufferSizeAt(i); + } + } + std::vector> arrowBuffers; + for (int i = 0; i < numBuffers; ++i) { + auto buffer = arrow::AllocateResizableBuffer(bufferSizes[i], memoryPool_); + BOLT_CHECK( + buffer.ok(), + "Failed to allocate resizable buffer at index " + std::to_string(i)); + buffer.ValueUnsafe()->Resize(0, false); + arrowBuffers.emplace_back(std::move(buffer).ValueUnsafe()); + } + + auto payload = std::make_unique( + 0, isValidityBuffer_, std::move(arrowBuffers)); + + for (auto& savedPayload : savedPayloads_.payloads) { + auto result = InMemoryPayload::merge( + std::move(payload), + std::move(savedPayload), + memoryPool_, + INT64_MAX, + INT64_MIN); + BOLT_CHECK(result.ok(), "Failed to merge payloads"); + payload = std::move(result.ValueUnsafe()); + } + savedPayloads_ = {}; + return payload; +} + BoltColumnarBatchDeserializer::BoltColumnarBatchDeserializer( std::shared_ptr in, const std::shared_ptr& schema, @@ -475,22 +519,18 @@ RowVectorPtr BoltColumnarBatchDeserializer::next() { } if (reachEos_) { - if (merged_) { + if (!savedPayloads_.payloads.empty()) { return makeColumnarBatch( - rowType_, - std::move(merged_), - boltPool_, - deserializeTime_, - memoryPool_); + rowType_, drainSaved(), boltPool_, deserializeTime_, memoryPool_); } return nullptr; } std::vector> arrowBuffers{}; uint32_t numRows = 0; - while (!merged_ || - (merged_->numRows() < batchSize_ && - merged_->getBufferSize() < shuffleBatchByteSize_)) { + while (savedPayloads_.payloads.empty() || + (savedPayloads_.rowCount < batchSize_ && + savedPayloads_.size < shuffleBatchByteSize_)) { if (!payloadType_.has_value()) { int64_t bytes = 0; bool isComposite = isCompositeRowVectorLayout(bytes); @@ -502,7 +542,7 @@ RowVectorPtr BoltColumnarBatchDeserializer::next() { vectorLayout_ = RowVectorLayout::kComposite; zstdCodec_->markHeaderSkipped(readAheadBuffer_.size); readAheadBuffer_.reset(); - if (!merged_) { + if (savedPayloads_.payloads.empty()) { return nextFromRows(); } else { break; @@ -522,47 +562,36 @@ RowVectorPtr BoltColumnarBatchDeserializer::next() { result.ok(), "Failed to deserialize BlockPayload: " + result.status().message()); arrowBuffers = std::move(result.ValueUnsafe()); - if (!merged_) { - merged_ = std::make_unique( - numRows, isValidityBuffer_, std::move(arrowBuffers)); + if (savedPayloads_.payloads.empty()) { + savedPayloads_.save(std::make_unique( + numRows, isValidityBuffer_, std::move(arrowBuffers))); arrowBuffers.clear(); continue; } - auto mergedRows = merged_->numRows() + numRows; - auto mergedByteSize = - merged_->getBufferSize() + getBufferSize(arrowBuffers); + auto mergedRows = savedPayloads_.rowCount + numRows; + auto mergedByteSize = savedPayloads_.size + getBufferSize(arrowBuffers); if (mergedRows > batchSize_ || mergedByteSize > shuffleBatchByteSize_) { break; } auto append = std::make_unique( numRows, isValidityBuffer_, std::move(arrowBuffers)); - auto mergeResult = InMemoryPayload::merge( - std::move(merged_), - std::move(append), - memoryPool_, - INT64_MAX, - INT64_MIN); - BOLT_CHECK( - mergeResult.ok(), - "Failed to merge payloads: " + mergeResult.status().message()); - merged_ = std::move(mergeResult.ValueUnsafe()); - + savedPayloads_.save(std::move(append)); arrowBuffers.clear(); } // Reach EOS. - if (reachEos_ && !merged_) { + if (reachEos_ && savedPayloads_.payloads.empty()) { return nullptr; } auto columnarBatch = makeColumnarBatch( - rowType_, std::move(merged_), boltPool_, deserializeTime_, memoryPool_); + rowType_, drainSaved(), boltPool_, deserializeTime_, memoryPool_); // Save remaining rows. if (!arrowBuffers.empty()) { - merged_ = std::make_unique( - numRows, isValidityBuffer_, std::move(arrowBuffers)); + savedPayloads_.save(std::make_unique( + numRows, isValidityBuffer_, std::move(arrowBuffers))); } return columnarBatch; } diff --git a/bolt/shuffle/sparksql/BoltShuffleReader.h b/bolt/shuffle/sparksql/BoltShuffleReader.h index d969a73a8..624e18c07 100644 --- a/bolt/shuffle/sparksql/BoltShuffleReader.h +++ b/bolt/shuffle/sparksql/BoltShuffleReader.h @@ -120,6 +120,8 @@ class BoltColumnarBatchDeserializer { private: bytedance::bolt::RowVectorPtr nextFromRows(); FLATTEN bool isCompositeRowVectorLayout(int64_t& bytes); + // merge the saved payloads into a single InMemoryPayload and clear saved. + std::unique_ptr drainSaved(); std::shared_ptr in_; std::shared_ptr schema_; @@ -135,7 +137,18 @@ class BoltColumnarBatchDeserializer { uint64_t& deserializeTime_; uint64_t& decompressTime_; - std::unique_ptr merged_{nullptr}; + struct SavedPayload { + uint64_t size{0}; + uint64_t rowCount{0}; + std::vector> payloads; + + void save(std::unique_ptr payload) { + size += payload->getBufferSize(); + rowCount += payload->numRows(); + payloads.emplace_back(std::move(payload)); + } + } savedPayloads_; + bool reachEos_{false}; // for row format shuffle read diff --git a/bolt/shuffle/sparksql/Payload.h b/bolt/shuffle/sparksql/Payload.h index f147d7403..09138b4ed 100644 --- a/bolt/shuffle/sparksql/Payload.h +++ b/bolt/shuffle/sparksql/Payload.h @@ -203,6 +203,10 @@ class InMemoryPayload final : public Payload { arrow::Result> readBufferAt( uint32_t index) override; + int64_t bufferSizeAt(uint32_t index) const { + return buffers_[index] ? buffers_[index]->size() : 0; + } + arrow::Result> toBlockPayload( Payload::Type payloadType, arrow::MemoryPool* pool, diff --git a/bolt/shuffle/sparksql/tests/BoltShuffleReaderTest.cpp b/bolt/shuffle/sparksql/tests/BoltShuffleReaderTest.cpp new file mode 100644 index 000000000..bd22010d9 --- /dev/null +++ b/bolt/shuffle/sparksql/tests/BoltShuffleReaderTest.cpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "bolt/common/base/Exceptions.h" +#include "bolt/common/memory/Memory.h" +#include "bolt/shuffle/sparksql/BoltShuffleReader.h" +#include "bolt/shuffle/sparksql/Payload.h" +#include "bolt/type/Type.h" +#include "bolt/vector/FlatVector.h" + +namespace bytedance::bolt::shuffle::sparksql::test { +namespace { + +constexpr int32_t kNumColumns = 2; + +int64_t expectedValue(int64_t globalRow, int32_t col) { + return globalRow * (col + 1); +} + +// Two buffers per BIGINT column: validity bitmap (all-valid) + values seeded +// with the sentinel pattern from expectedValue(). +std::vector> makePayloadBuffers( + int32_t payloadIdx, + int32_t rowsPerPayload, + arrow::MemoryPool* pool) { + const int64_t valueBytes = rowsPerPayload * sizeof(int64_t); + const int64_t validityBytes = arrow::bit_util::BytesForBits(rowsPerPayload); + std::vector> buffers; + for (int32_t c = 0; c < kNumColumns; ++c) { + auto v = arrow::AllocateResizableBuffer(validityBytes, pool).ValueOrDie(); + std::memset(v->mutable_data(), 0xFF, validityBytes); + buffers.push_back(std::move(v)); + + auto d = arrow::AllocateResizableBuffer(valueBytes, pool).ValueOrDie(); + auto* values = reinterpret_cast(d->mutable_data()); + for (int32_t r = 0; r < rowsPerPayload; ++r) { + values[r] = expectedValue(payloadIdx * rowsPerPayload + r, c); + } + buffers.push_back(std::move(d)); + } + return buffers; +} + +std::shared_ptr buildStream( + int32_t numPayloads, + int32_t rowsPerPayload, + const std::vector* isValidityBuffer, + arrow::MemoryPool* pool) { + auto stream = + arrow::io::BufferOutputStream::Create(1 << 12, pool).ValueOrDie(); + for (int32_t p = 0; p < numPayloads; ++p) { + auto payload = BlockPayload::fromBuffers( + Payload::Type::kUncompressed, + rowsPerPayload, + makePayloadBuffers(p, rowsPerPayload, pool), + isValidityBuffer, + pool, + /*codec=*/nullptr, + Payload::Mode::kBuffer, + /*hasComplexType=*/false) + .ValueOrDie(); + BOLT_CHECK(payload->serialize(stream.get()).ok(), "serialize failed"); + } + return stream->Finish().ValueOrDie(); +} + +class BoltShuffleReaderTest : public ::testing::Test { + protected: + void SetUp() override { + pool_ = arrow::default_memory_pool(); + boltPool_ = bytedance::bolt::memory::memoryManager()->addLeafPool(); + + std::vector names; + std::vector types; + std::vector> fields; + for (int32_t c = 0; c < kNumColumns; ++c) { + names.push_back("c" + std::to_string(c)); + types.push_back(bytedance::bolt::BIGINT()); + fields.push_back( + arrow::field("c" + std::to_string(c), arrow::int64(), true)); + } + rowType_ = bytedance::bolt::ROW(std::move(names), std::move(types)); + schema_ = arrow::schema(fields); + for (int32_t c = 0; c < kNumColumns; ++c) { + isValidityBuffer_.push_back(true); + isValidityBuffer_.push_back(false); + } + } + + std::unique_ptr makeDeserializer( + const std::shared_ptr& stream, + int32_t batchSize, + int64_t shuffleBatchByteSize) { + factory_ = std::make_unique( + schema_, + /*codec=*/nullptr, + rowType_, + batchSize, + shuffleBatchByteSize, + pool_, + boltPool_.get(), + /*checksumEnabled=*/false); + factory_->setpartitioningShortName("single"); + return factory_->createDeserializer( + std::make_shared(stream)); + } + + // Drain the deserializer, verifying every cell value and total row count. + void drainAndVerify(BoltColumnarBatchDeserializer& d, int64_t expectedRows) { + int64_t total = 0; + while (auto batch = d.next()) { + ASSERT_EQ(batch->childrenSize(), kNumColumns); + for (int32_t c = 0; c < kNumColumns; ++c) { + auto* col = batch->childAt(c)->asFlatVector(); + ASSERT_NE(col, nullptr); + for (vector_size_t i = 0; i < batch->size(); ++i) { + EXPECT_FALSE(col->isNullAt(i)); + EXPECT_EQ(col->valueAt(i), expectedValue(total + i, c)); + } + } + total += batch->size(); + } + EXPECT_EQ(total, expectedRows); + } + + arrow::MemoryPool* pool_{}; + std::shared_ptr boltPool_; + bytedance::bolt::RowTypePtr rowType_; + std::shared_ptr schema_; + std::vector isValidityBuffer_; + std::unique_ptr factory_; +}; + +TEST_F(BoltShuffleReaderTest, SinglePayload) { + constexpr int32_t kRows = 16; + auto stream = buildStream(1, kRows, &isValidityBuffer_, pool_); + auto d = makeDeserializer(stream, /*batchSize=*/1024, /*byteSize=*/1 << 20); + drainAndVerify(*d, kRows); +} + +TEST_F(BoltShuffleReaderTest, ManyPayloadsSingleBatch) { + constexpr int32_t kPayloads = 200; + constexpr int32_t kRows = 16; + auto stream = buildStream(kPayloads, kRows, &isValidityBuffer_, pool_); + auto d = + makeDeserializer(stream, /*batchSize=*/100000, /*byteSize=*/1LL << 30); + drainAndVerify(*d, kPayloads * kRows); +} + +TEST_F(BoltShuffleReaderTest, ManyPayloadsMultipleBatches) { + constexpr int32_t kPayloads = 100; + constexpr int32_t kRows = 16; + auto stream = buildStream(kPayloads, kRows, &isValidityBuffer_, pool_); + auto d = makeDeserializer(stream, /*batchSize=*/160, /*byteSize=*/1LL << 30); + drainAndVerify(*d, kPayloads * kRows); +} + +TEST_F(BoltShuffleReaderTest, AllocationCountBounded) { + constexpr int32_t kPayloads = 200; + constexpr int32_t kRows = 16; + constexpr int32_t kBuffersPerPayload = kNumColumns * 2; + auto stream = buildStream(kPayloads, kRows, &isValidityBuffer_, pool_); + auto d = + makeDeserializer(stream, /*batchSize=*/100000, /*byteSize=*/1LL << 30); + + const int64_t before = pool_->num_allocations(); + drainAndVerify(*d, kPayloads * kRows); + const int64_t allocs = pool_->num_allocations() - before; + + EXPECT_LT(allocs, kPayloads * kBuffersPerPayload * 1.2) + << "saw " << allocs << " allocations; perf fix may have regressed"; +} + +} // namespace +} // namespace bytedance::bolt::shuffle::sparksql::test + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + bytedance::bolt::memory::MemoryManager::initialize({}); + return RUN_ALL_TESTS(); +} diff --git a/bolt/shuffle/sparksql/tests/CMakeLists.txt b/bolt/shuffle/sparksql/tests/CMakeLists.txt index 327e96083..f0355e70d 100644 --- a/bolt/shuffle/sparksql/tests/CMakeLists.txt +++ b/bolt/shuffle/sparksql/tests/CMakeLists.txt @@ -163,3 +163,19 @@ add_test( NAME bolt_shuffle_adaptive_zstd_codec_test COMMAND bolt_shuffle_adaptive_zstd_codec_test ) + +add_executable(bolt_shuffle_reader_test BoltShuffleReaderTest.cpp) + +target_link_libraries( + bolt_shuffle_reader_test + PRIVATE + bolt_shuffle_spark_impl + bolt_testutils + GTest::gtest + GTest::gmock +) + +add_test( + NAME bolt_shuffle_reader_test + COMMAND bolt_shuffle_reader_test +)