Skip to content

Commit deac14a

Browse files
authored
🧹 Remove max_batch_tokens, num_blocks and block_size from generation kwargs (#4065)
1 parent 3d5a30b commit deac14a

File tree

3 files changed

+3
-12
lines changed

3 files changed

+3
-12
lines changed

‎trl/trainer/grpo_trainer.py‎

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -566,10 +566,6 @@ def __init__(
566566
"repetition_penalty": self.repetition_penalty,
567567
"cache_implementation": args.cache_implementation,
568568
}
569-
if args.use_transformers_paged:
570-
generation_kwargs["max_batch_tokens"] = 512
571-
generation_kwargs["num_blocks"] = 1024
572-
generation_kwargs["block_size"] = 128
573569
if args.generation_kwargs is not None:
574570
generation_kwargs.update(args.generation_kwargs)
575571
self.generation_config = GenerationConfig(**generation_kwargs)
@@ -1306,6 +1302,7 @@ def _generate_and_score_completions(
13061302
all_outputs = unwrapped_model.generate_batch(
13071303
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
13081304
)
1305+
unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
13091306
completion_ids = [output.generated_tokens for output in all_outputs.values()]
13101307
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
13111308
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")

‎trl/trainer/online_dpo_trainer.py‎

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -572,10 +572,6 @@ def __init__(
572572
generation_kwargs["min_p"] = self.min_p
573573
if args.generation_kwargs is not None:
574574
generation_kwargs.update(args.generation_kwargs)
575-
if self.use_transformers_paged:
576-
generation_kwargs["max_batch_tokens"] = 512
577-
generation_kwargs["num_blocks"] = 1024
578-
generation_kwargs["block_size"] = 128
579575
# Remove None values
580576
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
581577
self.generation_config = GenerationConfig(**generation_kwargs)
@@ -1112,6 +1108,7 @@ def _generate(self, model, prompts, images=None):
11121108
generation_config=self.generation_config,
11131109
progress_bar=False,
11141110
)
1111+
unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
11151112
completion_ids = [output.generated_tokens for output in all_outputs.values()]
11161113
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
11171114
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")

‎trl/trainer/rloo_trainer.py‎

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -638,10 +638,6 @@ def decode(example, tokenizer):
638638
"repetition_penalty": self.repetition_penalty,
639639
"cache_implementation": args.cache_implementation,
640640
}
641-
if args.use_transformers_paged:
642-
generation_kwargs["max_batch_tokens"] = 512
643-
generation_kwargs["num_blocks"] = 1024
644-
generation_kwargs["block_size"] = 128
645641
if args.generation_kwargs is not None:
646642
generation_kwargs.update(args.generation_kwargs)
647643
self.generation_config = GenerationConfig(**generation_kwargs)
@@ -1284,6 +1280,7 @@ def _generate_and_score_completions(
12841280
all_outputs = unwrapped_model.generate_batch(
12851281
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
12861282
)
1283+
unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
12871284
completion_ids = [output.generated_tokens for output in all_outputs.values()]
12881285
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
12891286
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")

0 commit comments

Comments
 (0)