Skip to content

the continue_final_message parameter in apply_chat_template can behave unexpectedly in some cases #40687

@usepr

Description

@usepr

System Info

In transformers/src/transformers/utils/chat_template_utils.py@render_jinja_template(), if continue_final_message is set to True, then the content of the final message will be used to split the rendered_chat string.

final_msg_loc = rendered_chat.rindex(final_message.strip())
...
rendered_chat = rendered_chat[: final_msg_loc + len(final_message.strip())]

So if the content of the final message is empty or matches some substrings of the chat template, there will be issues.

Who can help?

@ArthurZucker @itazap

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer

m = [
        {'role': 'user', 'content': 'user-input'},
        {'role': 'assistant', 'content': 'end'}
]

tokenizer = AutoTokenizer.from_pretrained("Qwen3-0.6B")

print (tokenizer.apply_chat_template(m, continue_final_message=True, tokenize=False))

will output

<|im_start|>user
user-input<|im_end|>
<|im_start|>assistant
<think>

</think>

end<|im_end

Expected behavior

Expected output

<|im_start|>user
user-input<|im_end|>
<|im_start|>assistant
<think>

</think>

end

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions