Skip to content

ValueError: The NeuronTrainer only accept NeuronTrainingArguments, but <class 'optimum.neuron.training_args.Seq2SeqNeuronTrainingArguments'> was provided. #693

@industrialeaf

Description

@industrialeaf

System Info

AWS EC2 instance: trn1.32xlarge

Platform:

- Platform: Linux-5.15.0-1031-aws-x86_64-with-glibc2.35
- Python version: 3.11.9


Python packages:

- `optimum-neuron` version: 0.0.24
- `neuron-sdk` version: 2.19.1
- `optimum` version: 1.20.0
- `transformers` version: 4.41.1
- `huggingface_hub` version: 0.24.6
- `torch` version: 2.1.2+cu121
- `aws-neuronx-runtime-discovery` version: 2.9
- `libneuronxla` version: 2.0.2335
- `neuronx-cc` version: 2.14.213.0
- `neuronx-distributed` version: 0.8.0
- `neuronx-hwm` version: NA
- `torch-neuronx` version: 2.1.2.2.2.0
- `torch-xla` version: 2.1.3
- `transformers-neuronx` version: NA

Who can help?

@michaelbenayoun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

Error message:

Traceback (most recent call last):
  File "/home/ubuntu/projects/seq2seq/train_t5_small.py", line 48, in <module>
    trainer = Seq2SeqNeuronTrainer(
              ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/envs/py311/lib/python3.11/site-packages/optimum/neuron/trainers.py", line 144, in __init__
    raise ValueError(
ValueError: The NeuronTrainer only accept NeuronTrainingArguments, but <class 'optimum.neuron.training_args.Seq2SeqNeuronTrainingArguments'> was provided.

Minimal example to reproduce:

Run the following script with torchrun train.py.

from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from optimum.neuron import Seq2SeqNeuronTrainer, Seq2SeqNeuronTrainingArguments
from optimum.neuron.distributed import lazy_load_for_parallelism


# Load dataset
dataset = load_dataset("samsum")

# Load tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-small")

# Preprocess the data
def preprocess_function(examples):
    inputs = ["summarize: " + doc for doc in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length')

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=150, truncation=True, padding='max_length')

    model_inputs["labels"] = labels["input_ids"]
    print("keys", model_inputs.keys())
    print("len labels", len(model_inputs['labels']))
    print("len inpids", len(model_inputs['input_ids']))
    print("len attmsk", len(model_inputs['attention_mask']))
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)

# Define training arguments
training_args = Seq2SeqNeuronTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=False,  # should be false since we don't provide a generation_config
)

# Load model
with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

# Initialize the trainer
trainer = Seq2SeqNeuronTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

Expected behavior

The NeuronTrainer accepts Seq2SeqNeuronTrainingArguments.

I have a workaround going where I have patched these lines to accept Seq2SeqNeuronTrainingArguments:

if not isinstance(self.args, NeuronTrainingArguments) and not isinstance(self.args, Seq2SeqNeuronTrainingArguments):
    raise ValueError(
         f"The NeuronTrainer only accepts NeuronTrainingArguments and Seq2SeqNeuronTrainingArguments, but {type(self.args)} was provided."
    )

Metadata

Metadata

Labels

StalebugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions