2727from  forge .data .datasets .packed  import  PackedDataset , TextPacker 
2828from  forge .data .datasets .sft_dataset  import  AlpacaToMessages , sft_iterable_dataset 
2929from  forge .data .tokenizer  import  HuggingFaceModelTokenizer 
30+ from  forge .observability  import  get_or_create_metric_logger , record_metric , Reduce 
3031from  forge .util .config  import  parse 
3132
3233from  monarch .actor  import  current_rank , current_size , endpoint 
@@ -77,7 +78,6 @@ def __init__(self, config: DictConfig):
7778
7879        self .current_step  =  0 
7980        self .num_training_steps  =  job_config .training .steps 
80-         self .metric_logger  =  None   # TODO: fix this 
8181        self .gradient_accumulation_steps  =  1   # Example value, adjust as needed 
8282        self ._rank  =  current_rank ().rank 
8383        self ._size  =  math .prod (current_size ().values ())
@@ -109,9 +109,22 @@ def _init_dist(self):
109109        os .environ .update (env )
110110        logger .info ("env: {}" .format (env ))
111111
112+     async  def  setup_metric_logger (self ):
113+         """Initialization happens in the main process. Here we just retrieve it""" 
114+         mlogger  =  await  get_or_create_metric_logger ()
115+         return  mlogger 
116+ 
117+     def  record_batch_metrics (self , data_metrics : list ):
118+         """Since the dataloader creates new processes, we dont call `record_metric` in the dataset. 
119+         Instead, pop the metrics from the batch and record them here.""" 
120+         for  metric  in  data_metrics :
121+             record_metric (metric .key , metric .value , metric .reduction )
122+ 
112123    @endpoint  
113124    async  def  setup (self ):
114125        self .train_dataloader  =  self .setup_data ()
126+         self .mlogger  =  await  self .setup_metric_logger ()
127+ 
115128        # self.train_dataloader = self.setup_data( 
116129        #     self.train_config.train_dataset_config, 
117130        #     self.train_config.train_dataloader_config, 
@@ -234,7 +247,9 @@ def train_step(self, batch) -> None:
234247        # ) as grad_acc: 
235248        labels  =  batch .pop ("labels" )
236249        loss  =  self .forward_backward (batch , labels )
250+         loss  =  loss .item ()
237251
252+         record_metric ("ForgeSFTRecipe/train_step/loss" , loss , Reduce .MEAN )
238253        logger .info (f"{ self .current_step } { self .num_training_steps } { loss }  )
239254        # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") 
240255        # self.pbar.update(1) 
@@ -251,14 +266,25 @@ async def train(self) -> None:
251266
252267        while  self .current_step  <  self .num_training_steps :
253268            batch  =  next (dataloader )
269+ 
270+             # Pop and record metrics from batch before moving to device 
271+             self .record_batch_metrics (batch .pop ("metrics" , []))
272+             record_metric ("ForgeSFTRecipe/train/step" , self .current_step , Reduce .MEAN )
273+ 
254274            # Move tensors to the appropriate device 
255275            for  k , v  in  batch .items ():
256276                if  isinstance (v , torch .Tensor ):
257277                    batch [k ] =  v .to ("cuda" )  # TODO: hardcoded for now 
278+ 
258279            self .train_step (batch )
259280            # self.profiler.step() 
260281            self .current_step  +=  1 
261282
283+             # Flush metrics 
284+             if  self ._rank  ==  0 :
285+                 logger .debug (f"Flushing metrics at step { self .current_step }  )
286+                 await  self .mlogger .flush .call_one (global_step = self .current_step )
287+ 
262288            self .checkpointer .save (
263289                curr_step = self .current_step ,
264290                last_step = self .current_step  ==  self .num_training_steps ,
@@ -270,16 +296,23 @@ async def train(self) -> None:
270296    async  def  cleanup (self ) ->  None :
271297        if  self .checkpointer :
272298            self .checkpointer .close ()
273-         if  self . metric_logger :
274-             self .metric_logger . close ()
299+         if  getattr ( self ,  "mlogger" ,  None ) :
300+             await   self .mlogger . shutdown . call_one ()
275301
276302    def  __repr__ (self ) ->  str :
277303        return  "Trainer" 
278304
279305
280306async  def  run (cfg : DictConfig ) ->  None :
281-     logging .info ("Spawing recipe..." )
307+ 
308+     logging .info ("Spawning recipe..." )
282309    process_cfg  =  cfg .pop ("processes" )
310+ 
311+     # Initialize metric logger in main process 
312+     metric_logging_cfg  =  cfg .get ("metric_logging" , {})
313+     mlogger  =  await  get_or_create_metric_logger (process_name = "Controller" )
314+     await  mlogger .init_backends .call_one (metric_logging_cfg )
315+ 
283316    recipe  =  await  ForgeSFTRecipe .options (** process_cfg ).as_actor (cfg )
284317
285318    logging .info ("Created recipe, running setup." )
@@ -290,6 +323,7 @@ async def run(cfg: DictConfig) -> None:
290323
291324    logging .info ("Done training. Clean up" )
292325    await  recipe .cleanup .call ()
326+ 
293327    await  recipe .mesh .stop ()
294328    logging .info ("All done!" )
295329
0 commit comments