diff --git a/bolt/exec/LocalPlanner.cpp b/bolt/exec/LocalPlanner.cpp index 9df9af872..c287fb1bd 100644 --- a/bolt/exec/LocalPlanner.cpp +++ b/bolt/exec/LocalPlanner.cpp @@ -372,11 +372,6 @@ uint32_t maxDrivers( else if (node->name() == "ValueStream") { return 1; } - // multi-threaded spark: SparkShuffleReader is designed to be - // single-threaded for now. This assumption might not hold in the future. - else if (node->name() == "SparkShuffleReader") { - return 1; - } // multi-threaded spark: SparkShuffleWriter is designed to be // single-threaded for now. This assumption might not hold in the future. else if (node->name() == "SparkShuffleWriter") { diff --git a/bolt/exec/Task.cpp b/bolt/exec/Task.cpp index 99d40e3aa..bc80225c1 100644 --- a/bolt/exec/Task.cpp +++ b/bolt/exec/Task.cpp @@ -355,6 +355,10 @@ Task::~Task() { #define CLEAR(_action_) \ clearStage = #_action_; \ _action_; +#ifdef SPARK_COMPATIBLE + CLEAR(shuffleReaderClients_.clear()); + CLEAR(shuffleReaderClientByPlanNode_.clear()); +#endif CLEAR(threadFinishPromises_.clear()); CLEAR(splitGroupStates_.clear()); CLEAR(taskStats_ = TaskStats()); @@ -401,6 +405,11 @@ void Task::init(std::optional&& spillDiskOpts) { 1, queryCtx_->queryConfig().isMultiDriverEnabled()); exchangeClients_.resize(driverFactories_.size()); +#ifdef SPARK_COMPATIBLE + shuffleReaderClients_.resize(driverFactories_.size()); + shuffleReaderClientsInitFlags_ = + std::vector(driverFactories_.size()); +#endif // In Task::next() we always assume ungrouped execution. for (const auto& factory : driverFactories_) { @@ -820,6 +829,12 @@ void Task::createDriverFactoriesLocked(uint32_t maxDrivers) { queryCtx_->queryConfig().isMultiDriverEnabled(), queryCtx_->queryConfig().morselDrivenEnabled()); +#ifdef SPARK_COMPATIBLE + shuffleReaderClients_.resize(driverFactories_.size()); + shuffleReaderClientsInitFlags_ = + std::vector(driverFactories_.size()); +#endif + // Keep one exchange client per pipeline (NULL if not used). const uint32_t numPipelines = driverFactories_.size(); // Calculates total number of drivers and create pipeline stats. diff --git a/bolt/exec/Task.h b/bolt/exec/Task.h index 7693b15ec..dca8a8cdd 100644 --- a/bolt/exec/Task.h +++ b/bolt/exec/Task.h @@ -42,6 +42,13 @@ #include "bolt/exec/TaskStructs.h" #include "bolt/exec/TraceConfig.h" #include "bolt/vector/ComplexVector.h" + +#ifdef SPARK_COMPATIBLE +namespace bytedance::bolt::shuffle::sparksql { +class BoltShuffleReaderClient; +} // namespace bytedance::bolt::shuffle::sparksql +#endif + namespace bytedance::bolt::exec { class OutputBufferManager; @@ -734,6 +741,39 @@ class Task : public std::enable_shared_from_this { return numThreads_; } +#ifdef SPARK_COMPATIBLE + std::shared_ptr + getOrCreateShuffleReaderClient( + int32_t pipelineId, + const core::PlanNodeId& planNodeId, + std::function< + std::shared_ptr()> + initClient) { + BOLT_CHECK_GE(pipelineId, 0); + const auto pipelineIndex = static_cast(pipelineId); + BOLT_CHECK_LT(pipelineIndex, shuffleReaderClients_.size()); + std::lock_guard l( + shuffleReaderClientsInitFlags_[pipelineIndex]); + if (!shuffleReaderClients_[pipelineIndex]) { + shuffleReaderClients_[pipelineIndex] = initClient(); + shuffleReaderClientByPlanNode_.emplace( + planNodeId, shuffleReaderClients_[pipelineIndex]); + } + return shuffleReaderClients_[pipelineIndex]; + } + + bolt::memory::MemoryPool* addShuffleReaderClientPool( + const core::PlanNodeId& planNodeId, + uint32_t pipelineId) { + auto* nodePool = getOrAddNodePool(planNodeId); + childPools_.push_back(nodePool->addLeafChild( + fmt::format("shuffleReaderClient.{}.{}", planNodeId, pipelineId), + true, + createExchangeClientReclaimer())); + return childPools_.back().get(); + } +#endif + private: Task( const std::string& taskId, @@ -1157,6 +1197,16 @@ class Task : public std::enable_shared_from_this { std::unordered_map> exchangeClientByPlanNode_; +#ifdef SPARK_COMPATIBLE + std::vector> + shuffleReaderClients_; + std::unordered_map< + core::PlanNodeId, + std::shared_ptr> + shuffleReaderClientByPlanNode_; + std::vector shuffleReaderClientsInitFlags_; +#endif + ConsumerSupplier consumerSupplier_; // The function that is executed when the task encounters its first error, diff --git a/bolt/shuffle/sparksql/ShuffleReaderNode.cpp b/bolt/shuffle/sparksql/ShuffleReaderNode.cpp index 8e8272cf6..d729d01ed 100644 --- a/bolt/shuffle/sparksql/ShuffleReaderNode.cpp +++ b/bolt/shuffle/sparksql/ShuffleReaderNode.cpp @@ -15,6 +15,8 @@ */ #include "bolt/shuffle/sparksql/ShuffleReaderNode.h" +#include "bolt/common/time/Timer.h" +#include "bolt/exec/Task.h" #include "bolt/shuffle/sparksql/compression/Compression.h" using namespace bytedance::bolt::shuffle::sparksql; @@ -27,10 +29,40 @@ SparkShuffleReader::SparkShuffleReader( shuffleReaderNode->outputType(), operatorId, shuffleReaderNode->id(), - std::string(shuffleReaderNode->name())), - shuffleReaderOptions_(shuffleReaderNode->getShuffleReaderOptions()), + std::string(shuffleReaderNode->name())) { + auto initClient = [&]() { + return std::make_shared( + shuffleReaderNode, + operatorCtx_->task()->addShuffleReaderClientPool( + shuffleReaderNode->id(), operatorCtx_->driverCtx()->pipelineId)); + }; + + shuffleReaderClient_ = operatorCtx_->task()->getOrCreateShuffleReaderClient( + operatorCtx_->driverCtx()->pipelineId, + shuffleReaderNode->id(), + initClient); + BOLT_CHECK_NOT_NULL(shuffleReaderClient_); +} + +bytedance::bolt::RowVectorPtr SparkShuffleReader::getOutput() { + auto data = shuffleReaderClient_->next(); + if (!data) { + finished_ = true; + } + return data; +} + +void SparkShuffleReader::close() { + shuffleReaderClient_ = nullptr; + bytedance::bolt::exec::SourceOperator::close(); +} + +BoltShuffleReaderClient::BoltShuffleReaderClient( + std::shared_ptr shuffleReaderNode, + memory::MemoryPool* pool) + : shuffleReaderOptions_(shuffleReaderNode->getShuffleReaderOptions()), readerStreamIterator_(shuffleReaderNode->getReaderStreams()), - arrowPool_(std::make_shared(pool())), + arrowPool_(std::make_shared(pool)), codec_(createCodec( shuffleReaderOptions_.compressionType, CodecOptions{ @@ -45,8 +77,10 @@ SparkShuffleReader::SparkShuffleReader( partitioningShortName_(shuffleReaderOptions_.partitionShortName), rowBufferPool_(std::make_shared(arrowPool_.get())), row2ColConverter_(std::make_shared( - outputType_, - pool())) { + shuffleReaderNode->outputType(), + pool)), + outputType_(shuffleReaderNode->outputType()), + pool_(pool) { isValidityBuffer_.reserve(outputType_->size()); for (size_t i = 0; i < outputType_->size(); ++i) { switch (outputType_->childAt(i)->kind()) { @@ -83,16 +117,32 @@ SparkShuffleReader::SparkShuffleReader( (shuffleWriterType_ == ShuffleWriterType::RowBased)); } -void SparkShuffleReader::init() { +BoltShuffleReaderClient::~BoltShuffleReaderClient() { + if (readerStreamIterator_) { + readerStreamIterator_->updateMetrics( + numRows_, + numBatches_, + decompressTime_, + deserializeTime_, + totalReadTime_); + readerStreamIterator_->close(); + readerStreamIterator_ = nullptr; + } +} + +void BoltShuffleReaderClient::init() { // Bolt operator should not alloc memory during construct, so init schema and // codec here - schema_ = boltTypeToArrowSchema(outputType_, pool()); + schema_ = boltTypeToArrowSchema(outputType_, pool_); zstdCodec_ = std::make_shared( 1 /*not used*/, false, arrowPool_.get()); } -bytedance::bolt::RowVectorPtr SparkShuffleReader::getOutput() { - std::call_once(initFlag_, &SparkShuffleReader::init, this); +bytedance::bolt::RowVectorPtr BoltShuffleReaderClient::next() { + std::call_once(initFlag_, &BoltShuffleReaderClient::init, this); + std::lock_guard lock(mutex_); + bytedance::bolt::NanosecondTimer timer(&totalReadTime_); + while (true) { if (!columnarBatchDeserializer_) { auto in = readerStreamIterator_->nextStream(arrowPool_.get()); @@ -106,7 +156,7 @@ bytedance::bolt::RowVectorPtr SparkShuffleReader::getOutput() { batchSize_, shuffleBatchByteSize_, arrowPool_.get(), - pool(), + pool_, &isValidityBuffer_, hasComplexType_, deserializeTime_, @@ -116,31 +166,17 @@ bytedance::bolt::RowVectorPtr SparkShuffleReader::getOutput() { rowBufferPool_.get(), row2ColConverter_.get()); } else { - finished_ = true; return nullptr; } } - auto output = columnarBatchDeserializer_->next(); - if (output) { - return output; - } else { + auto data = columnarBatchDeserializer_->next(); + if (!data) { columnarBatchDeserializer_ = nullptr; + } else { + numBatches_++; + numRows_ += data->size(); + return data; } } } - -void SparkShuffleReader::close() { - auto stats = this->stats().rlock(); - readerStreamIterator_->updateMetrics( - stats->outputPositions, - stats->outputVectors, - decompressTime_, - deserializeTime_, - stats->getOutputTiming.wallNanos); - if (readerStreamIterator_) { - readerStreamIterator_->close(); - readerStreamIterator_ = nullptr; - } - bytedance::bolt::exec::SourceOperator::close(); -} diff --git a/bolt/shuffle/sparksql/ShuffleReaderNode.h b/bolt/shuffle/sparksql/ShuffleReaderNode.h index a1da0e23f..bc2ec80bf 100644 --- a/bolt/shuffle/sparksql/ShuffleReaderNode.h +++ b/bolt/shuffle/sparksql/ShuffleReaderNode.h @@ -32,6 +32,7 @@ #pragma once #include +#include #include "bolt/exec/Driver.h" #include "bolt/exec/Operator.h" #include "bolt/shuffle/sparksql/BoltArrowMemoryPool.h" @@ -39,6 +40,8 @@ #include "bolt/shuffle/sparksql/ReaderStreamIterator.h" namespace bytedance::bolt::shuffle::sparksql { +class BoltShuffleReaderClient; + class SparkShuffleReaderNode : public bytedance::bolt::core::PlanNode { public: SparkShuffleReaderNode( @@ -107,9 +110,23 @@ class SparkShuffleReader : public bytedance::bolt::exec::SourceOperator { void close() override; - void init(); + private: + bool finished_ = false; + std::shared_ptr shuffleReaderClient_; +}; + +class BoltShuffleReaderClient { + public: + BoltShuffleReaderClient( + std::shared_ptr shuffleReaderNode, + memory::MemoryPool* pool); + ~BoltShuffleReaderClient(); + + RowVectorPtr next(); private: + void init(); + std::once_flag initFlag_; ShuffleReaderOptions shuffleReaderOptions_; std::shared_ptr readerStreamIterator_; @@ -129,6 +146,9 @@ class SparkShuffleReader : public bytedance::bolt::exec::SourceOperator { uint64_t deserializeTime_{0}; uint64_t decompressTime_{0}; + uint64_t totalReadTime_{0}; + int64_t numBatches_{0}; + int64_t numRows_{0}; // for rowbased shuffle std::shared_ptr zstdCodec_{nullptr}; @@ -138,8 +158,11 @@ class SparkShuffleReader : public bytedance::bolt::exec::SourceOperator { std::unique_ptr columnarBatchDeserializer_; bool isRowBased_ = false; + const RowTypePtr outputType_; - bool finished_ = false; + memory::MemoryPool* const pool_; + + std::mutex mutex_; }; class SparkShuffleReaderTranslator