3333 UserWarning ,
3434 )
3535
36+ from instructlab .training .hpu_utils import is_torch_hpu_available
37+
38+ if is_torch_hpu_available ():
39+ import habana_frameworks .torch .core as htcore
40+ import habana_frameworks .torch .distributed .hccl
41+ from optimum .habana .transformers .modeling_utils import adapt_transformers_to_gaudi
42+ adapt_transformers_to_gaudi ()
43+
3644# Third Party
3745from tqdm import tqdm
3846from transformers import AutoConfig
@@ -139,10 +147,19 @@ def train(
139147 total_length = float (torch .tensor ([batch .pop ("total_length" )]))
140148 if not args .use_dolomite :
141149 for k in batch :
142- batch [k ] = batch [k ].to (local_rank )
150+ batch [k ] = batch [k ].to ('hpu' if is_torch_hpu_available () else local_rank )
151+
152+ hpu_args = {}
153+ if is_torch_hpu_available ():
154+ hpu_args = {
155+ "use_flash_attention" :True ,
156+ "lazy_mode" :False ,
157+ }
158+
143159 output = model (
144160 ** batch ,
145161 use_cache = False ,
162+ ** hpu_args ,
146163 )
147164 loss = output .loss
148165 log_loss = loss .detach ().item ()
@@ -179,8 +196,14 @@ def train(
179196 elapsed_time = time .time () - start
180197 overall_throughput = args .samples_per_gpu * world_size / elapsed_time
181198 current_lr = accelerator .lr_scheduler .get_last_lr ()[0 ]
182- cuda_mem_allocated = torch .cuda .memory_allocated () / (1024 ** 3 )
183- cuda_malloc_retries = torch .cuda .memory_stats ()["num_alloc_retries" ]
199+
200+ if is_torch_hpu_available ():
201+ mem_allocated = torch .hpu .memory_allocated () / (1024 ** 3 )
202+ malloc_retries = 0
203+ else :
204+ mem_allocated = torch .cuda .memory_allocated () / (1024 ** 3 )
205+ malloc_retries = torch .cuda .memory_stats ()["num_alloc_retries" ]
206+
184207 global_grad_norm = (
185208 model .get_global_grad_norm ()
186209 if hasattr (model , "get_global_grad_norm" )
@@ -202,8 +225,8 @@ def train(
202225 "rank" : torch .distributed .get_rank (),
203226 "overall_throughput" : overall_throughput ,
204227 "lr" : current_lr ,
205- "cuda_mem_allocated" : cuda_mem_allocated ,
206- "cuda_malloc_retries" : cuda_malloc_retries ,
228+ ( "hpu" if is_torch_hpu_available () else "cuda" ) + "_mem_allocated" : mem_allocated ,
229+ ( "hpu" if is_torch_hpu_available () else "cuda" ) + "_malloc_retries" : malloc_retries ,
207230 "num_loss_counted_tokens" : int (num_loss_counted_tokens ),
208231 "num_tokens_rank0" : int (total_length ),
209232 "batch_size" : int (micro_batch_size ),
@@ -236,7 +259,10 @@ def train(
236259 global_step += 1
237260 if local_rank == 0 :
238261 inner_pb .update (1 )
239- torch .cuda .empty_cache ()
262+
263+ if not is_torch_hpu_available ():
264+ torch .cuda .empty_cache ()
265+
240266 if args .checkpoint_at_epoch :
241267 base_logger .debug (f"Saving checkpoint at epoch { epoch } " )
242268 save_checkpoint (
@@ -314,17 +340,24 @@ def main(args):
314340 args .model_type = model_conf .model_type
315341
316342 #### distributed init #####
317- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
343+ if is_torch_hpu_available ():
344+ torch .hpu .set_device (int (os .environ ["LOCAL_RANK" ]))
345+ else :
346+ torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
347+
318348 args .local_rank = int (os .environ ["LOCAL_RANK" ])
319349
320350 timeout = _get_collective_timeout ()
321- if timeout is not None :
322- torch .distributed .init_process_group (timeout = timeout )
323- else :
324- torch .distributed .init_process_group ()
351+ backend = "hccl" if is_torch_hpu_available () else None
352+ torch .distributed .init_process_group (backend = backend , timeout = timeout )
325353
326354 args .global_rank = torch .distributed .get_rank ()
327- tensor = torch .ByteTensor ([False ]).cuda ()
355+
356+ if is_torch_hpu_available ():
357+ tensor = torch .ByteTensor ([False ]).to ('hpu' )
358+ else :
359+ tensor = torch .ByteTensor ([False ]).cuda ()
360+
328361 torch .distributed .all_reduce (tensor )
329362 torch .distributed .barrier ()
330363
0 commit comments