Skip to content

Commit 75dfe98

Browse files
committed
add driver flag
1 parent 7db0cda commit 75dfe98

File tree

3 files changed

+34
-28
lines changed

3 files changed

+34
-28
lines changed

src/turbomind/models/llama/LlamaBatch.cc

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector<Signal>&
202202
int idx = 0;
203203
for (const auto& r : reqs) {
204204

205-
if (tp_rank_ == 0) {
205+
if (is_driver_) {
206206
TM_LOG_INFO("[ProcessInferRequests] Request for %ld received.", (long)r->id);
207207
}
208208

@@ -230,7 +230,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector<Signal>&
230230
s = ptr->tokens.size();
231231
}
232232
else if (s > ptr->tokens.size()) {
233-
if (tp_rank_ == 0) {
233+
if (is_driver_) {
234234
TM_LOG_WARNING("[ProcessInferRequests] Skipping invalid step (%d) setting for ID %lu", s, ptr->id);
235235
}
236236
s = ptr->tokens.size();
@@ -363,7 +363,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector<Signal>&
363363
// the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
364364
if (state.seq_len_limit[idx] >= session_len_) {
365365
state.seq_len_limit[idx] = session_len_ - 1;
366-
if (tp_rank_ == 0) {
366+
if (is_driver_) {
367367
const int trunc_output_len = state.seq_len_limit[idx] - state.h_context_length[idx];
368368
TM_LOG_WARNING(
369369
"[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `max_new_tokens` is truncated to %d",
@@ -840,7 +840,8 @@ LlamaBatch::LlamaBatch(DataType data_type,
840840
std::unique_ptr<Context> ctx, // ! This is moved
841841
std::shared_ptr<Gateway> gateway,
842842
int device_id,
843-
int dp_rank):
843+
int dp_rank,
844+
bool is_driver):
844845
param_(param),
845846
gateway_(gateway),
846847
max_batch_size_(param.max_batch_size),
@@ -852,6 +853,7 @@ LlamaBatch::LlamaBatch(DataType data_type,
852853
dp_rank_(dp_rank),
853854
tp_size_(model->tp_size_),
854855
tp_rank_(model->tp_rank_),
856+
is_driver_(is_driver),
855857
data_type_(data_type),
856858
debug_(isDebug()),
857859
stream_(ctx->stream),
@@ -980,7 +982,7 @@ void LlamaBatch::ComputeAndOutputLogits(const Tensor& hidden_states, int first,
980982

981983
auto logits = model_->postDecodeEmbedding(hidden_states, symm_logits_buf_.buffer());
982984

983-
if (tp_rank_ == 0) {
985+
if (is_driver_) {
984986
OutputLogits(logits, first, last, GenerationConfig::kAll);
985987
}
986988
}
@@ -1141,7 +1143,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
11411143
}
11421144

11431145
// ! Only rank-0 writes to output
1144-
if (tp_rank_ == 0 && output_logprobs) {
1146+
if (is_driver_ && output_logprobs) {
11451147
NvtxScope scope("logprobs");
11461148
float* sampled_logprobs_ptr = h_sampled_logprobs_.data();
11471149
uint32_t* sampled_indexes_ptr = h_sampled_indexes_.data();
@@ -1168,7 +1170,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
11681170
}
11691171

11701172
// ! Only rank-0 writes to output
1171-
if (tp_rank_ == 0) {
1173+
if (is_driver_) {
11721174
NvtxScope scope("output_ids");
11731175
for (int i = 0; i < batch_size - g.partial; ++i) {
11741176
if (auto& r = state_->requests[i]) {
@@ -1184,7 +1186,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
11841186
// Cache computed blocks to block trie
11851187
sequence_manager_->CacheIfEnabled(state_->sequences, batch_size);
11861188

1187-
if (debug_ && tp_rank_ == 0) {
1189+
if (debug_ && is_driver_) {
11881190
for (int i = 0; i < batch_size; ++i) {
11891191
// ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")";
11901192
std::vector<int> tokens(state_->h_context_length[i]);
@@ -1225,7 +1227,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
12251227
// Interrupt should reset r
12261228
FT_CHECK(!r);
12271229
}
1228-
else if (r->stream_output && tp_rank_ == 0) {
1230+
else if (r->stream_output && is_driver_) {
12291231
const auto seq_len = *r->sequence_length.data();
12301232
// Create signals by copying the request handles for non-finished streaming requests
12311233
signals.push_back([this, r, seq_len] { //
@@ -1249,15 +1251,15 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
12491251

12501252
auto LlamaBatch::Interrupt(int index, bool force_stop, bool force_end) -> Signal
12511253
{
1252-
if (tp_rank_ == 0) {
1254+
if (is_driver_) {
12531255
TM_LOG_INFO("[Interrupt] slot %d, request %lu, stop %d, end %d",
12541256
index,
12551257
(long)state_->requests[index]->id,
12561258
force_stop,
12571259
force_end);
12581260
}
12591261

1260-
if (debug_ && tp_rank_ == 0) {
1262+
if (debug_ && is_driver_) {
12611263
std::vector<int> tokens(state_->h_context_length[index]);
12621264
core::Copy(state_->output_ids.data() + index * session_len_, tokens.size(), tokens.data());
12631265
cudaStreamSynchronize(stream_);
@@ -1331,7 +1333,7 @@ void LlamaBatch::InternalThreadEntry()
13311333

13321334
std::shared_ptr<RequestData> req;
13331335

1334-
if (tp_rank_ == 0) {
1336+
if (is_driver_) {
13351337
req = std::make_shared<RequestData>();
13361338
{
13371339
NvtxScope _("pop");
@@ -1372,7 +1374,7 @@ void LlamaBatch::InternalThreadEntry()
13721374

13731375
ProcessCancelRequests(req->cancel, signals);
13741376

1375-
if (tp_rank_ == 0) {
1377+
if (is_driver_) {
13761378
gateway_->notify(std::move(signals));
13771379
}
13781380

@@ -1393,7 +1395,7 @@ void LlamaBatch::InternalThreadEntry()
13931395
comm_.h_tp_group->Sync();
13941396
}
13951397

1396-
if (tp_rank_ == 0) {
1398+
if (is_driver_) {
13971399
gateway_->notify(std::move(signals));
13981400
}
13991401
}
@@ -1426,7 +1428,7 @@ bool LlamaBatch::Forward(GenerationState& g)
14261428
const int active_size = state_->active_size;
14271429

14281430
constexpr int kLogInterval = 10;
1429-
if (tp_rank_ == 0 && (g.step - 1) % kLogInterval == 0) {
1431+
if (is_driver_ && (g.step - 1) % kLogInterval == 0) {
14301432
TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1);
14311433
}
14321434

@@ -1506,7 +1508,7 @@ bool LlamaBatch::Forward(GenerationState& g)
15061508
const int dc_batch_size = p ? 0 : pf_offset;
15071509
const int pf_batch_size = mini_batch_size - dc_batch_size;
15081510

1509-
if (tp_rank_ == 0) {
1511+
if (is_driver_) {
15101512
if (pf_batch_size) {
15111513
const auto max_q =
15121514
*std::max_element(h_input_length_buf_.data() + first, h_input_length_buf_.data() + last);
@@ -1622,7 +1624,7 @@ bool LlamaBatch::Forward(GenerationState& g)
16221624
});
16231625
AnomalyHandler::instance().Reset();
16241626

1625-
if (debug_ && tp_rank_ == 0) {
1627+
if (debug_ && is_driver_) {
16261628
std::vector<int> curr(active_size);
16271629
core::Copy(token_ids_buf_.data() + g.step * active_size, active_size, curr.data());
16281630
cudaStreamSynchronize(stream_);
@@ -1679,7 +1681,7 @@ void LlamaBatch::Warmup()
16791681
if (auto str = std::getenv("TM_GEMM_IMPORT")) {
16801682
std::ifstream ifs(str);
16811683
const int n_imported = linear.Import(ifs);
1682-
if (tp_rank_ == 0) {
1684+
if (is_driver_) {
16831685
TM_LOG_INFO("[Gemm2] %d records imported", n_imported);
16841686
}
16851687
return;
@@ -1697,7 +1699,7 @@ void LlamaBatch::Warmup()
16971699
bss.push_back(max_forward_token_num_);
16981700
}
16991701

1700-
if (tp_rank_ == 0) {
1702+
if (is_driver_) {
17011703
auto str = Join(bss.begin(), bss.end(), ", ");
17021704
TM_LOG_INFO("[Gemm2] Tuning sequence: %s", str.c_str());
17031705
}
@@ -1720,7 +1722,7 @@ void LlamaBatch::Warmup()
17201722

17211723
/// NOTE: No explicit barrier can be used here as internal threads are waiting on it now
17221724
for (auto token_num : bss) {
1723-
if (tp_rank_ == 0) {
1725+
if (is_driver_) {
17241726
TM_LOG_INFO("[Gemm2] %d", token_num);
17251727
}
17261728

@@ -1749,7 +1751,7 @@ void LlamaBatch::Warmup()
17491751

17501752
auto tock = std::chrono::steady_clock::now();
17511753

1752-
if (tp_rank_ == 0) {
1754+
if (is_driver_) {
17531755
TM_LOG_INFO("[Gemm2] Tuning finished in %.2f seconds.",
17541756
std::chrono::duration<float, std::ratio<1, 1>>(tock - tick).count());
17551757
}
@@ -1759,7 +1761,7 @@ void LlamaBatch::Warmup()
17591761
check_cuda_error(cudaStreamSynchronize(stream_));
17601762

17611763
// Only rank-0 exports the dispatch cache
1762-
if (tp_rank_ == 0) {
1764+
if (is_driver_) {
17631765
if (auto path = std::getenv("TM_GEMM_EXPORT")) {
17641766
std::ofstream ofs(path);
17651767
const auto n_records = context_->linear->Export(ofs);
@@ -1805,7 +1807,7 @@ void LlamaBatch::InitializeBufferAndKVCache()
18051807

18061808
const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len;
18071809
if (max_session_len < session_len_) {
1808-
if (tp_rank_ == 0) {
1810+
if (is_driver_) {
18091811
TM_LOG_WARNING("No enough blocks for `session_len` (%d), `session_len` truncated to %d.",
18101812
session_len_,
18111813
max_session_len);

src/turbomind/models/llama/LlamaBatch.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ class LlamaBatch {
115115
std::unique_ptr<Context> ctx,
116116
std::shared_ptr<Gateway> gateway,
117117
int device_id,
118-
int dp_rank);
118+
int dp_rank,
119+
bool is_driver);
119120

120121
~LlamaBatch();
121122

@@ -211,6 +212,7 @@ class LlamaBatch {
211212
const int tp_rank_;
212213
const DataType data_type_;
213214
const bool debug_;
215+
const bool is_driver_;
214216

215217
// Refs into `Context<T>`
216218
cudaStream_t const stream_{};

src/turbomind/triton_backend/llama/LlamaTritonModel.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -525,14 +525,16 @@ void LlamaTritonModel::createEngine(int device_id, int rank)
525525
h_comm->Sync();
526526

527527
try {
528-
const int dp_rank = engine_param.outer_dp_rank * engine_param.attn_dp_size + engine_param.attn_dp_rank;
529-
engines_[device_id] = std::make_unique<Engine>(dtype_,
530-
engine_param_, //
528+
const int dp_rank = engine_param.outer_dp_rank * engine_param.attn_dp_size + engine_param.attn_dp_rank;
529+
const bool is_driver = engine_param.attn_tp_rank == 0;
530+
engines_[device_id] = std::make_unique<Engine>(dtype_,
531+
engine_param, //
531532
std::move(model),
532533
std::move(ctx),
533534
gateway_,
534535
engine_param_.devices[device_id],
535-
dp_rank);
536+
dp_rank,
537+
is_driver);
536538
}
537539
catch (const std::exception& e) {
538540
TM_LOG_ERROR("[Engine][Init] %s", e.what());

0 commit comments

Comments
 (0)