Skip to content

Commit

Permalink
add warning that sharegpt will be deprecated (#1957)
Browse files Browse the repository at this point in the history
* add warning that sharegpt will be deprecated

* add helper script for chat_templates and document deprecation

* Update src/axolotl/prompt_strategies/sharegpt.py

Co-authored-by: NanoCode012 <[email protected]>

---------

Co-authored-by: NanoCode012 <[email protected]>
  • Loading branch information
winglian and NanoCode012 authored Oct 11, 2024
1 parent 922db77 commit 7688385
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation
# fastchat conversation (deprecation soon, use chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt
Expand Down
60 changes: 60 additions & 0 deletions scripts/chat_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset


@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
ds_cfg = {}
ds_cfg["path"] = dataset
ds_cfg["split"] = split
ds_cfg["type"] = "chat_template"
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"

dataset = load_dataset(dataset, split=split)
features = dataset.features
feature_keys = features.keys()
field_messages = None
for key in ["conversation", "conversations", "messages"]:
if key in feature_keys:
field_messages = key
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages

message_fields = features["conversations"][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
break
if not message_field_role:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role

message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
break
if not message_field_content:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content

print(yaml.dump({"datasets": [ds_cfg]}))


if __name__ == "__main__":
parse_dataset()
3 changes: 3 additions & 0 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def build_loader(
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.",
)
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
Expand Down

0 comments on commit 7688385

Please sign in to comment.