Skip to content

Commit

Permalink
Update chat_templates.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jun 3, 2024
1 parent 86804dc commit 87fdd3a
Showing 1 changed file with 81 additions and 1 deletion.
82 changes: 81 additions & 1 deletion unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"test_hf_gguf_equivalence",
"remove_special_tokens",
"create_ollama_modelfile",
"standardize_dataset",
]

from transformers import StoppingCriteria, StoppingCriteriaList
Expand Down Expand Up @@ -700,8 +701,87 @@ def remove_special_tokens(tokenizer, prompt):
pass


def create_ollama_modelfile(tokenizer, gguf_location):
def standardize_dataset(
dataset,
conversation_key = "conversations",
system_message = None,
aliases_for_system = ["system",],
aliases_for_user = ["user", "human", "input",],
aliases_for_assistant = ["gpt", "assistant", "output",],
):
"""
Standardizes ShareGPT and other formats to user/assistant Hugging Face format.
"""
import collections
import itertools

convos = dataset[:10][conversation_key]
uniques = collections.defaultdict(list)
for convo in convos:
for message in convo:
for key, value in message.items():
uniques[key].append(value)
pass

# Must be only 2 entries
assert(len(uniques.keys()) == 2)

keys = list(uniques.keys())
length_first = len(set(uniques[keys[0]]))
length_second = len(set(uniques[keys[1]]))

if length_first < length_second:
# Role is assigned to the first element
role_key = keys[0]
content_key = keys[1]
else:
role_key = keys[1]
content_key = keys[0]
pass

# Check roles are in aliases
all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant)
roles = set(uniques[role_key])
leftover_aliases = (all_aliases | roles) - all_aliases
if len(leftover_aliases) != 0:
raise TypeError(
f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases."
)
pass

# Mapping for aliases
aliases_mapping = {}
for x in aliases_for_system: aliases_mapping[x] = "system"
for x in aliases_for_user: aliases_mapping[x] = "user"
for x in aliases_for_assistant: aliases_mapping[x] = "assistant"

def _standardize_dataset(examples):
convos = examples[conversation_key]
all_convos = []
for convo in convos:
new_convo = []
if len(convo) == 0: continue
has_system = aliases_mapping[convo[0][role_key]] == "system"
if not has_system and system_message is not None:
new_convo.append({ "role" : "system", "content" : system_message, })
for message in convo:
role = aliases_mapping[message[role_key]]
new_convo.append({ "role" : role, "content" : message[content_key], })
pass
all_convos.append(new_convo)
pass
return { conversation_key : all_convos, }
pass

return dataset.map(_standardize_dataset, batched = True,)
pass


def create_ollama_modelfile(tokenizer, gguf_location):
"""
Creates an Ollama Modelfile.
Use ollama.create(model = "new_ollama_model", modelfile = modelfile)
"""
modelfile = getattr(tokenizer, "_ollama_modelfile", None)
if modelfile is None:
raise RuntimeError(
Expand Down

0 comments on commit 87fdd3a

Please sign in to comment.