Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions align_system/algorithms/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from importlib import reload

def reload_all():
# Useful function for developing in an interactive environment without having to restart the kernel

from align_system.algorithms.lib import util
from align_system.algorithms.lib import language_model as lm
from align_system.algorithms.lib.chat import dialog_tokenizer as dt
from align_system.algorithms.lib.chat import chat_language_model as clm
from align_system.algorithms import llama_2_kdma_predicting_adm as kpa


# Reload in the correct order
for module in [util, lm, dt, clm, kpa]:
reload(module)
Empty file.
156 changes: 156 additions & 0 deletions align_system/algorithms/lib/chat/chat_language_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import List, Dict, Optional, Callable, Union, TextIO

from align_system.algorithms.lib.language_model import LanguageModel
from align_system.algorithms.lib.chat.dialog_tokenizer import dialog_tokenizers
from align_system.algorithms.lib.util import read_template, format_template, dialog_from_string, dialog_to_string

class ChatLanguageModel(LanguageModel):

def __init__(self, model: LanguageModel, tokenizer: Callable[[str], List[str]]):
"""
Initializes the chat language model.

:param model: Pretrained language model.
:param tokenizer: Tokenizer function.
"""
super().__init__(model, tokenizer)
model_name = model.name_or_path
assert model_name in dialog_tokenizers, f'No dialog tokenizer found for model {model_name}'
self.dialog_tokenizer = dialog_tokenizers[model_name](tokenizer)

def generate_responses(self,
dialogs: List[Dict[str, str]],
log_file: Optional[TextIO] = None,
max_new_tokens: int = 512,
temperature: float = 0.6) -> List[str]:
"""
Generates responses for given dialogs.

:param dialogs: List of dialogs.
:param log_file: Optional file to log the process.
:param max_new_tokens: Maximum number of new tokens to generate.
:param temperature: Temperature for sampling.
:return: Generated responses.
"""
# If logging is requested, write the dialogues into the log file
if log_file is not None:
log_file.write('**Dialogs:**\n')
for i, dialog in enumerate(dialogs):
log_file.write(f'*Dialog {i}:*\n{dialog_to_string(dialog)}\n')
log_file.flush()

# Prepare lists for the last user dialogues and prefixes.
# Prefix refers to the assistant's response in the last turn of a dialogue.
user_last_dialogs = []
prefixes = []
for dialog in dialogs:
prefix = ''
if dialog[-1]['role'] == 'assistant':
prefix = dialog[-1]['content']
dialog = dialog[:-1]
user_last_dialogs.append(dialog)
prefixes.append(prefix)

# Tokenization step
prompt_token_lists = [
[self.dialog_tokenizer.dialog_to_tokens(dialog)]
for dialog in user_last_dialogs
]

# Add the prefix tokens to the prompt tokens
for prompt_tokens, prefix in zip(prompt_token_lists, prefixes):
if len(prefix) > 0:
prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
prompt_tokens[0] += prefix_tokens

# Generate responses using tokens
prompt_token_lists = [x[0] for x in prompt_token_lists]
responses = self.generate_from_tokens(prompt_token_lists, max_new_tokens=max_new_tokens, temperature=temperature)
prefixed_responses = [
f'{prefix}{response}'
for prefix, response in zip(prefixes, responses)
]

# If logging is requested, write the generated responses into the log file
if log_file is not None:
log_file.write('**Generated Responses:**\n')
for i, response in enumerate(prefixed_responses):
log_file.write(f'*Response {i}:*\n{response}\n')
log_file.flush()

return prefixed_responses

def generate_from_template(
self,
template_files: Union[List[str], str],
substitution_dicts: Union[List[Dict[str, str]], Dict[str, str]],
parse_generation_fn: Optional[Callable[[str], str]] = None,
batch_size: int = 5,
log_file: Optional[TextIO] = None,
max_tokens: int = 512,
temperature: float = 0.6,
max_retry: int = 10,
verbose: bool = False) -> List[str]:
"""
Generates responses for given templates with substitutions.

:param template_files: Template files to use for generation.
:param substitution_dicts: Substitution dictionaries for the templates.
:param parse_generation_fn: Function to parse the generated responses.
:param batch_size: Batch size for generating responses.
:param log_file: Optional file to log the process.
:param max_tokens: Maximum number of tokens to generate.
:param temperature: Temperature for sampling.
:param max_retry: Maximum number of attempts to generate a valid output.
:param verbose: If True, verbose logging is enabled.
:return: Generated responses.
"""
if isinstance(substitution_dicts, dict):
substitution_dicts = [substitution_dicts]

if isinstance(template_files, str):
template_files = [template_files] * len(substitution_dicts)

assert len(template_files) == len(substitution_dicts), 'Number of templates and substitutions do not match'

# Create a dialogue for each template/substitution pair
dialogs = {
i: dialog_from_string(format_template(read_template(template_file), **substitutions))
for i, (template_file, substitutions) in enumerate(zip(template_files, substitution_dicts))
}

outputs = {}
input_counts = {}
while len(dialogs) > 0:
sample_ids = list(dialogs.keys())[:batch_size]
batch = [dialogs[i] for i in sample_ids]
generations = self.generate_responses(batch, log_file=log_file, max_new_tokens=max_tokens, temperature=temperature)

# Process the generated responses
for sample_id, generation in zip(sample_ids, generations):
input_counts[sample_id] = input_counts.get(sample_id, 0) + 1

# If the maximum number of try-outs is exceeded, throw an error
if input_counts[sample_id] > max_retry:
raise Exception(f'Could not generate valid output for sample [{sample_id}]')

# If there's a specific function to parse the generations, try to apply it
if parse_generation_fn is not None:
try:
outputs[sample_id] = parse_generation_fn(generation)
del dialogs[sample_id]
except Exception as e:
if verbose:
print(f'Error: could not parse output for sample [{sample_id}]')
print(e)
pass
else:
outputs[sample_id] = generation
del dialogs[sample_id]

assert len(outputs) == len(substitution_dicts), 'Unexpected state: number of outputs and substitutions do not match'

return [
outputs[i]
for i in range(len(outputs))
]
82 changes: 82 additions & 0 deletions align_system/algorithms/lib/chat/dialog_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from abc import abstractmethod
from typing import List, Dict
from transformers import PreTrainedTokenizerBase

class DialogTokenizer:
"""
Abstract base class for dialog tokenizers.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
"""
Initializes the dialog tokenizer.

:param tokenizer: Pretrained tokenizer.
"""
self.tokenizer = tokenizer

@abstractmethod
def dialog_to_tokens(self, dialog_messages: List[Dict[str, str]]) -> List[int]:
"""
Transforms a dialog to tokens.

:param dialog_messages: List of dialogs.
:returns: List of tokens representing the dialog.
"""
pass


class Llama2DialogTokenizer(DialogTokenizer):
"""
Dialog tokenizer for Llama-2.
"""

def dialog_to_tokens(self, dialog_messages: List[Dict[str, str]]) -> List[int]:
"""
Transforms a dialog to tokens. Llama communicates using system, user and assistant roles.

:param dialog_messages: List of dialogs.
:returns: List of tokens representing the dialog.
"""
# Define instance and system borders
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

# If the role of the first message is system
if dialog_messages[0]["role"] == "system":
# Create an initial dialog entry combining system and user messages
system_dialog = {"role": dialog_messages[1]["role"],
"content": B_SYS + dialog_messages[0]["content"] + E_SYS + dialog_messages[1]["content"]}
# Update dialog to start with system_dialog and followed by the rest of the dialog
dialog_messages = [system_dialog] + dialog_messages[2:]

# Ensure the correct dialog order (system, user, assistant, user, assistant... )
assert all([msg["role"] == "user" for msg in dialog_messages[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog_messages[1::2]]), \
"Model only supports 'system', 'user' and 'assistant' roles, in the sequence (s/u/a/u/a...)"

# Encode each user message and its following assistant message into tokens
dialog_tokens = []
for prompt, answer in zip(dialog_messages[::2], dialog_messages[1::2]):
tokenized_message = ([self.tokenizer.bos_token_id] +
self.tokenizer.encode(f"{B_INST} {prompt['content'].strip()} {E_INST} {answer['content'].strip()} ",
add_special_tokens=False) +
[self.tokenizer.eos_token_id])
dialog_tokens.extend(tokenized_message)

# Ensure the final message is from the user
assert dialog_messages[-1]["role"] == "user", "Last message must be from the user."

# Encode the user's final message into tokens and add to dialog_tokens
user_final_message_tokens = ([self.tokenizer.bos_token_id] + self.tokenizer.encode(
f"{B_INST} {dialog_messages[-1]['content'].strip()} {E_INST}",
add_special_tokens=False))
dialog_tokens.extend(user_final_message_tokens)

return dialog_tokens


# This mapping should ideally be updated when adding any new tokenizer classes to the project
dialog_tokenizers = {
'meta-llama/Llama-2-7b-chat-hf': Llama2DialogTokenizer,
'meta-llama/Llama-2-13b-chat-hf': Llama2DialogTokenizer,
}
Loading