@@ -202,7 +202,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector<Signal>&
202202 int idx = 0 ;
203203 for (const auto & r : reqs) {
204204
205- if (tp_rank_ == 0 ) {
205+ if (is_driver_ ) {
206206 TM_LOG_INFO (" [ProcessInferRequests] Request for %ld received." , (long )r->id );
207207 }
208208
@@ -230,7 +230,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector<Signal>&
230230 s = ptr->tokens .size ();
231231 }
232232 else if (s > ptr->tokens .size ()) {
233- if (tp_rank_ == 0 ) {
233+ if (is_driver_ ) {
234234 TM_LOG_WARNING (" [ProcessInferRequests] Skipping invalid step (%d) setting for ID %lu" , s, ptr->id );
235235 }
236236 s = ptr->tokens .size ();
@@ -363,7 +363,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector<Signal>&
363363 // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
364364 if (state.seq_len_limit [idx] >= session_len_) {
365365 state.seq_len_limit [idx] = session_len_ - 1 ;
366- if (tp_rank_ == 0 ) {
366+ if (is_driver_ ) {
367367 const int trunc_output_len = state.seq_len_limit [idx] - state.h_context_length [idx];
368368 TM_LOG_WARNING (
369369 " [ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `max_new_tokens` is truncated to %d" ,
@@ -840,7 +840,8 @@ LlamaBatch::LlamaBatch(DataType data_type,
840840 std::unique_ptr<Context> ctx, // ! This is moved
841841 std::shared_ptr<Gateway> gateway,
842842 int device_id,
843- int dp_rank):
843+ int dp_rank,
844+ bool is_driver):
844845 param_ (param),
845846 gateway_ (gateway),
846847 max_batch_size_ (param.max_batch_size),
@@ -852,6 +853,7 @@ LlamaBatch::LlamaBatch(DataType data_type,
852853 dp_rank_ (dp_rank),
853854 tp_size_ (model->tp_size_),
854855 tp_rank_ (model->tp_rank_),
856+ is_driver_ (is_driver),
855857 data_type_ (data_type),
856858 debug_ (isDebug()),
857859 stream_ (ctx->stream),
@@ -980,7 +982,7 @@ void LlamaBatch::ComputeAndOutputLogits(const Tensor& hidden_states, int first,
980982
981983 auto logits = model_->postDecodeEmbedding (hidden_states, symm_logits_buf_.buffer ());
982984
983- if (tp_rank_ == 0 ) {
985+ if (is_driver_ ) {
984986 OutputLogits (logits, first, last, GenerationConfig::kAll );
985987 }
986988}
@@ -1141,7 +1143,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
11411143 }
11421144
11431145 // ! Only rank-0 writes to output
1144- if (tp_rank_ == 0 && output_logprobs) {
1146+ if (is_driver_ && output_logprobs) {
11451147 NvtxScope scope (" logprobs" );
11461148 float * sampled_logprobs_ptr = h_sampled_logprobs_.data ();
11471149 uint32_t * sampled_indexes_ptr = h_sampled_indexes_.data ();
@@ -1168,7 +1170,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
11681170 }
11691171
11701172 // ! Only rank-0 writes to output
1171- if (tp_rank_ == 0 ) {
1173+ if (is_driver_ ) {
11721174 NvtxScope scope (" output_ids" );
11731175 for (int i = 0 ; i < batch_size - g.partial ; ++i) {
11741176 if (auto & r = state_->requests [i]) {
@@ -1184,7 +1186,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
11841186 // Cache computed blocks to block trie
11851187 sequence_manager_->CacheIfEnabled (state_->sequences , batch_size);
11861188
1187- if (debug_ && tp_rank_ == 0 ) {
1189+ if (debug_ && is_driver_ ) {
11881190 for (int i = 0 ; i < batch_size; ++i) {
11891191 // ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")";
11901192 std::vector<int > tokens (state_->h_context_length [i]);
@@ -1225,7 +1227,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
12251227 // Interrupt should reset r
12261228 FT_CHECK (!r);
12271229 }
1228- else if (r->stream_output && tp_rank_ == 0 ) {
1230+ else if (r->stream_output && is_driver_ ) {
12291231 const auto seq_len = *r->sequence_length .data ();
12301232 // Create signals by copying the request handles for non-finished streaming requests
12311233 signals.push_back ([this , r, seq_len] { //
@@ -1249,15 +1251,15 @@ void LlamaBatch::Finish(GenerationState& g, std::vector<Signal>& signals)
12491251
12501252auto LlamaBatch::Interrupt (int index, bool force_stop, bool force_end) -> Signal
12511253{
1252- if (tp_rank_ == 0 ) {
1254+ if (is_driver_ ) {
12531255 TM_LOG_INFO (" [Interrupt] slot %d, request %lu, stop %d, end %d" ,
12541256 index,
12551257 (long )state_->requests [index]->id ,
12561258 force_stop,
12571259 force_end);
12581260 }
12591261
1260- if (debug_ && tp_rank_ == 0 ) {
1262+ if (debug_ && is_driver_ ) {
12611263 std::vector<int > tokens (state_->h_context_length [index]);
12621264 core::Copy (state_->output_ids .data () + index * session_len_, tokens.size (), tokens.data ());
12631265 cudaStreamSynchronize (stream_);
@@ -1331,7 +1333,7 @@ void LlamaBatch::InternalThreadEntry()
13311333
13321334 std::shared_ptr<RequestData> req;
13331335
1334- if (tp_rank_ == 0 ) {
1336+ if (is_driver_ ) {
13351337 req = std::make_shared<RequestData>();
13361338 {
13371339 NvtxScope _ (" pop" );
@@ -1372,7 +1374,7 @@ void LlamaBatch::InternalThreadEntry()
13721374
13731375 ProcessCancelRequests (req->cancel , signals);
13741376
1375- if (tp_rank_ == 0 ) {
1377+ if (is_driver_ ) {
13761378 gateway_->notify (std::move (signals));
13771379 }
13781380
@@ -1393,7 +1395,7 @@ void LlamaBatch::InternalThreadEntry()
13931395 comm_.h_tp_group ->Sync ();
13941396 }
13951397
1396- if (tp_rank_ == 0 ) {
1398+ if (is_driver_ ) {
13971399 gateway_->notify (std::move (signals));
13981400 }
13991401 }
@@ -1426,7 +1428,7 @@ bool LlamaBatch::Forward(GenerationState& g)
14261428 const int active_size = state_->active_size ;
14271429
14281430 constexpr int kLogInterval = 10 ;
1429- if (tp_rank_ == 0 && (g.step - 1 ) % kLogInterval == 0 ) {
1431+ if (is_driver_ && (g.step - 1 ) % kLogInterval == 0 ) {
14301432 TM_LOG_INFO (" ------------------------- step = %d -------------------------" , g.step - 1 );
14311433 }
14321434
@@ -1506,7 +1508,7 @@ bool LlamaBatch::Forward(GenerationState& g)
15061508 const int dc_batch_size = p ? 0 : pf_offset;
15071509 const int pf_batch_size = mini_batch_size - dc_batch_size;
15081510
1509- if (tp_rank_ == 0 ) {
1511+ if (is_driver_ ) {
15101512 if (pf_batch_size) {
15111513 const auto max_q =
15121514 *std::max_element (h_input_length_buf_.data () + first, h_input_length_buf_.data () + last);
@@ -1622,7 +1624,7 @@ bool LlamaBatch::Forward(GenerationState& g)
16221624 });
16231625 AnomalyHandler::instance ().Reset ();
16241626
1625- if (debug_ && tp_rank_ == 0 ) {
1627+ if (debug_ && is_driver_ ) {
16261628 std::vector<int > curr (active_size);
16271629 core::Copy (token_ids_buf_.data () + g.step * active_size, active_size, curr.data ());
16281630 cudaStreamSynchronize (stream_);
@@ -1679,7 +1681,7 @@ void LlamaBatch::Warmup()
16791681 if (auto str = std::getenv (" TM_GEMM_IMPORT" )) {
16801682 std::ifstream ifs (str);
16811683 const int n_imported = linear.Import (ifs);
1682- if (tp_rank_ == 0 ) {
1684+ if (is_driver_ ) {
16831685 TM_LOG_INFO (" [Gemm2] %d records imported" , n_imported);
16841686 }
16851687 return ;
@@ -1697,7 +1699,7 @@ void LlamaBatch::Warmup()
16971699 bss.push_back (max_forward_token_num_);
16981700 }
16991701
1700- if (tp_rank_ == 0 ) {
1702+ if (is_driver_ ) {
17011703 auto str = Join (bss.begin (), bss.end (), " , " );
17021704 TM_LOG_INFO (" [Gemm2] Tuning sequence: %s" , str.c_str ());
17031705 }
@@ -1720,7 +1722,7 @@ void LlamaBatch::Warmup()
17201722
17211723 // / NOTE: No explicit barrier can be used here as internal threads are waiting on it now
17221724 for (auto token_num : bss) {
1723- if (tp_rank_ == 0 ) {
1725+ if (is_driver_ ) {
17241726 TM_LOG_INFO (" [Gemm2] %d" , token_num);
17251727 }
17261728
@@ -1749,7 +1751,7 @@ void LlamaBatch::Warmup()
17491751
17501752 auto tock = std::chrono::steady_clock::now ();
17511753
1752- if (tp_rank_ == 0 ) {
1754+ if (is_driver_ ) {
17531755 TM_LOG_INFO (" [Gemm2] Tuning finished in %.2f seconds." ,
17541756 std::chrono::duration<float , std::ratio<1 , 1 >>(tock - tick).count ());
17551757 }
@@ -1759,7 +1761,7 @@ void LlamaBatch::Warmup()
17591761 check_cuda_error (cudaStreamSynchronize (stream_));
17601762
17611763 // Only rank-0 exports the dispatch cache
1762- if (tp_rank_ == 0 ) {
1764+ if (is_driver_ ) {
17631765 if (auto path = std::getenv (" TM_GEMM_EXPORT" )) {
17641766 std::ofstream ofs (path);
17651767 const auto n_records = context_->linear ->Export (ofs);
@@ -1805,7 +1807,7 @@ void LlamaBatch::InitializeBufferAndKVCache()
18051807
18061808 const size_t max_session_len = sequence_manager_->max_block_count () * cache_block_seq_len;
18071809 if (max_session_len < session_len_) {
1808- if (tp_rank_ == 0 ) {
1810+ if (is_driver_ ) {
18091811 TM_LOG_WARNING (" No enough blocks for `session_len` (%d), `session_len` truncated to %d." ,
18101812 session_len_,
18111813 max_session_len);
0 commit comments