Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't use Trainer on mps device #35954

Open
2 of 4 tasks
smartliuhw opened this issue Jan 29, 2025 · 1 comment
Open
2 of 4 tasks

Can't use Trainer on mps device #35954

smartliuhw opened this issue Jan 29, 2025 · 1 comment
Labels

Comments

@smartliuhw
Copy link

System Info

version information:
transformers -- 4.48.1
torch -- 2.2.2
accelerate -- 1.3.0
system -- mac os 15.2 (24C101)

Who can help?

@muellerzr @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm using c3 dataset to finetune Qwen2.5-3B-Instruct for classification task. The code runs successfully on a cuda device, while I try to run the same code on the mps device, got error messages below
Image
The following are the most important part of my code:

## Train the model

exp_num=0

training_args = TrainingArguments(
    output_dir=f"./output/{lora_rank}_{exp_num}",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=True,
    bf16=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    tokenizer=tokenizer,
    data_collator=datacollator,
    compute_metrics=compute_metrics
)
logger.info("Start training")
trainer.train()

Expected behavior

The code runs sucessfully

@smartliuhw smartliuhw added the bug label Jan 29, 2025
@SunMarc
Copy link
Member

SunMarc commented Jan 29, 2025

This is indeed not supported yet. Follow the this PR which should solve your issue. huggingface/accelerate#3373

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants