Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions bolt/exec/LocalPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,6 @@ uint32_t maxDrivers(
else if (node->name() == "ValueStream") {
return 1;
}
// multi-threaded spark: SparkShuffleReader is designed to be
// single-threaded for now. This assumption might not hold in the future.
else if (node->name() == "SparkShuffleReader") {
return 1;
}
// multi-threaded spark: SparkShuffleWriter is designed to be
// single-threaded for now. This assumption might not hold in the future.
else if (node->name() == "SparkShuffleWriter") {
Expand Down
15 changes: 15 additions & 0 deletions bolt/exec/Task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ Task::~Task() {
#define CLEAR(_action_) \
clearStage = #_action_; \
_action_;
#ifdef SPARK_COMPATIBLE
CLEAR(shuffleReaderClients_.clear());
CLEAR(shuffleReaderClientByPlanNode_.clear());
#endif
CLEAR(threadFinishPromises_.clear());
CLEAR(splitGroupStates_.clear());
CLEAR(taskStats_ = TaskStats());
Expand Down Expand Up @@ -401,6 +405,11 @@ void Task::init(std::optional<common::SpillDiskOptions>&& spillDiskOpts) {
1,
queryCtx_->queryConfig().isMultiDriverEnabled());
exchangeClients_.resize(driverFactories_.size());
#ifdef SPARK_COMPATIBLE
shuffleReaderClients_.resize(driverFactories_.size());
shuffleReaderClientsInitFlags_ =
std::vector<std::mutex>(driverFactories_.size());
#endif

// In Task::next() we always assume ungrouped execution.
for (const auto& factory : driverFactories_) {
Expand Down Expand Up @@ -820,6 +829,12 @@ void Task::createDriverFactoriesLocked(uint32_t maxDrivers) {
queryCtx_->queryConfig().isMultiDriverEnabled(),
queryCtx_->queryConfig().morselDrivenEnabled());

#ifdef SPARK_COMPATIBLE
shuffleReaderClients_.resize(driverFactories_.size());
shuffleReaderClientsInitFlags_ =
std::vector<std::mutex>(driverFactories_.size());
#endif

// Keep one exchange client per pipeline (NULL if not used).
const uint32_t numPipelines = driverFactories_.size();
// Calculates total number of drivers and create pipeline stats.
Expand Down
50 changes: 50 additions & 0 deletions bolt/exec/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@
#include "bolt/exec/TaskStructs.h"
#include "bolt/exec/TraceConfig.h"
#include "bolt/vector/ComplexVector.h"

#ifdef SPARK_COMPATIBLE
namespace bytedance::bolt::shuffle::sparksql {
class BoltShuffleReaderClient;
} // namespace bytedance::bolt::shuffle::sparksql
#endif

namespace bytedance::bolt::exec {

class OutputBufferManager;
Expand Down Expand Up @@ -734,6 +741,39 @@ class Task : public std::enable_shared_from_this<Task> {
return numThreads_;
}

#ifdef SPARK_COMPATIBLE
std::shared_ptr<shuffle::sparksql::BoltShuffleReaderClient>
getOrCreateShuffleReaderClient(
int32_t pipelineId,
const core::PlanNodeId& planNodeId,
std::function<
std::shared_ptr<shuffle::sparksql::BoltShuffleReaderClient>()>
initClient) {
BOLT_CHECK_GE(pipelineId, 0);
const auto pipelineIndex = static_cast<size_t>(pipelineId);
BOLT_CHECK_LT(pipelineIndex, shuffleReaderClients_.size());
std::lock_guard<std::mutex> l(
shuffleReaderClientsInitFlags_[pipelineIndex]);
if (!shuffleReaderClients_[pipelineIndex]) {
shuffleReaderClients_[pipelineIndex] = initClient();
shuffleReaderClientByPlanNode_.emplace(
planNodeId, shuffleReaderClients_[pipelineIndex]);
}
return shuffleReaderClients_[pipelineIndex];
}

bolt::memory::MemoryPool* addShuffleReaderClientPool(
const core::PlanNodeId& planNodeId,
uint32_t pipelineId) {
auto* nodePool = getOrAddNodePool(planNodeId);
childPools_.push_back(nodePool->addLeafChild(
fmt::format("shuffleReaderClient.{}.{}", planNodeId, pipelineId),
true,
createExchangeClientReclaimer()));
return childPools_.back().get();
}
#endif

private:
Task(
const std::string& taskId,
Expand Down Expand Up @@ -1157,6 +1197,16 @@ class Task : public std::enable_shared_from_this<Task> {
std::unordered_map<core::PlanNodeId, std::shared_ptr<ExchangeClient>>
exchangeClientByPlanNode_;

#ifdef SPARK_COMPATIBLE
std::vector<std::shared_ptr<shuffle::sparksql::BoltShuffleReaderClient>>
shuffleReaderClients_;
std::unordered_map<
core::PlanNodeId,
std::shared_ptr<shuffle::sparksql::BoltShuffleReaderClient>>
shuffleReaderClientByPlanNode_;
std::vector<std::mutex> shuffleReaderClientsInitFlags_;
#endif

ConsumerSupplier consumerSupplier_;

// The function that is executed when the task encounters its first error,
Expand Down
96 changes: 66 additions & 30 deletions bolt/shuffle/sparksql/ShuffleReaderNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/

#include "bolt/shuffle/sparksql/ShuffleReaderNode.h"
#include "bolt/common/time/Timer.h"
#include "bolt/exec/Task.h"
#include "bolt/shuffle/sparksql/compression/Compression.h"
using namespace bytedance::bolt::shuffle::sparksql;

Expand All @@ -27,10 +29,40 @@ SparkShuffleReader::SparkShuffleReader(
shuffleReaderNode->outputType(),
operatorId,
shuffleReaderNode->id(),
std::string(shuffleReaderNode->name())),
shuffleReaderOptions_(shuffleReaderNode->getShuffleReaderOptions()),
std::string(shuffleReaderNode->name())) {
auto initClient = [&]() {
return std::make_shared<BoltShuffleReaderClient>(
shuffleReaderNode,
operatorCtx_->task()->addShuffleReaderClientPool(
shuffleReaderNode->id(), operatorCtx_->driverCtx()->pipelineId));
};

shuffleReaderClient_ = operatorCtx_->task()->getOrCreateShuffleReaderClient(
operatorCtx_->driverCtx()->pipelineId,
shuffleReaderNode->id(),
initClient);
BOLT_CHECK_NOT_NULL(shuffleReaderClient_);
}

bytedance::bolt::RowVectorPtr SparkShuffleReader::getOutput() {
auto data = shuffleReaderClient_->next();
if (!data) {
finished_ = true;
}
return data;
}

void SparkShuffleReader::close() {
shuffleReaderClient_ = nullptr;
bytedance::bolt::exec::SourceOperator::close();
}

BoltShuffleReaderClient::BoltShuffleReaderClient(
std::shared_ptr<const SparkShuffleReaderNode> shuffleReaderNode,
memory::MemoryPool* pool)
: shuffleReaderOptions_(shuffleReaderNode->getShuffleReaderOptions()),
readerStreamIterator_(shuffleReaderNode->getReaderStreams()),
arrowPool_(std::make_shared<BoltArrowMemoryPool>(pool())),
arrowPool_(std::make_shared<BoltArrowMemoryPool>(pool)),
codec_(createCodec(
shuffleReaderOptions_.compressionType,
CodecOptions{
Expand All @@ -45,8 +77,10 @@ SparkShuffleReader::SparkShuffleReader(
partitioningShortName_(shuffleReaderOptions_.partitionShortName),
rowBufferPool_(std::make_shared<RowBufferPool>(arrowPool_.get())),
row2ColConverter_(std::make_shared<ShuffleRowToColumnarConverter>(
outputType_,
pool())) {
shuffleReaderNode->outputType(),
pool)),
outputType_(shuffleReaderNode->outputType()),
pool_(pool) {
isValidityBuffer_.reserve(outputType_->size());
for (size_t i = 0; i < outputType_->size(); ++i) {
switch (outputType_->childAt(i)->kind()) {
Expand Down Expand Up @@ -83,16 +117,32 @@ SparkShuffleReader::SparkShuffleReader(
(shuffleWriterType_ == ShuffleWriterType::RowBased));
}

void SparkShuffleReader::init() {
BoltShuffleReaderClient::~BoltShuffleReaderClient() {
if (readerStreamIterator_) {
readerStreamIterator_->updateMetrics(
numRows_,
numBatches_,
decompressTime_,
deserializeTime_,
totalReadTime_);
readerStreamIterator_->close();
readerStreamIterator_ = nullptr;
}
}

void BoltShuffleReaderClient::init() {
// Bolt operator should not alloc memory during construct, so init schema and
// codec here
schema_ = boltTypeToArrowSchema(outputType_, pool());
schema_ = boltTypeToArrowSchema(outputType_, pool_);
zstdCodec_ = std::make_shared<AdaptiveParallelZstdCodec>(
1 /*not used*/, false, arrowPool_.get());
}

bytedance::bolt::RowVectorPtr SparkShuffleReader::getOutput() {
std::call_once(initFlag_, &SparkShuffleReader::init, this);
bytedance::bolt::RowVectorPtr BoltShuffleReaderClient::next() {
std::call_once(initFlag_, &BoltShuffleReaderClient::init, this);
std::lock_guard<std::mutex> lock(mutex_);
bytedance::bolt::NanosecondTimer timer(&totalReadTime_);

while (true) {
if (!columnarBatchDeserializer_) {
auto in = readerStreamIterator_->nextStream(arrowPool_.get());
Expand All @@ -106,7 +156,7 @@ bytedance::bolt::RowVectorPtr SparkShuffleReader::getOutput() {
batchSize_,
shuffleBatchByteSize_,
arrowPool_.get(),
pool(),
pool_,
&isValidityBuffer_,
hasComplexType_,
deserializeTime_,
Expand All @@ -116,31 +166,17 @@ bytedance::bolt::RowVectorPtr SparkShuffleReader::getOutput() {
rowBufferPool_.get(),
row2ColConverter_.get());
} else {
finished_ = true;
return nullptr;
}
}

auto output = columnarBatchDeserializer_->next();
if (output) {
return output;
} else {
auto data = columnarBatchDeserializer_->next();
if (!data) {
columnarBatchDeserializer_ = nullptr;
} else {
numBatches_++;
numRows_ += data->size();
return data;
}
}
}

void SparkShuffleReader::close() {
auto stats = this->stats().rlock();
readerStreamIterator_->updateMetrics(
stats->outputPositions,
stats->outputVectors,
decompressTime_,
deserializeTime_,
stats->getOutputTiming.wallNanos);
if (readerStreamIterator_) {
readerStreamIterator_->close();
readerStreamIterator_ = nullptr;
}
bytedance::bolt::exec::SourceOperator::close();
}
27 changes: 25 additions & 2 deletions bolt/shuffle/sparksql/ShuffleReaderNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@
#pragma once

#include <cstdint>
#include <mutex>
#include "bolt/exec/Driver.h"
#include "bolt/exec/Operator.h"
#include "bolt/shuffle/sparksql/BoltArrowMemoryPool.h"
#include "bolt/shuffle/sparksql/BoltShuffleReader.h"
#include "bolt/shuffle/sparksql/ReaderStreamIterator.h"
namespace bytedance::bolt::shuffle::sparksql {

class BoltShuffleReaderClient;

class SparkShuffleReaderNode : public bytedance::bolt::core::PlanNode {
public:
SparkShuffleReaderNode(
Expand Down Expand Up @@ -107,9 +110,23 @@ class SparkShuffleReader : public bytedance::bolt::exec::SourceOperator {

void close() override;

void init();
private:
bool finished_ = false;
std::shared_ptr<BoltShuffleReaderClient> shuffleReaderClient_;
};

class BoltShuffleReaderClient {
public:
BoltShuffleReaderClient(
std::shared_ptr<const SparkShuffleReaderNode> shuffleReaderNode,
memory::MemoryPool* pool);
~BoltShuffleReaderClient();

RowVectorPtr next();

private:
void init();

std::once_flag initFlag_;
ShuffleReaderOptions shuffleReaderOptions_;
std::shared_ptr<ReaderStreamIterator> readerStreamIterator_;
Expand All @@ -129,6 +146,9 @@ class SparkShuffleReader : public bytedance::bolt::exec::SourceOperator {

uint64_t deserializeTime_{0};
uint64_t decompressTime_{0};
uint64_t totalReadTime_{0};
int64_t numBatches_{0};
int64_t numRows_{0};

// for rowbased shuffle
std::shared_ptr<AdaptiveParallelZstdCodec> zstdCodec_{nullptr};
Expand All @@ -138,8 +158,11 @@ class SparkShuffleReader : public bytedance::bolt::exec::SourceOperator {
std::unique_ptr<BoltColumnarBatchDeserializer> columnarBatchDeserializer_;

bool isRowBased_ = false;
const RowTypePtr outputType_;

bool finished_ = false;
memory::MemoryPool* const pool_;

std::mutex mutex_;
};

class SparkShuffleReaderTranslator
Expand Down