Skip to content

Commit a53d6e1

Browse files
committed
Enable fine tuning on HPU
1 parent 1532531 commit a53d6e1

File tree

5 files changed

+152
-15
lines changed

5 files changed

+152
-15
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from functools import lru_cache
3+
4+
5+
@lru_cache(maxsize=None)
6+
def is_torch_hpu_available() -> bool:
7+
try:
8+
import habana_frameworks.torch.core # noqa: F401
9+
except ImportError:
10+
return False
11+
return True
12+
13+
14+
def simple_bucket(length):
15+
"""
16+
This bucket algorithm merely relies on the given number instead of based on
17+
slicing the known (min, max) range for several reasons:
18+
1) Due to the use of the first-fit-decreasing (FFD) algorithm, the
19+
(min, max) sequence length of each rank will be much smaller than the
20+
(min, max) sequence length of the dataset. Bucketing on the
21+
(min, max) sequence length of the dataset is not practical
22+
2) The (min, max) sequence length of a given rank is unknown until
23+
finishing 1 epoch since the packing is done on the fly
24+
3) Due to the shuffling, the (min, max) sequence length of a given rank
25+
may vary between ranks. Once the (min, max) sequence length of a
26+
given rank changes, the bucketing also needs adjustment
27+
28+
This bucket algorithm is based on the most significant set bit of the input number.
29+
It first check what’s the most significant set bit, assuming it's bit "S",
30+
and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size.
31+
By default the range is divided into 16 buckets, so the bucket size will be
32+
2 ** (S - 4)
33+
For example, 0b10001 will be padded to 0b10010.
34+
This approach can limit the overhead of bucketing (at most 1/16 of the input
35+
number) and also prevent recompilation due to a too small bucket size.
36+
"""
37+
l = length
38+
msb = 0
39+
while l > 0:
40+
msb += 1
41+
l = l // 2
42+
43+
align = (1 << (msb - 4)) if msb >= 4 else 1
44+
45+
return (length + align - 1) // align * align
46+
47+
48+
def bucket(length):
49+
return simple_bucket(length)

src/instructlab/training/main_ds.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@
4343
UserWarning,
4444
)
4545

46+
from instructlab.training.hpu_utils import is_torch_hpu_available
47+
48+
if is_torch_hpu_available():
49+
import habana_frameworks.torch.core as htcore
50+
import habana_frameworks.torch.distributed.hccl
51+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
52+
adapt_transformers_to_gaudi()
53+
4654
# Third Party
4755
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
4856
from torch.utils.data import DataLoader
@@ -174,6 +182,13 @@ def setup_model(
174182
else:
175183
model = AutoModelForCausalLM.from_pretrained(**base_model_args)
176184

185+
if is_torch_hpu_available():
186+
torch._dynamo.config.cache_size_limit = int(1e4)
187+
torch._dynamo.config.accumulated_cache_size_limit = int(2e4)
188+
model = torch.compile(model, backend="hpu_backend", dynamic=False)
189+
for layer in model.model.layers:
190+
layer.compile(backend="hpu_backend", dynamic=False)
191+
177192
# store the base model args so we can recall them later if saving a LoRA model
178193
args.base_model_args = base_model_args
179194

@@ -222,7 +237,22 @@ def setup_model(
222237
)
223238
model.config.eos_token_id = tokenizer.eos_token_id
224239

225-
if "ForCausalLM" not in model.__class__.__name__:
240+
if not is_torch_hpu_available():
241+
class_name = model.__class__.__name__
242+
else:
243+
class_name = model._orig_mod.__class__.__name__ if model.__class__.__name__ == 'OptimizedModule' else model.__class__.__name__
244+
245+
replace_no_split_modules = {
246+
'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',]
247+
}
248+
249+
if class_name in replace_no_split_modules:
250+
if model.__class__.__name__ == 'OptimizedModule':
251+
model._orig_mod._no_split_modules = replace_no_split_modules[class_name]
252+
else:
253+
model._no_split_modules = replace_no_split_modules[class_name]
254+
255+
if "ForCausalLM" not in class_name:
226256
raise ValueError(
227257
f"Model class name: {model.__class__.__name__} is not supported."
228258
)
@@ -272,6 +302,11 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
272302
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
273303

274304
accelerator = setup_accelerator(args, model, grad_accum)
305+
306+
if is_torch_hpu_available():
307+
accelerator.state.fsdp_plugin.use_orig_params=True
308+
accelerator.state.fsdp_plugin.sync_module_states=True
309+
275310
if args.distributed_training_framework == DistributedBackend.FSDP.value:
276311
model = accelerator.prepare(model)
277312
optimizer = setup_optimizer(args, model)
@@ -414,10 +449,19 @@ def train(
414449
total_length = float(torch.tensor([batch.pop("total_length")]))
415450
if not args.use_dolomite:
416451
for k in batch:
417-
batch[k] = batch[k].to(local_rank)
452+
batch[k] = batch[k].to('hpu' if is_torch_hpu_available() else local_rank)
453+
454+
hpu_args = []
455+
if is_torch_hpu_available():
456+
hpu_args = {
457+
"use_flash_attention":True,
458+
"lazy_mode":False,
459+
}
460+
418461
output = model(
419462
**batch,
420463
use_cache=False,
464+
**hpu_args,
421465
)
422466
loss = output.loss
423467
log_loss = loss.detach().item()
@@ -454,8 +498,14 @@ def train(
454498
elapsed_time = time.time() - start
455499
overall_throughput = args.samples_per_gpu * world_size / elapsed_time
456500
current_lr = lr_scheduler.get_last_lr()[0]
457-
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
458-
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
501+
502+
if is_torch_hpu_available():
503+
mem_allocated = torch.hpu.memory_allocated() / (1024**3)
504+
malloc_retries = 0
505+
else:
506+
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
507+
malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
508+
459509
global_grad_norm = (
460510
model.get_global_grad_norm()
461511
if hasattr(model, "get_global_grad_norm")
@@ -477,8 +527,8 @@ def train(
477527
"rank": torch.distributed.get_rank(),
478528
"overall_throughput": overall_throughput,
479529
"lr": current_lr,
480-
"cuda_mem_allocated": cuda_mem_allocated,
481-
"cuda_malloc_retries": cuda_malloc_retries,
530+
("hpu" if is_torch_hpu_available() else "cuda") + "_mem_allocated": mem_allocated,
531+
("hpu" if is_torch_hpu_available() else "cuda") + "_malloc_retries": malloc_retries,
482532
"num_loss_counted_tokens": int(num_loss_counted_tokens),
483533
"num_tokens_rank0": int(total_length),
484534
"batch_size": int(micro_batch_size),
@@ -519,7 +569,10 @@ def train(
519569
global_step += 1
520570
if local_rank == 0:
521571
inner_pb.update(1)
522-
torch.cuda.empty_cache()
572+
573+
if not is_torch_hpu_available():
574+
torch.cuda.empty_cache()
575+
523576
if args.checkpoint_at_epoch:
524577
base_logger.debug(f"Saving checkpoint at epoch {epoch}")
525578
save_checkpoint(
@@ -595,18 +648,27 @@ def main(args):
595648
args.model_type = model_conf.model_type
596649

597650
#### distributed init #####
598-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
651+
if is_torch_hpu_available():
652+
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
653+
else:
654+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
655+
599656
args.local_rank = int(os.environ["LOCAL_RANK"])
600657

601658
timeout = _get_collective_timeout()
602-
init = functools.partial(torch.distributed.init_process_group, "nccl")
659+
init = functools.partial(torch.distributed.init_process_group, "hccl" if is_torch_hpu_available() else "nccl")
603660
if timeout is not None:
604661
init(timeout=timeout)
605662
else:
606663
init()
607664

608665
args.global_rank = torch.distributed.get_rank()
609-
tensor = torch.ByteTensor([False]).cuda()
666+
667+
if is_torch_hpu_available():
668+
tensor = torch.ByteTensor([False]).to('hpu')
669+
else:
670+
tensor = torch.ByteTensor([False]).cuda()
671+
610672
torch.distributed.all_reduce(tensor)
611673
torch.distributed.barrier()
612674

src/instructlab/training/multipack_sampler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import torch
3535
import torch.distributed as dist
3636

37+
from instructlab.training.hpu_utils import is_torch_hpu_available, bucket
38+
3739

3840
def find_max_pack_len_with_padding(
3941
dataset,
@@ -395,6 +397,11 @@ def generate_batches(self, set_stats=False):
395397
)
396398

397399
lengths = self.lengths[indices]
400+
401+
if is_torch_hpu_available():
402+
bucket_v = np.vectorize(bucket)
403+
lengths = bucket_v(lengths)
404+
398405
lengths_cumsum = np.cumsum(lengths)
399406

400407
batches, total_used, total_slots = allocate(

src/instructlab/training/setup_accelerator.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from functools import partial
33

44
# Third Party
5-
from accelerate import Accelerator
65
from peft.utils.other import fsdp_auto_wrap_policy
76
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
87
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
@@ -12,6 +11,12 @@
1211
# First Party
1312
from instructlab.training.config import DeepSpeedOptions
1413
from instructlab.training.utils import get_module_class_from_name, patch_target_module
14+
from instructlab.training.hpu_utils import is_torch_hpu_available
15+
16+
if is_torch_hpu_available():
17+
from optimum.habana.accelerate import GaudiAccelerator
18+
else:
19+
from accelerate import Accelerator
1520

1621

1722
def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions):
@@ -51,7 +56,10 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption
5156

5257
def get_fsdp_config(args, model: PreTrainedModel):
5358
# Third Party
54-
from accelerate.utils import FullyShardedDataParallelPlugin
59+
if is_torch_hpu_available():
60+
from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin
61+
else:
62+
from accelerate.utils import FullyShardedDataParallelPlugin
5563
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
5664

5765
is_lora = args.lora_r > 0
@@ -73,7 +81,7 @@ def get_fsdp_config(args, model: PreTrainedModel):
7381
prefetch_policy = (
7482
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
7583
)
76-
fsdp_plugin = FullyShardedDataParallelPlugin(
84+
fsdp_plugin = (GaudiFullyShardedDataParallelPlugin if is_torch_hpu_available() else FullyShardedDataParallelPlugin)(
7785
auto_wrap_policy=wrap_policy,
7886
limit_all_gathers=True,
7987
backward_prefetch=prefetch_policy,
@@ -128,7 +136,7 @@ def setup_accelerator(args, model: PreTrainedModel, grad_accum):
128136
raise ValueError(
129137
f"Unknown sharding framework: {args.distributed_training_framework}"
130138
)
131-
accelerator = Accelerator(
139+
accelerator = (GaudiAccelerator if is_torch_hpu_available() else Accelerator)(
132140
**accel_args,
133141
)
134142
accelerator.even_batches = False

src/instructlab/training/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
QuantizeDataType,
5151
TrainingArgs,
5252
)
53+
from instructlab.training.hpu_utils import is_torch_hpu_available, bucket
5354

5455
logger = logging.getLogger("instructlab.training")
5556

@@ -209,6 +210,9 @@ def listen(self):
209210

210211

211212
def supports_flash_attention(device_id=0):
213+
if is_torch_hpu_available():
214+
return False
215+
212216
"""Check if a GPU supports FlashAttention."""
213217
major, minor = torch.cuda.get_device_capability(device_id)
214218
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
@@ -300,6 +304,9 @@ def pad_collate_fn(batch):
300304
lens = np.array([len(item["input_ids"]) for item in batch])
301305
max_len = max(lens)
302306

307+
if is_torch_hpu_available():
308+
max_len = bucket(max_len)
309+
303310
input_ids = torch.stack(
304311
[
305312
F.pad(
@@ -411,6 +418,7 @@ def reduce_sum_forward(
411418
output_attentions=output_attentions,
412419
output_hidden_states=output_hidden_states,
413420
return_dict=return_dict,
421+
**_deprecated_arguments if is_torch_hpu_available() else None,
414422
)
415423

416424
return_dict = isinstance(output, dict)
@@ -1093,7 +1101,10 @@ def set_random_seed(seed):
10931101
random.seed(seed)
10941102
np.random.seed(seed)
10951103
torch.manual_seed(seed)
1096-
torch.cuda.manual_seed_all(seed)
1104+
if is_torch_hpu_available():
1105+
torch.hpu.manual_seed_all(seed)
1106+
else:
1107+
torch.cuda.manual_seed_all(seed)
10971108

10981109

10991110
def save_checkpoint(

0 commit comments

Comments
 (0)