102102from  jetstream .core .metrics .prometheus  import  JetstreamMetricsCollector 
103103import  numpy  as  np 
104104
105- log_level  =  os .getenv ("LOG_LEVEL" , "WARNING" ).upper ()
105+ from  jax .experimental  import  layout  as  jax_layout 
106+ DLL  =  jax_layout .DeviceLocalLayout 
107+ Layout  =  jax_layout .Layout 
108+ 
109+ log_level  =  os .getenv ("LOG_LEVEL" , "DEBUG" ).upper ()
106110
107111logger  =  logging .getLogger ("JetstreamLogger" )
108112logger .propagate  =  False 
@@ -405,6 +409,29 @@ def __init__(
405409
406410    self ._jax_padding  =  jax_padding 
407411
412+     ##### Auto layout compile for interleaved engine 
413+     self ._generate_executables  =  [None  for  _  in  self ._generate_engines ]
414+     self ._cached_insert  =  [None  for  _  in  self ._generate_engines ]
415+     self ._cached_prefill  =  [None  for  _  in  self ._prefill_engines ]
416+     self ._decode_states  =  [None  for  _  in  self ._generate_engines ]
417+     if  self ._interleaved_mode :
418+       for  idx  in  range (len (self ._generate_engines )):
419+         logger .debug ("Compiling interleaved engine {}" .format (idx ))
420+         engine  =  self ._generate_engines [idx ]
421+         params  =  self ._generate_params [idx ]
422+         engine , params , gen_fn , prefill_fn , insert_fn , decode_state  =  self ._auto_layout_compile (engine , params )
423+ 
424+         self ._prefill_engines [idx ] =  engine 
425+         self ._generate_engines [idx ] =  engine 
426+         self ._prefill_params [idx ] =  params 
427+         self ._generate_params [idx ] =  params 
428+         self ._cached_prefill [idx ] =  prefill_fn 
429+         self ._cached_insert [idx ] =  insert_fn 
430+         self ._generate_executables [idx ] =  gen_fn 
431+ 
432+         self ._decode_states [idx ] =  decode_state 
433+ 
434+ 
408435    # Create all threads 
409436    self ._prefill_threads  =  [
410437        JetThread (
@@ -670,6 +697,56 @@ def _do_chunked_prefill(
670697
671698    return  prefill_result , first_token 
672699
700+   def  _auto_layout_compile (self , engine , params ):
701+     logger .debug ("Compiling generate function" )
702+     generate_executable , params , decode_state_executable  =  engine .aot_compile (
703+         params , pass_rng_shape = False 
704+     )
705+     decode_state  =  decode_state_executable (None )
706+ 
707+     # prefill 
708+     interesting_buckets  =  [
709+         64 ,
710+         128 ,
711+         256 ,
712+         512 ,
713+         1024 ,
714+     ]
715+ 
716+     cached_prefill  =  {}
717+     cached_insert  =  {}
718+     for  length  in  interesting_buckets :
719+       i32_scalar  =  jax .ShapeDtypeStruct ((), int )
720+       logger .debug ("Compiling prefill: %d" , length )
721+       input_data  =  jax .ShapeDtypeStruct ((length ,), jax .numpy .dtype ("int32" ))
722+ 
723+       cached_prefill [length ] =  (
724+           jax .jit (
725+             engine .prefill_aot ,
726+             in_shardings = (engine .param_layouts , None , None ),
727+             out_shardings = (Layout (DLL .AUTO ), Layout (DLL .AUTO )),
728+           ).lower (params , input_data , i32_scalar )
729+       ).compile (compiler_options = None )
730+ 
731+       logger .debug ("Generate dummy prefix: %d" , length )
732+       dummy_tokens  =  jax .numpy .ones (shape = (length ,), dtype = jax .numpy .dtype ("int32" ))
733+       prefix_shapes  =  jax .eval_shape (engine .prefill_aot , params , dummy_tokens , 1 )
734+       
735+       logger .debug ("Compiling insert: %d" , length )
736+       prefill_output_layout , _  =  cached_prefill [length ].output_layouts 
737+       logger .debug ("Prefill output layout: {}" .format (prefill_output_layout ))
738+       logger .debug ("Prefix shapes: {}" .format (prefix_shapes ))
739+       i32_scalar  =  jax .ShapeDtypeStruct ((), int )
740+       cached_insert [length ] =  (
741+           jax .jit (
742+             engine .insert ,
743+             in_shardings = (prefill_output_layout , engine .decode_state_layouts , None ),
744+             out_shardings = (engine .decode_state_layouts ),
745+             donate_argnames = ("decode_state" ),
746+           ).lower (prefix_shapes [0 ], engine .decode_state_shapes , i32_scalar )
747+       ).compile (compiler_options = None )
748+     return  engine , params , generate_executable , cached_prefill , cached_insert , decode_state 
749+ 
673750  def  _prefill_thread (self , idx : int ):
674751    """Thread which runs in the background performing prefills.""" 
675752    logger .info ("Spinning up prefill thread %d." , idx )
@@ -683,6 +760,13 @@ def _prefill_thread(self, idx: int):
683760    thread_name  =  f"Prefill thread { idx }  " 
684761    ThreadDebugLog (thread_name , f"Prefill params { idx }   loaded." )
685762
763+     if  not  self ._interleaved_mode :
764+       logger .debug ("Compiling for disagg mode" )
765+       prefill_engine , prefill_params , gen_fn , prefill_fn , insert_fn , _  =  self ._auto_layout_compile (
766+         prefill_engine , prefill_params 
767+       )
768+       self ._cached_prefill [idx ] =  prefill_fn 
769+ 
686770    while  self .live :
687771      my_transfer_backlog  =  self ._transfer_backlogs [idx ]
688772      # The prefill thread can just sleep until it has work to do. 
@@ -759,10 +843,11 @@ def _prefill_thread(self, idx: int):
759843          )
760844        else :
761845          # Compute new kv cache for the prefill_content. 
762-           prefill_result , first_token  =  prefill_engine .prefill (
763-               params = final_prefill_params ,
764-               padded_tokens = padded_tokens ,
765-               true_length = true_length ,
846+           assert  padded_tokens .shape [0 ] in  self ._cached_prefill [idx ]
847+           prefill_result , first_token  =  self ._cached_prefill [idx ][padded_tokens .shape [0 ]](
848+               final_prefill_params ,
849+               padded_tokens ,
850+               true_length ,
766851          )
767852
768853        request .complete  =  np .zeros (
@@ -967,10 +1052,14 @@ def _insert_if_possible(
9671052        else :
9681053          break 
9691054
970-       decode_state  =  generate_engine .insert (
1055+       if  'decoder'  in  new_request .prefill_result ['cache' ]:
1056+         length  =  new_request .prefill_result ['cache' ]['decoder' ]['layers_0' ]['self_attention' ]['KVCache_0' ]['cache_prefill_segment_id' ].value .shape [1 ]
1057+       else :
1058+         length  =  new_request .prefill_result ['cache' ]['self_attention' ]['KVCache_0' ]['cache_prefill_segment_id' ].value .shape [2 ]
1059+       decode_state  =  self ._cached_insert [idx ][length ](
9711060          new_request .prefill_result ,
9721061          decode_state ,
973-           slot = slot ,
1062+           slot ,
9741063          # request_id=new_request.request_id, 
9751064      )
9761065      ThreadDebugLog (
@@ -1115,9 +1204,19 @@ def _generate_thread(self, idx: int):
11151204    # Keep track of what step tokens were generated at. 
11161205    generate_timestep  =  0 
11171206    # State to store things like running kv cache in. 
1118-     decode_state  =  generate_engine .init_decode_state ()
1119- 
11201207    generate_params  =  self ._generate_params [idx ]
1208+ 
1209+     if  not  self ._interleaved_mode :
1210+       logger .debug ("Compiling for disagg mode" )
1211+       generate_engine , generate_params , gen_fn , prefill_fn , insert_fn , decode_state  =  self ._auto_layout_compile (
1212+         generate_engine , generate_params 
1213+       )
1214+       self ._generate_executables [idx ] =  gen_fn 
1215+       self ._cached_insert [idx ] =  insert_fn 
1216+       self ._decode_states [idx ] =  decode_state 
1217+ 
1218+     decode_state  =  self ._decode_states [idx ]
1219+ 
11211220    thread_name  =  f"Generate thread { idx }  " 
11221221    ThreadDebugLog (thread_name , f"Generate params { idx }   loaded." )
11231222    time_of_last_generate  =  time .time ()
@@ -1178,8 +1277,8 @@ def _generate_thread(self, idx: int):
11781277      ), "At this point we must have some requests inserted into the slots." 
11791278
11801279      # Now we actually take a generate step on requests in the slots. 
1181-       decode_state , sampled_tokens  =  generate_engine . generate (
1182-           generate_params , decode_state 
1280+       decode_state , sampled_tokens  =  self . _generate_executables [ idx ] (
1281+           generate_params , decode_state ,  None 
11831282      )
11841283      sampled_tokens .copy_to_host_async ()
11851284      # Respond to detokenization backpressure. 
0 commit comments