Skip to content

Commit

Permalink
Chat templates
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jul 15, 2024
1 parent e32fc24 commit 0f2e484
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 65 deletions.
134 changes: 70 additions & 64 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,11 +1125,6 @@ def construct_chat_template( \
for eos in extra_eos_tokens:
count_eos += len(re.findall(r"{OUTPUT}" + re.escape(eos), chat_template))
pass
if count_eos == 0:
logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
eos = extra_eos_tokens[0]
chat_template = re.sub(r"{OUTPUT}", r"{OUTPUT}" + eos, chat_template)
pass

# This forces you to provide 2 input and outputs
final_combined_check = False
Expand All @@ -1151,72 +1146,83 @@ def construct_chat_template( \

# Must be equivalent to left
final_combined_check = True
except:
# Simple 1 singular input and output
system_count = chat_template.count("{SYSTEM}")
input_count = chat_template.count("{INPUT}")
output_count = chat_template.count("{OUTPUT}")
if system_count > 1:
raise RuntimeError("You must only provide 1 {SYSTEM} in the chat template")
if input_count > 1:
raise RuntimeError("You must only provide 1 {INPUT} in the chat template")
if output_count > 1:
raise RuntimeError("You must only provide 1 {OUTPUT} in the chat template")

if system_count != 0:
j = next(re.finditer(r"\{SYSTEM\}[\s]{0,}", chat_template)).span(0)[1]

# Repeatted text
instruction_response = chat_template[j:]
if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
raise RuntimeError(error_msg)
pass

# 1st System, Instruction, Output pair
left = chat_template[:j]
# 2nd Instruction, Output pair
right = chat_template[j:]

final_combined_check = left if final_combined_check else chat_template

# Isolate input
extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens)
if len(extra_eos_tokens_regex) != 0:
find_end = f"(?:{extra_eos_tokens_regex})?"
else:
j = 0
find_end = ""
find_end = r"\{INPUT\}[\s\n]{0,}" + find_end
input_end = list(re.finditer(find_end, right))
assert(len(input_end) == 1)
input_end = input_end[0]
input_end = input_end.span(0)[1]
input_part = right[:input_end]

# Isolate output
output_part = right[input_end:]

# Isolate system
where_system = left.find(input_part)
system_part = left[:where_system if where_system != -1 else len(left)]

# Check if the user provided a correct prompt
combined = system_part + input_part + output_part
if combined != final_combined_check:
combined_changed = combined .replace('\n', '\\n')
left_changed = final_combined_check.replace('\n', '\\n')
raise RuntimeError(
"Unsloth: The prompt template you provided isn't correct. You gave:\n"\
f"{combined_changed}\n\n"\
"But we require the following:\n"\
f"{left_changed}"
)
pass
except:
ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}"):]

# Must be equivalent to the original text
final_combined_check = False
pass
ending = re.escape(ending)
find_text = "{INPUT}" + ending + "(.+?{OUTPUT}" + ending + ")"
response_part = re.findall(find_text, chat_template, flags = re.DOTALL | re.MULTILINE)
response_part = response_part[0]

# Repeatted text
instruction_response = chat_template[j:]
if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
raise RuntimeError(error_msg)
pass
for j in range(1, len(response_part)):
try_find = re.escape(response_part[:j])
try: found = next(re.finditer("(" + try_find + ").+?\{INPUT\}", chat_template, flags = re.DOTALL | re.MULTILINE))
except: break
pass
separator = found.group(1)

# 1st System, Instruction, Output pair
left = chat_template[:j]
# 2nd Instruction, Output pair
right = chat_template[j:]
response_start = chat_template.find(response_part)
start_instruction = chat_template[:response_start].rfind(separator)
if start_instruction == -1: start_instruction = 0
instruction_part = chat_template[start_instruction:response_start]

final_combined_check = left if final_combined_check else chat_template
combined = instruction_part + response_part
where = chat_template.find(combined)
system_part = chat_template[:where]

# Isolate input
extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens)
if len(extra_eos_tokens_regex) != 0:
find_end = f"(?:{extra_eos_tokens_regex})?"
else:
find_end = ""
find_end = r"\{INPUT\}[\s\n]{0,}" + find_end
input_end = list(re.finditer(find_end, right))
assert(len(input_end) == 1)
input_end = input_end[0]
input_end = input_end.span(0)[1]
input_part = right[:input_end]

# Isolate output
output_part = right[input_end:]

# Isolate system
where_system = left.find(input_part)
system_part = left[:where_system if where_system != -1 else len(left)]

# Check if the user provided a correct prompt
combined = system_part + input_part + output_part
if combined != final_combined_check:
combined_changed = combined .replace('\n', '\\n')
left_changed = final_combined_check.replace('\n', '\\n')
raise RuntimeError(
"Unsloth: The prompt template you provided isn't correct. You gave:\n"\
f"{combined_changed}\n\n"\
"But we require the following:\n"\
f"{left_changed}"
)
system_part, input_part, output_part = system_part, instruction_part, response_part
pass

if count_eos == 0:
logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
eos = extra_eos_tokens[0]
output_part = output_part + eos
pass

# Ollama modelfile parts
Expand Down
2 changes: 1 addition & 1 deletion unsloth/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ def save_to_gguf(
print_info = \
f"==((====))== Unsloth: Conversion from QLoRA to GGUF information\n"\
f" \\\ /| [0] Installing llama.cpp will take 3 minutes.\n"\
f"O^O/ \_/ \\ [1] Converting HF to GUUF 16bits will take 3 minutes.\n"\
f"O^O/ \_/ \\ [1] Converting HF to GGUF 16bits will take 3 minutes.\n"\
f"\ / [2] Converting GGUF 16bits to {quantization_method} will take 10 minutes each.\n"\
f' "-____-" In total, you will have to wait at least 16 minutes.\n'
print(print_info)
Expand Down

0 comments on commit 0f2e484

Please sign in to comment.