Skip to content

Commit

Permalink
Add support for Gemma chat template (#1530)
Browse files Browse the repository at this point in the history
* Add support for Gemma chat template

* Update fschat version to include its newest support for Gemma chat style

* pin fastchat to current HEAD

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
Haoxiang-Wang and winglian authored Apr 21, 2024
1 parent 7477a53 commit 60f5ce0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.36
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8
gradio==3.50.2
tensorboard

Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/monkeypatch/fastchat_conversation_turns.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ def get_turns( # pylint: disable=too-many-return-statements
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message:
raise ValueError("Gemma chat template does not support system messages")
for i, (role, message) in enumerate(self.messages):
prefix = "<bos>" if i == 0 else ""
message_str = message if message else ""
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
Expand Down

3 comments on commit 60f5ce0

@teknium1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is broken somehow when training llama-3 I get:

  File "/root/axolotl/.venv/lib/python3.10/site-packages/multiprocess/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/root/axolotl/.venv/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 1377, in _write_generator_to_queue
    for i, result in enumerate(func(**kwargs)):
  File "/root/axolotl/.venv/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3442, in _map_single
    example = apply_function_on_filtered_inputs(example, i, offset=offset)
  File "/root/axolotl/.venv/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3345, in apply_function_on_filtered_inputs
    processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
  File "/root/axolotl/src/axolotl/prompt_tokenizers.py", line 373, in tokenize_prompt
    for _, part in enumerate(
  File "/root/axolotl/src/axolotl/prompters.py", line 360, in build_prompt
    for part in turns:
  File "/root/axolotl/src/axolotl/monkeypatch/fastchat_conversation_turns.py", line 126, in get_turns
    if self.sep_style == SeparatorStyle.GEMMA:
  File "/usr/lib/python3.10/enum.py", line 437, in __getattr__
    raise AttributeError(name) from None
AttributeError: GEMMA
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/axolotl/src/axolotl/cli/train.py", line 59, in <module>
    fire.Fire(do_cli)
  File "/root/axolotl/.venv/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/root/axolotl/.venv/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/root/axolotl/.venv/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/root/axolotl/src/axolotl/cli/train.py", line 35, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/root/axolotl/src/axolotl/cli/train.py", line 53, in do_train
    dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
  File "/root/axolotl/src/axolotl/cli/__init__.py", line 397, in load_datasets
    train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
  File "/root/axolotl/src/axolotl/utils/data/sft.py", line 66, in prepare_dataset
    train_dataset, eval_dataset, prompters = load_prepare_datasets(
  File "/root/axolotl/src/axolotl/utils/data/sft.py", line 460, in load_prepare_datasets
    dataset, prompters = load_tokenized_prepared_datasets(
  File "/root/axolotl/src/axolotl/utils/data/sft.py", line 399, in load_tokenized_prepared_datasets
    dataset_wrapper, dataset_prompter = get_dataset_wrapper(
  File "/root/axolotl/src/axolotl/utils/data/sft.py", line 553, in get_dataset_wrapper
    dataset_wrapper = TokenizedPromptDataset(
  File "/root/axolotl/src/axolotl/datasets.py", line 43, in __init__
    self.process(dataset).data,
  File "/root/axolotl/src/axolotl/datasets.py", line 55, in process
    return dataset.map(
  File "/root/axolotl/.venv/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 591, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/root/axolotl/.venv/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 556, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/root/axolotl/.venv/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3181, in map
    for rank, done, content in iflatmap_unordered(
  File "/root/axolotl/.venv/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 1417, in iflatmap_unordered
    [async_result.get(timeout=0.05) for async_result in async_results]
  File "/root/axolotl/.venv/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 1417, in <listcomp>
    [async_result.get(timeout=0.05) for async_result in async_results]
  File "/root/axolotl/.venv/lib/python3.10/site-packages/multiprocess/pool.py", line 774, in get
    raise self._value
AttributeError: GEMMA

@ehartford
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just edit fastchat_conversation_turns.py and remove that if statement

@adamlin120
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.