Skip to content

Commit

Permalink
Fix get_chat_template call for trainer builder (#2003)
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn authored Oct 30, 2024
1 parent e62554c commit 74db2a1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)

model = model.to(cfg.device, dtype=cfg.torch_dtype)

Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,7 +1595,8 @@ def build(self, total_num_steps):
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
self.cfg.chat_template,
tokenizer=self.tokenizer,
)

if self.cfg.rl == "orpo":
Expand Down

0 comments on commit 74db2a1

Please sign in to comment.