-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Labels
Description
System Info
AWS EC2 instance: trn1.32xlarge
Platform:
- Platform: Linux-5.15.0-1031-aws-x86_64-with-glibc2.35
- Python version: 3.11.9
Python packages:
- `optimum-neuron` version: 0.0.24
- `neuron-sdk` version: 2.19.1
- `optimum` version: 1.20.0
- `transformers` version: 4.41.1
- `huggingface_hub` version: 0.24.6
- `torch` version: 2.1.2+cu121
- `aws-neuronx-runtime-discovery` version: 2.9
- `libneuronxla` version: 2.0.2335
- `neuronx-cc` version: 2.14.213.0
- `neuronx-distributed` version: 0.8.0
- `neuronx-hwm` version: NA
- `torch-neuronx` version: 2.1.2.2.2.0
- `torch-xla` version: 2.1.3
- `transformers-neuronx` version: NA
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction (minimal, reproducible, runnable)
Error message:
Traceback (most recent call last):
File "/home/ubuntu/projects/seq2seq/train_t5_small.py", line 48, in <module>
trainer = Seq2SeqNeuronTrainer(
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/miniconda3/envs/py311/lib/python3.11/site-packages/optimum/neuron/trainers.py", line 144, in __init__
raise ValueError(
ValueError: The NeuronTrainer only accept NeuronTrainingArguments, but <class 'optimum.neuron.training_args.Seq2SeqNeuronTrainingArguments'> was provided.
Minimal example to reproduce:
Run the following script with torchrun train.py
.
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from optimum.neuron import Seq2SeqNeuronTrainer, Seq2SeqNeuronTrainingArguments
from optimum.neuron.distributed import lazy_load_for_parallelism
# Load dataset
dataset = load_dataset("samsum")
# Load tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-small")
# Preprocess the data
def preprocess_function(examples):
inputs = ["summarize: " + doc for doc in examples["dialogue"]]
model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length')
with tokenizer.as_target_tokenizer():
labels = tokenizer(examples["summary"], max_length=150, truncation=True, padding='max_length')
model_inputs["labels"] = labels["input_ids"]
print("keys", model_inputs.keys())
print("len labels", len(model_inputs['labels']))
print("len inpids", len(model_inputs['input_ids']))
print("len attmsk", len(model_inputs['attention_mask']))
return model_inputs
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# Define training arguments
training_args = Seq2SeqNeuronTrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=3,
predict_with_generate=False, # should be false since we don't provide a generation_config
)
# Load model
with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
# Initialize the trainer
trainer = Seq2SeqNeuronTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
)
# Train the model
trainer.train()
Expected behavior
The NeuronTrainer accepts Seq2SeqNeuronTrainingArguments.
I have a workaround going where I have patched these lines to accept Seq2SeqNeuronTrainingArguments
:
if not isinstance(self.args, NeuronTrainingArguments) and not isinstance(self.args, Seq2SeqNeuronTrainingArguments):
raise ValueError(
f"The NeuronTrainer only accepts NeuronTrainingArguments and Seq2SeqNeuronTrainingArguments, but {type(self.args)} was provided."
)