4343 UserWarning ,
4444 )
4545
46+ from instructlab .training .hpu_utils import is_torch_hpu_available
47+
48+ if is_torch_hpu_available ():
49+ import habana_frameworks .torch .core as htcore
50+ import habana_frameworks .torch .distributed .hccl
51+ from optimum .habana .transformers .modeling_utils import adapt_transformers_to_gaudi
52+ adapt_transformers_to_gaudi ()
53+
4654# Third Party
4755from instructlab .dolomite .hf_models import GPTDolomiteForCausalLM
4856from torch .utils .data import DataLoader
@@ -174,6 +182,13 @@ def setup_model(
174182 else :
175183 model = AutoModelForCausalLM .from_pretrained (** base_model_args )
176184
185+ if is_torch_hpu_available ():
186+ torch ._dynamo .config .cache_size_limit = int (1e4 )
187+ torch ._dynamo .config .accumulated_cache_size_limit = int (2e4 )
188+ model = torch .compile (model , backend = "hpu_backend" , dynamic = False )
189+ for layer in model .model .layers :
190+ layer .compile (backend = "hpu_backend" , dynamic = False )
191+
177192 # store the base model args so we can recall them later if saving a LoRA model
178193 args .base_model_args = base_model_args
179194
@@ -222,7 +237,22 @@ def setup_model(
222237 )
223238 model .config .eos_token_id = tokenizer .eos_token_id
224239
225- if "ForCausalLM" not in model .__class__ .__name__ :
240+ if not is_torch_hpu_available ():
241+ class_name = model .__class__ .__name__
242+ else :
243+ class_name = model ._orig_mod .__class__ .__name__ if model .__class__ .__name__ == 'OptimizedModule' else model .__class__ .__name__
244+
245+ replace_no_split_modules = {
246+ 'GaudiLlamaForCausalLM' : ['GaudiLlamaDecoderLayer' ,]
247+ }
248+
249+ if class_name in replace_no_split_modules :
250+ if model .__class__ .__name__ == 'OptimizedModule' :
251+ model ._orig_mod ._no_split_modules = replace_no_split_modules [class_name ]
252+ else :
253+ model ._no_split_modules = replace_no_split_modules [class_name ]
254+
255+ if "ForCausalLM" not in class_name :
226256 raise ValueError (
227257 f"Model class name: { model .__class__ .__name__ } is not supported."
228258 )
@@ -272,6 +302,11 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
272302 model .get_input_embeddings ().register_forward_hook (make_inputs_require_grad )
273303
274304 accelerator = setup_accelerator (args , model , grad_accum )
305+
306+ if is_torch_hpu_available ():
307+ accelerator .state .fsdp_plugin .use_orig_params = True
308+ accelerator .state .fsdp_plugin .sync_module_states = True
309+
275310 if args .distributed_training_framework == DistributedBackend .FSDP .value :
276311 model = accelerator .prepare (model )
277312 optimizer = setup_optimizer (args , model )
@@ -414,10 +449,19 @@ def train(
414449 total_length = float (torch .tensor ([batch .pop ("total_length" )]))
415450 if not args .use_dolomite :
416451 for k in batch :
417- batch [k ] = batch [k ].to (local_rank )
452+ batch [k ] = batch [k ].to ('hpu' if is_torch_hpu_available () else local_rank )
453+
454+ hpu_args = []
455+ if is_torch_hpu_available ():
456+ hpu_args = {
457+ "use_flash_attention" :True ,
458+ "lazy_mode" :False ,
459+ }
460+
418461 output = model (
419462 ** batch ,
420463 use_cache = False ,
464+ ** hpu_args ,
421465 )
422466 loss = output .loss
423467 log_loss = loss .detach ().item ()
@@ -454,8 +498,14 @@ def train(
454498 elapsed_time = time .time () - start
455499 overall_throughput = args .samples_per_gpu * world_size / elapsed_time
456500 current_lr = lr_scheduler .get_last_lr ()[0 ]
457- cuda_mem_allocated = torch .cuda .memory_allocated () / (1024 ** 3 )
458- cuda_malloc_retries = torch .cuda .memory_stats ()["num_alloc_retries" ]
501+
502+ if is_torch_hpu_available ():
503+ mem_allocated = torch .hpu .memory_allocated () / (1024 ** 3 )
504+ malloc_retries = 0
505+ else :
506+ mem_allocated = torch .cuda .memory_allocated () / (1024 ** 3 )
507+ malloc_retries = torch .cuda .memory_stats ()["num_alloc_retries" ]
508+
459509 global_grad_norm = (
460510 model .get_global_grad_norm ()
461511 if hasattr (model , "get_global_grad_norm" )
@@ -477,8 +527,8 @@ def train(
477527 "rank" : torch .distributed .get_rank (),
478528 "overall_throughput" : overall_throughput ,
479529 "lr" : current_lr ,
480- "cuda_mem_allocated" : cuda_mem_allocated ,
481- "cuda_malloc_retries" : cuda_malloc_retries ,
530+ ( "hpu" if is_torch_hpu_available () else "cuda" ) + "_mem_allocated" : mem_allocated ,
531+ ( "hpu" if is_torch_hpu_available () else "cuda" ) + "_malloc_retries" : malloc_retries ,
482532 "num_loss_counted_tokens" : int (num_loss_counted_tokens ),
483533 "num_tokens_rank0" : int (total_length ),
484534 "batch_size" : int (micro_batch_size ),
@@ -519,7 +569,10 @@ def train(
519569 global_step += 1
520570 if local_rank == 0 :
521571 inner_pb .update (1 )
522- torch .cuda .empty_cache ()
572+
573+ if not is_torch_hpu_available ():
574+ torch .cuda .empty_cache ()
575+
523576 if args .checkpoint_at_epoch :
524577 base_logger .debug (f"Saving checkpoint at epoch { epoch } " )
525578 save_checkpoint (
@@ -595,18 +648,27 @@ def main(args):
595648 args .model_type = model_conf .model_type
596649
597650 #### distributed init #####
598- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
651+ if is_torch_hpu_available ():
652+ torch .hpu .set_device (int (os .environ ["LOCAL_RANK" ]))
653+ else :
654+ torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
655+
599656 args .local_rank = int (os .environ ["LOCAL_RANK" ])
600657
601658 timeout = _get_collective_timeout ()
602- init = functools .partial (torch .distributed .init_process_group , "nccl" )
659+ init = functools .partial (torch .distributed .init_process_group , "hccl" if is_torch_hpu_available () else " nccl" )
603660 if timeout is not None :
604661 init (timeout = timeout )
605662 else :
606663 init ()
607664
608665 args .global_rank = torch .distributed .get_rank ()
609- tensor = torch .ByteTensor ([False ]).cuda ()
666+
667+ if is_torch_hpu_available ():
668+ tensor = torch .ByteTensor ([False ]).to ('hpu' )
669+ else :
670+ tensor = torch .ByteTensor ([False ]).cuda ()
671+
610672 torch .distributed .all_reduce (tensor )
611673 torch .distributed .barrier ()
612674
0 commit comments