102102from  jetstream .core .metrics .prometheus  import  JetstreamMetricsCollector 
103103import  numpy  as  np 
104104
105- log_level  =  os .getenv ("LOG_LEVEL" , "WARNING" ).upper ()
105+ from  jax .experimental  import  layout  as  jax_layout 
106+ DLL  =  jax_layout .DeviceLocalLayout 
107+ Layout  =  jax_layout .Layout 
108+ 
109+ log_level  =  os .getenv ("LOG_LEVEL" , "DEBUG" ).upper ()
106110
107111logger  =  logging .getLogger ("JetstreamLogger" )
108112logger .propagate  =  False 
@@ -405,6 +409,26 @@ def __init__(
405409
406410    self ._jax_padding  =  jax_padding 
407411
412+     ##### Auto layout compile for interleaved engine 
413+     self ._generate_executables  =  [None  for  _  in  self ._generate_engines ]
414+     self ._cached_insert  =  [None  for  _  in  self ._generate_engines ]
415+     self ._cached_prefill  =  [None  for  _  in  self ._prefill_engines ]
416+     if  self ._interleaved_mode :
417+       for  idx  in  range (len (self ._generate_engines )):
418+         logger .debug ("Compiling interleaved engine {}" .format (idx ))
419+         engine  =  self ._generate_engines [idx ]
420+         params  =  self ._generate_params [idx ]
421+         engine , params , gen_fn , prefill_fn , insert_fn  =  self ._auto_layout_compile (engine , params )
422+ 
423+         self ._prefill_engines [idx ] =  engine 
424+         self ._generate_engines [idx ] =  engine 
425+         self ._prefill_params [idx ] =  params 
426+         self ._generate_params [idx ] =  params 
427+         self ._cached_prefill [idx ] =  prefill_fn 
428+         self ._cached_insert [idx ] =  insert_fn 
429+         self ._generate_executables [idx ] =  gen_fn 
430+ 
431+ 
408432    # Create all threads 
409433    self ._prefill_threads  =  [
410434        JetThread (
@@ -670,6 +694,56 @@ def _do_chunked_prefill(
670694
671695    return  prefill_result , first_token 
672696
697+   def  _auto_layout_compile (self , engine , params ):
698+     logger .debug ("Compiling generate function" )
699+     generate_executable , params , decode_state_executable  =  engine .aot_compile (
700+         params , pass_rng_shape = False 
701+     )
702+     decode_state  =  decode_state_executable (None )
703+ 
704+     # prefill 
705+     interesting_buckets  =  [
706+         64 ,
707+         128 ,
708+         256 ,
709+         512 ,
710+         1024 ,
711+     ]
712+ 
713+     cached_prefill  =  {}
714+     cached_insert  =  {}
715+     for  length  in  interesting_buckets :
716+       i32_scalar  =  jax .ShapeDtypeStruct ((), int )
717+       logger .debug ("Compiling prefill: %d" , length )
718+       input_data  =  jax .ShapeDtypeStruct ((length ,), jax .numpy .dtype ("int32" ))
719+ 
720+       cached_prefill [length ] =  (
721+           jax .jit (
722+             engine .prefill_aot ,
723+             in_shardings = (engine .param_layouts , None , None ),
724+             out_shardings = (Layout (DLL .AUTO ), Layout (DLL .AUTO )),
725+           ).lower (params , input_data , i32_scalar )
726+       ).compile (compiler_options = None )
727+ 
728+       logger .debug ("Generate dummy prefix: %d" , length )
729+       dummy_tokens  =  jax .numpy .ones (shape = (length ,), dtype = jax .numpy .dtype ("int32" ))
730+       prefix_shapes  =  jax .eval_shape (engine .prefill_aot , params , dummy_tokens , 1 )
731+       
732+       logger .debug ("Compiling insert: %d" , length )
733+       prefill_output_layout , _  =  cached_prefill [length ].output_layouts 
734+       logger .debug ("Prefill output layout: {}" .format (prefill_output_layout ))
735+       logger .debug ("Prefix shapes: {}" .format (prefix_shapes ))
736+       i32_scalar  =  jax .ShapeDtypeStruct ((), int )
737+       cached_insert [length ] =  (
738+           jax .jit (
739+             engine .insert ,
740+             in_shardings = (prefill_output_layout , engine .decode_state_layouts , None ),
741+             out_shardings = (engine .decode_state_layouts ),
742+             donate_argnames = ("decode_state" ),
743+           ).lower (prefix_shapes [0 ], engine .decode_state_shapes , i32_scalar )
744+       ).compile (compiler_options = None )
745+     return  engine , params , generate_executable , cached_prefill , cached_insert 
746+ 
673747  def  _prefill_thread (self , idx : int ):
674748    """Thread which runs in the background performing prefills.""" 
675749    logger .info ("Spinning up prefill thread %d." , idx )
@@ -683,6 +757,12 @@ def _prefill_thread(self, idx: int):
683757    thread_name  =  f"Prefill thread { idx }  
684758    ThreadDebugLog (thread_name , f"Prefill params { idx }  )
685759
760+     if  not  self .interleaved :
761+       prefill_engine , prefill_params , gen_fn , prefill_fn , insert_fn  =  self ._auto_layout_compile (
762+         prefill_engine , prefill_params 
763+       )
764+       self ._cached_prefill [idx ] =  prefill_fn 
765+ 
686766    while  self .live :
687767      my_transfer_backlog  =  self ._transfer_backlogs [idx ]
688768      # The prefill thread can just sleep until it has work to do. 
@@ -759,10 +839,11 @@ def _prefill_thread(self, idx: int):
759839          )
760840        else :
761841          # Compute new kv cache for the prefill_content. 
762-           prefill_result , first_token  =  prefill_engine .prefill (
763-               params = final_prefill_params ,
764-               padded_tokens = padded_tokens ,
765-               true_length = true_length ,
842+           assert  padded_tokens .shape [0 ] in  self ._cached_prefill [idx ]
843+           prefill_result , first_token  =  self ._cached_prefill [idx ][padded_tokens .shape [0 ]](
844+               final_prefill_params ,
845+               padded_tokens ,
846+               true_length ,
766847          )
767848
768849        request .complete  =  np .zeros (
@@ -967,10 +1048,11 @@ def _insert_if_possible(
9671048        else :
9681049          break 
9691050
970-       decode_state  =  generate_engine .insert (
1051+       length  =  new_request .prefill_result ['cache' ]['decoder' ]['layers_0' ]['self_attention' ]['KVCache_0' ]['cache_prefill_segment_id' ].value .shape [1 ]
1052+       decode_state  =  self ._cached_insert [idx ][length ](
9711053          new_request .prefill_result ,
9721054          decode_state ,
973-           slot = slot ,
1055+           slot ,
9741056          # request_id=new_request.request_id, 
9751057      )
9761058      ThreadDebugLog (
@@ -1115,9 +1197,15 @@ def _generate_thread(self, idx: int):
11151197    # Keep track of what step tokens were generated at. 
11161198    generate_timestep  =  0 
11171199    # State to store things like running kv cache in. 
1118-     decode_state  =  generate_engine .init_decode_state ()
1119- 
1200+     decode_state  =  self .decode_state 
11201201    generate_params  =  self ._generate_params [idx ]
1202+ 
1203+     if  not  self .interleaved :
1204+       generate_engine , generate_params , gen_fn , prefill_fn , insert_fn  =  self ._auto_layout_compile (
1205+         generate_engine , generate_params 
1206+       )
1207+       self ._generate_executables [idx ] =  gen_fn 
1208+ 
11211209    thread_name  =  f"Generate thread { idx }  
11221210    ThreadDebugLog (thread_name , f"Generate params { idx }  )
11231211    time_of_last_generate  =  time .time ()
@@ -1178,8 +1266,8 @@ def _generate_thread(self, idx: int):
11781266      ), "At this point we must have some requests inserted into the slots." 
11791267
11801268      # Now we actually take a generate step on requests in the slots. 
1181-       decode_state , sampled_tokens  =  generate_engine . generate (
1182-           generate_params , decode_state 
1269+       decode_state , sampled_tokens  =  self . _generate_executables [ idx ] (
1270+           generate_params , decode_state ,  None 
11831271      )
11841272      sampled_tokens .copy_to_host_async ()
11851273      # Respond to detokenization backpressure. 
0 commit comments