Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Please reopen issue #30361 #31635

Closed
3 of 4 tasks
kirk86 opened this issue Jun 26, 2024 · 5 comments
Closed
3 of 4 tasks

Please reopen issue #30361 #31635

kirk86 opened this issue Jun 26, 2024 · 5 comments

Comments

@kirk86
Copy link

kirk86 commented Jun 26, 2024

System Info

transformers: 4.39.3
python: 3.12
system: linux

Who can help?

@muellerz @SunMarc @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps outline here

Did some further investigation, created custom trainer from subclassing Seq2SeqTrainer and directly called model.generate

custom_generation_kwargs = {
    "max_new_tokens": 150,
}

from transformers import Seq2SeqTrainer

def compute_metrics(eval_preds):
    predictions, labels = eval_preds
    print(f"Predictions.shape = {predictions.shape}, Labels.shape = {labels.shape}")
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    if predictions.shape[1] != labels.shape[1]:
        predictions = predictions[..., 1:]
        #labels  = labels[..., :-1]
    predictions = np.where(predictions == -100, tokenizer.pad_token_id, predictions)
    labels = np.where(labels == -100, tokenizer.pad_token_id, labels)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    acc = utils.accuracy(preds=predictions, labels=labels)
    return {"accuracy": acc}

class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    def __init__(self, *args, custom_generation_kwargs=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.custom_generation_kwargs = custom_generation_kwargs or {}
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)

        # Generate predictions
        print(f"Name of self.model.main_input_name = {self.model.main_input_name}")
        generation_inputs = inputs[self.model.main_input_name]
        generated_tokens = self.model.generate(
            generation_inputs,
            **custom_generation_kwargs
        )
        print(f"Generated tokens {generated_tokens.shape}")

        # If the model has past_key_values, it will also return them, but we don't need them here
        if self.args.prediction_loss_only:
            return (None, generated_tokens, None)

        # Compute loss
        with torch.no_grad():
            outputs = model(**inputs)
            print(f"Logits.shape = {outputs.logits.shape}")
            if self.label_smoother is not None:
                loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
            else:
                loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()

        return (loss, generated_tokens, inputs["labels"])

# Initialize the trainer with custom generation parameters
trainer = CustomSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets.select(range(100)),
    eval_dataset=tokenized_datasets.select(range(1000)),
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()
eval_results = trainer.evaluate()
print(eval_results)

What we observe is that the number of generated tokens fluctuate on every batch:

Generated tokens torch.Size([32, 104])
Logits.shape = torch.Size([32, 150, 50265])
Generated tokens torch.Size([32, 124])
Logits.shape = torch.Size([32, 150, 50265])
Generated tokens torch.Size([32, 115])
Logits.shape = torch.Size([32, 150, 50265])
Generated tokens torch.Size([32, 129])
Logits.shape = torch.Size([32, 150, 50265])
Generated tokens torch.Size([32, 122])
Logits.shape = torch.Size([32, 150, 50265])

Expected behavior

To work as expected and generated tokens to adhere to either max_new_tokens or max_length.

Edit

This only works if you call model.generate(inputs, generation_config=gen_conf_obj) with GenerationConfig object but seems to not work if you call model.generate(inputs, max_new_tokens=128), it defaults to generating tokens of context_size=20, which contradicts the examples shown in text-generation.
Stated as:

Customize text generation
You can override any generation_config by passing the parameters and their values directly to the generate method:

model.generate(**inputs, num_beams=4, do_sample=True)
outputs = model.generate(**inputs, penalty_alpha=0.6, top_k=4, max_new_tokens=100)

Finally, this seems to be working with Seq2SeqTrainingArguments when passing a generation_config object but how exactly do we get the default values from model.generation_config. For instance, we have X,Y,Z models each with different model.generation_config I want to get those defaults and update the ones I'm interested?

Doing model.generate_config.to_dict() brings everything and it's not working.
Also, getting those params that are only set to some value still not working as intended.

generation_config = model.generation_config
set_params = {k: getattr(generation_config, k) for k in dir(generation_config)
              if not k.startswith('_') and getattr(generation_config, k) is not None}

generation_config = GenerationConfig(**set_params)

The model.generation_config should have an easy to access method that returns the default values per model?

 GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "no_repeat_ngram_size": 3,
  "num_beams": 4,
  "pad_token_id": 1
}

Is there a way to access those default values, e..g, model.generation_config.get_defaults().to_dict() so that I can instantiate a GenerationConfig from those and add or update the ones I need based on the specific task. Did it by simply parsing the output but should have a proper way to access those?

Edit 2

Passing a GenerationConfig object in the Seq2SeqTrainingArguments is completely ignored for some models (e.g., bart-base, flan-t5-small, etc), interesting is the fact that passing gen_conf = GenerationConfig(max_new_tokens=128) directly to model.generate(inputs, generation_config=gen_conf) (for CustomSeq2SeqTrainer) is also being ignored, same for model.generate(max_length=128) as well as model.generate(max_new_tokens=128), in the case of bart-base.

Please provide a work around in this case. We need a modular and robust solution across different models.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gante
Copy link
Member

gante commented Jul 27, 2024

Expected behavior
To work as expected and generated tokens to adhere to either max_new_tokens or max_length.

See my previous answer -- this expectation is incorrect: #30361 (comment)


As for the other issues, they are explained by the arguments not being passed correctly inside the trainer. Have you confirmed your issues against a recent version? (>= 4.43)

@kirk86
Copy link
Author

kirk86 commented Jul 29, 2024

Hi @gante , thanks for the reply.

See my previous answer -- this expectation is incorrect:

I've read the previous answers and I think that some things are still wrongfully stated.

E.g., I suggest setting the min_length flag (see the GenerationConfig docs), which prevents the EOS token from being generated up to a certain point.

Setting the min_length still doesn't guarantee that the model emits tokens up to the requested length. Still there are cases where the number of tokens generated fluctuate and do not adhere to min_length.

It only works if one sets both min_length and max_length to the desired number of output tokens which is not described in any docs to the best of my knowledge.

Have you confirmed your issues against a recent version? (>= 4.43)

I haven't had the chance to do that cause the code base relies on an earlier version of HF.
Once I get some free time I'll try to verify this.

@gante
Copy link
Member

gante commented Jul 29, 2024

Setting the min_length still doesn't guarantee that the model emits tokens up to the requested length. Still there are cases where the number of tokens generated fluctuate and do not adhere to min_length.

max_length is 20 by default, you might be hitting this maximum length :) We throw a warning in that case (max_length < min_length). If you've set max_length to a value larger than min_length AND min_length is not respected, then indeed we have a bug -- a reproducer would be greatly appreciated!

I haven't had the chance to do that cause the code base relies on an earlier version of HF.

Regarding previous versions: I'm afraid I won't be able to fix bad behavior in previous versions. But if the bad behaviour is still present, I'd be glad to explore 🤗

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants