@@ -3,7 +3,6 @@ use crate::validation::{Validation, ValidationError};
33use crate :: GenerateRequest ;
44use crate :: { Entry , Queue , Token } ;
55use nohash_hasher:: IntMap ;
6- use std:: future:: Future ;
76use std:: sync:: Arc ;
87use text_generation_client:: {
98 Batch , ClientError , GeneratedText , Generation , PrefillTokens , ShardedClient ,
@@ -81,6 +80,7 @@ impl Infer {
8180 . limit_concurrent_requests
8281 . try_acquire_owned ( )
8382 . map_err ( |err| {
83+ metrics:: increment_counter!( "tgi_request_failure" , "err" => "overloaded" ) ;
8484 tracing:: error!( "{err}" ) ;
8585 err
8686 } ) ?;
@@ -172,6 +172,7 @@ impl Infer {
172172 } )
173173 } else {
174174 let err = InferError :: IncompleteGeneration ;
175+ metrics:: increment_counter!( "tgi_request_failure" , "err" => "incomplete" ) ;
175176 tracing:: error!( "{err}" ) ;
176177 Err ( err)
177178 }
@@ -201,7 +202,7 @@ async fn batching_task(
201202 // This batch might be smaller than the maximum batch size if there are not enough requests
202203 // waiting in the queue
203204 while let Some ( ( mut entries, batch, span) ) = queue. next_batch ( None , max_batch_size) . await {
204- let mut cached_batch = wrap_future ( client . prefill ( batch) , & mut entries)
205+ let mut cached_batch = prefill ( & mut client , batch, & mut entries)
205206 . instrument ( span)
206207 . await ;
207208 let mut waiting_tokens = 1 ;
@@ -212,6 +213,7 @@ async fn batching_task(
212213 // Get current batch info
213214 let batch_size = batch. size ;
214215 let mut batches = vec ! [ batch] ;
216+ metrics:: gauge!( "tgi_batch_current_size" , batch_size as f64 ) ;
215217
216218 // If the current batch is too small, we try to add more requests to it
217219 if batch_size <= limit_min_batch_size {
@@ -241,10 +243,9 @@ async fn batching_task(
241243 } ) ;
242244
243245 // Generate one token for this new batch to have the attention past in cache
244- let new_cached_batch =
245- wrap_future ( client. prefill ( new_batch) , & mut new_entries)
246- . instrument ( span)
247- . await ;
246+ let new_cached_batch = prefill ( & mut client, new_batch, & mut new_entries)
247+ . instrument ( span)
248+ . await ;
248249 // Reset waiting counter
249250 waiting_tokens = 1 ;
250251 // Extend current batch with the new batch
@@ -268,29 +269,59 @@ async fn batching_task(
268269 entry. temp_span = Some ( entry_batch_span) ;
269270 } ) ;
270271
271- cached_batch = wrap_future ( client . decode ( batches) , & mut entries)
272+ cached_batch = decode ( & mut client , batches, & mut entries)
272273 . instrument ( next_batch_span)
273274 . await ;
274275 waiting_tokens += 1 ;
275276 }
277+ metrics:: gauge!( "tgi_batch_current_size" , 0.0 ) ;
276278 }
277279 }
278280}
279281
280- /// Wrap a future inside a match statement to handle errors and send the responses to Infer
281282#[ instrument( skip_all) ]
282- async fn wrap_future (
283- future : impl Future < Output = Result < ( Vec < Generation > , Option < Batch > ) , ClientError > > ,
283+ async fn prefill (
284+ client : & mut ShardedClient ,
285+ batch : Batch ,
284286 entries : & mut IntMap < u64 , Entry > ,
285287) -> Option < Batch > {
286- match future. await {
288+ let start_time = Instant :: now ( ) ;
289+
290+ match client. prefill ( batch) . await {
291+ Ok ( ( generations, next_batch) ) => {
292+ send_generations ( generations, entries) ;
293+ metrics:: histogram!( "tgi_batch_inference_duration" , start_time. elapsed( ) , "method" => "prefill" ) ;
294+ metrics:: increment_counter!( "tgi_batch_inference_success" , "method" => "prefill" ) ;
295+ next_batch
296+ }
297+ // If we have an error, we discard the whole batch
298+ Err ( err) => {
299+ send_errors ( err, entries) ;
300+ metrics:: increment_counter!( "tgi_batch_inference_failure" , "method" => "prefill" ) ;
301+ None
302+ }
303+ }
304+ }
305+
306+ #[ instrument( skip_all) ]
307+ async fn decode (
308+ client : & mut ShardedClient ,
309+ batches : Vec < Batch > ,
310+ entries : & mut IntMap < u64 , Entry > ,
311+ ) -> Option < Batch > {
312+ let start_time = Instant :: now ( ) ;
313+
314+ match client. decode ( batches) . await {
287315 Ok ( ( generations, next_batch) ) => {
288316 send_generations ( generations, entries) ;
317+ metrics:: histogram!( "tgi_batch_inference_duration" , start_time. elapsed( ) , "method" => "decode" ) ;
318+ metrics:: increment_counter!( "tgi_batch_inference_success" , "method" => "decode" ) ;
289319 next_batch
290320 }
291321 // If we have an error, we discard the whole batch
292322 Err ( err) => {
293323 send_errors ( err, entries) ;
324+ metrics:: increment_counter!( "tgi_batch_inference_failure" , "method" => "decode" ) ;
294325 None
295326 }
296327 }
@@ -303,6 +334,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
303334 // Create and enter a span to link this function back to the entry
304335 let _send_error_span = info_span ! ( parent: entry. temp_span. as_ref( ) . expect( "batch_span is None. This is a bug." ) , "send_error" ) . entered ( ) ;
305336 let err = InferError :: GenerationError ( error. to_string ( ) ) ;
337+ metrics:: increment_counter!( "tgi_request_failure" , "err" => "generation" ) ;
306338 tracing:: error!( "{err}" ) ;
307339
308340 // unwrap_or is valid here as we don't care if the receiver is gone.
0 commit comments