diff --git a/src/axolotl/core/chat/format/__init__.py b/src/axolotl/core/chat/format/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/axolotl/core/chat/format/chatml.py b/src/axolotl/core/chat/format/chatml.py
index d38265454b..315d101a86 100644
--- a/src/axolotl/core/chat/format/chatml.py
+++ b/src/axolotl/core/chat/format/chatml.py
@@ -1,9 +1,16 @@
+"""
+ChatML transformation functions for MessageContents
+"""
from typing import Optional
from ..messages import MessageContents, Messages
+from .shared import wrap_tools
-def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
+def format_message(
+ message: Messages,
+ message_index: Optional[int] = None, # pylint: disable=unused-argument
+) -> Messages:
if message.is_chat_formatted:
return message
@@ -21,37 +28,7 @@ def format_message(message: Messages, message_index: Optional[int] = None) -> Me
)
message.content.append(MessageContents(type="text", value="\n", weight=0))
- # loop over message.content by index to find tool calls, we need to wrap each with tags,
- # so be wary of indexing issues when changing the list while iterating
- # iterate over the range in reverse order to avoid index shifting
- for i in range(len(message.content) - 1, -1, -1):
- if message.content[i].type == "tool_call":
- # append a MessageContents text tag after
- message.content.insert(
- i + 1,
- MessageContents(
- type="text", value="\n", weight=message.weight
- ),
- )
- # make sure the actual tool call content ends with a newline
- message.content[i].has_newline = True
- # prepend a MessageContents text tag before
- message.content.insert(
- i, MessageContents(type="text", value="\n", weight=message.weight)
- )
- elif message.content[i].type == "tool_response":
- # append a MessageContents text tag after
- message.content.insert(
- i + 1,
- MessageContents(
- type="text", value="\n", weight=message.weight
- ),
- )
- # make sure the actual tool response content ends with a newline
- message.content[i].has_newline = True
- # prepend a MessageContents text tag before
- message.content.insert(
- i, MessageContents(type="text", value="\n", weight=message.weight)
- )
+ message = wrap_tools(message)
+
message.is_chat_formatted = True
return message
diff --git a/src/axolotl/core/chat/format/llama3x.py b/src/axolotl/core/chat/format/llama3x.py
index e5bc7b8494..352b741605 100644
--- a/src/axolotl/core/chat/format/llama3x.py
+++ b/src/axolotl/core/chat/format/llama3x.py
@@ -1,6 +1,10 @@
+"""
+Llama 3.x chat formatting functions for MessageContents
+"""
from typing import Optional
from ..messages import MessageContents, Messages
+from .shared import wrap_tools
def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
@@ -24,38 +28,7 @@ def format_message(message: Messages, message_index: Optional[int] = None) -> Me
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
)
- # loop over message.content by index to find tool calls, we need to wrap each with tags,
- # so be wary of indexing issues when changing the list while iterating
- # iterate over the range in reverse order to avoid index shifting
- for i in range(len(message.content) - 1, -1, -1):
- if message.content[i].type == "tool_call":
- # append a MessageContents text tag after
- message.content.insert(
- i + 1,
- MessageContents(
- type="text", value="\n", weight=message.weight
- ),
- )
- # make sure the actual tool call content ends with a newline
- message.content[i].has_newline = True
- # prepend a MessageContents text tag before
- message.content.insert(
- i, MessageContents(type="text", value="\n", weight=message.weight)
- )
- elif message.content[i].type == "tool_response":
- # append a MessageContents text tag after
- message.content.insert(
- i + 1,
- MessageContents(
- type="text", value="\n", weight=message.weight
- ),
- )
- # make sure the actual tool response content ends with a newline
- message.content[i].has_newline = True
- # prepend a MessageContents text tag before
- message.content.insert(
- i, MessageContents(type="text", value="\n", weight=message.weight)
- )
+ message = wrap_tools(message)
if message_index == 0:
message.content.insert(
diff --git a/src/axolotl/core/chat/format/shared.py b/src/axolotl/core/chat/format/shared.py
new file mode 100644
index 0000000000..9efa2353db
--- /dev/null
+++ b/src/axolotl/core/chat/format/shared.py
@@ -0,0 +1,47 @@
+"""
+shared functions for format transforms
+"""
+from axolotl.core.chat.messages import MessageContents, Messages
+
+
+def wrap_tools(message: Messages):
+ # loop over message.content by index to find tool calls, we need to wrap each with tags,
+ # so be wary of indexing issues when changing the list while iterating.
+ # iterate over the range in reverse order to avoid index shifting
+ for i in range(len(message.content) - 1, -1, -1):
+ if message.content[i].type == "tool_call":
+ # append a MessageContents text tag after
+ message.content.insert(
+ i + 1,
+ MessageContents(
+ type="text", value="\n", weight=message.weight
+ ),
+ )
+ # make sure the actual tool call content ends with a newline
+ message.content[i].has_newline = True
+ # prepend a MessageContents text tag before
+ message.content.insert(
+ i,
+ MessageContents(
+ type="text", value="\n", weight=message.weight
+ ),
+ )
+ elif message.content[i].type == "tool_response":
+ # append a MessageContents text tag after
+ message.content.insert(
+ i + 1,
+ MessageContents(
+ type="text", value="\n", weight=message.weight
+ ),
+ )
+ # make sure the actual tool response content ends with a newline
+ message.content[i].has_newline = True
+ # prepend a MessageContents text tag before
+ message.content.insert(
+ i,
+ MessageContents(
+ type="text", value="\n", weight=message.weight
+ ),
+ )
+
+ return message
diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py
index 95fc334c60..76ab2f733b 100644
--- a/src/axolotl/core/chat/messages.py
+++ b/src/axolotl/core/chat/messages.py
@@ -1,25 +1,35 @@
+"""
+internal message representations of chat messages
+"""
import json
-import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, List, Optional, Union
from transformers import PreTrainedTokenizer
-from transformers import PreTrainedTokenizer
class MessageRoles(str, Enum):
+ """
+ Message roles for the system, user, assistant, and tools
+ """
+
system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
- ipython = (
- "ipython" # pylint: disable=invalid-name # for responses from builtin tools
+ ipython = ( # pylint: disable=invalid-name
+ # for responses from builtin tools
+ "ipython"
)
class MessageContentTypes(str, Enum):
- special_token = "special_token" # pylint: disable=invalid-name
+ """
+ Message content types for text, image, audio, tool calls, and tool responses
+ """
+
+ special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
@@ -29,18 +39,30 @@ class MessageContentTypes(str, Enum):
@dataclass
class SpecialToken(str, Enum):
- bos_token = "bos_token" # pylint: disable=invalid-name
- eos_token = "eos_token" # pylint: disable=invalid-name
+ """
+ Special tokens for beginning of string and end of string
+ """
+
+ bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
+ eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
@dataclass
class ToolCallFunction:
+ """
+ Tool call function with name and arguments
+ """
+
name: str
arguments: dict[str, str]
@dataclass
class Tool:
+ """
+ Tool with description, function, and parameters
+ """
+
description: str
function: ToolCallFunction
parameters: dict[str, str] # .properties
@@ -48,9 +70,13 @@ class Tool:
@dataclass
class ToolCallContents:
+ """
+ Tool call contents with name, arguments, and optional id
+ """
+
name: str
arguments: dict[str, Union[str, int]]
- id: Optional[str] = None
+ id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
@@ -61,9 +87,13 @@ def __str__(self) -> str:
@dataclass
class ToolResponseContents:
+ """
+ Tool response contents with name, content, and optional id
+ """
+
name: str
content: Union[str, dict[str, Union[str, int, float]]]
- id: Optional[str] = None
+ id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "content": self.content}
@@ -74,6 +104,10 @@ def __str__(self) -> str:
@dataclass
class MessageContents:
+ """
+ Message contents with type, value, metadata, weight, newline, and end of contents
+ """
+
type: Union[str, MessageContentTypes]
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
@@ -90,6 +124,10 @@ def __str__(self) -> str:
@dataclass
class Messages:
+ """
+ Messages with role, content, metadata, weight, and chat formatting
+ """
+
role: Union[MessageRoles, str] # allows for arbitrary roles
content: List["MessageContents"]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
@@ -148,6 +186,10 @@ def tokenized(
@dataclass
class Chats:
+ """
+ top level data structure for chat conversations
+ """
+
conversation: List[Messages]
def __str__(self) -> str:
@@ -173,7 +215,11 @@ def tokenized(
@dataclass
class ChatFormattedChats(Chats):
- formatter: Callable # [[Union[dict, Chats]], Chats]
+ """
+ Chat formatted chats with formatter and optional train on inputs
+ """
+
+ formatter: Callable # [[Union[dict, Chats]], Chats]
train_on_inputs: bool = False
def __post_init__(self):
@@ -185,6 +231,10 @@ def __post_init__(self):
@dataclass
class PreferenceChats:
+ """
+ representation for preference data for chat
+ """
+
prompt: List[Messages]
chosen: Messages
rejected: Messages
diff --git a/src/axolotl/core/datasets/chat.py b/src/axolotl/core/datasets/chat.py
index 2caa719754..863642b246 100644
--- a/src/axolotl/core/datasets/chat.py
+++ b/src/axolotl/core/datasets/chat.py
@@ -1,3 +1,6 @@
+"""
+chat dataset module
+"""
import os
from typing import Callable, Optional, Union
@@ -5,10 +8,14 @@
from datasets import Dataset
from transformers import PreTrainedTokenizer
-from axolotl.core.chat.messages import ChatFormattedChats, Chats
+from axolotl.core.chat.messages import ChatFormattedChats
class TokenizedChatDataset(Dataset):
+ """
+ Tokenized chat dataset
+ """
+
def __init__(
self,
data: Dataset,
@@ -35,7 +42,10 @@ def map_fn(ex):
)
return ex.tokenized(model_transform)
- num_proc = min(64, process_count if process_count else os.cpu_count())
+ process_or_cpu_count: int = (
+ process_count or os.cpu_count() # type: ignore[assignment]
+ )
+ num_proc = min(64, process_or_cpu_count)
tokenized_data = data.map(
map_fn,
num_proc=num_proc,
diff --git a/src/axolotl/core/datasets/transforms/chat_builder.py b/src/axolotl/core/datasets/transforms/chat_builder.py
index 810967d5d3..3499ca09ef 100644
--- a/src/axolotl/core/datasets/transforms/chat_builder.py
+++ b/src/axolotl/core/datasets/transforms/chat_builder.py
@@ -1,12 +1,22 @@
-from typing import Mapping, Union
+"""
+This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
+"""
+from typing import Any, Mapping, Union
-def chat_message_transform_builder(
+def chat_message_transform_builder( # pylint: disable=dangerous-default-value
train_on_inputs=False,
conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
- message_field_content: Union[str, list[str]] = ["value", "text", "content"], # commonly "content"
- message_field_training: Union[str, list[str]] = ["train", "weight"], # commonly "weight"
+ message_field_content: Union[str, list[str]] = [
+ "value",
+ "text",
+ "content",
+ ], # commonly "content"
+ message_field_training: Union[str, list[str]] = [
+ "train",
+ "weight",
+ ], # commonly "weight"
):
"""Builds a transform that takes a row from the dataset and converts it to a Chat
@@ -28,12 +38,20 @@ def chat_message_transform_builder(
A function that takes a list of conversations and returns a list of messages.
"""
- message_field_role = [message_field_role] if isinstance(message_field_role, str) else message_field_role
+ message_field_role = (
+ [message_field_role]
+ if isinstance(message_field_role, str)
+ else message_field_role
+ )
message_field_content = (
- [message_field_content] if isinstance(message_field_content, str) else message_field_content
+ [message_field_content]
+ if isinstance(message_field_content, str)
+ else message_field_content
)
message_weight_fields = (
- [message_field_training] if isinstance(message_field_training, str) else message_field_training
+ [message_field_training]
+ if isinstance(message_field_training, str)
+ else message_field_training
)
role_value_mappings = {
@@ -62,38 +80,62 @@ def chat_message_transform_builder(
"ipython": 0,
}
- def transform_builder(sample: Mapping[str, any]):
+ def transform_builder(sample: Mapping[str, Any]):
if conversations_field not in sample:
raise ValueError(f"Field '{conversations_field}' not found in sample.")
# if none of the role fields are in the message, raise an error
- if not any(role in sample[conversations_field][0] for role in message_field_role):
- raise ValueError(f"No role field found in message.")
- else:
- role_field = next(role for role in message_field_role if role in sample[conversations_field][0])
- if not any(field in sample[conversations_field][0] for field in message_field_content):
- raise ValueError(f"No message_content field found in message.")
- else:
- message_content_field = next(field for field in message_field_content if field in sample[conversations_field][0])
- if not any(field in sample[conversations_field][0] for field in message_field_training):
+ if not any(
+ role in sample[conversations_field][0] for role in message_field_role
+ ):
+ raise ValueError("No role field found in message.")
+ role_field = next(
+ role
+ for role in message_field_role
+ if role in sample[conversations_field][0]
+ )
+ if not any(
+ field in sample[conversations_field][0] for field in message_field_content
+ ):
+ raise ValueError("No message_content field found in message.")
+ message_content_field = next(
+ field
+ for field in message_field_content
+ if field in sample[conversations_field][0]
+ )
+ if not any(
+ field in sample[conversations_field][0] for field in message_field_training
+ ):
message_weight_field = None
else:
- message_weight_field = next(field for field in message_weight_fields if field in sample[conversations_field][0])
+ message_weight_field = next(
+ field
+ for field in message_weight_fields
+ if field in sample[conversations_field][0]
+ )
messages = []
for message in sample[conversations_field]:
role = role_value_mappings[message[role_field]]
- weight = int(message[message_weight_field]) if message_weight_field else role_default_weights_mappings[role]
+ weight = (
+ int(message[message_weight_field])
+ if message_weight_field
+ else role_default_weights_mappings[role]
+ )
# TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
- messages.append({
- "role": role,
- "content": [{
- "type": "text",
- "value": message[message_content_field],
- }],
- "weight": weight,
- })
+ messages.append(
+ {
+ "role": role,
+ "content": [
+ {
+ "type": "text",
+ "value": message[message_content_field],
+ }
+ ],
+ "weight": weight,
+ }
+ )
return {"conversation": messages}
diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py
index 86f98b78e6..56e2e87cd8 100644
--- a/src/axolotl/prompt_strategies/__init__.py
+++ b/src/axolotl/prompt_strategies/__init__.py
@@ -30,4 +30,4 @@ def load(strategy, tokenizer, cfg, ds_cfg):
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
- return None
+ return None
diff --git a/src/axolotl/prompt_strategies/chat.py b/src/axolotl/prompt_strategies/chat.py
index c54fdb8ab1..35d7649026 100644
--- a/src/axolotl/prompt_strategies/chat.py
+++ b/src/axolotl/prompt_strategies/chat.py
@@ -1,4 +1,7 @@
-from typing import Optional, Dict, Any, Callable
+"""
+Chat dataset wrapping strategy for new internal messages representations
+"""
+from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset
from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
@@ -6,7 +9,17 @@
class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
- def __init__(self, processor, message_transform=None, formatter=None, **kwargs):
+ """
+ Chat dataset wrapping strategy for new internal messages representations
+ """
+
+ def __init__(
+ self,
+ processor,
+ message_transform=None,
+ formatter=None,
+ **kwargs, # pylint: disable=unused-argument
+ ):
"""
:param processor: tokenizer or image processor
:param kwargs:
@@ -21,7 +34,7 @@ def wrap_dataset(
dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
- **kwargs,
+ **kwargs, # pylint: disable=unused-argument
):
self.dataset = TokenizedChatDataset(
dataset,
@@ -53,15 +66,19 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
- format_message = lambda x: x
+ format_message = (
+ lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
+ )
if chat_template == "chatml":
- from axolotl.core.chat.format.chatml import format_message
+ from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):
- from axolotl.core.chat.format.llama3x import format_message
+ from axolotl.core.chat.format.llama3x import format_message # noqa F811
message_transform: Callable = chat_message_transform_builder(
train_on_inputs=ds_cfg.get("train_on_inputs", False),
**builder_kwargs,
)
- strategy = ChatMessageDatasetWrappingStrategy(tokenizer, message_transform=message_transform, formatter=format_message)
+ strategy = ChatMessageDatasetWrappingStrategy(
+ tokenizer, message_transform=message_transform, formatter=format_message
+ )
return strategy
diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py
index 12945a8128..b5a6dfa7de 100644
--- a/src/axolotl/utils/data/sft.py
+++ b/src/axolotl/utils/data/sft.py
@@ -23,10 +23,11 @@
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
+ DatasetWrappingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
- SummarizeTLDRPromptTokenizingStrategy, PromptTokenizingStrategy, DatasetWrappingStrategy,
+ SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
AlpacaPrompter,
diff --git a/tests/core/chat/test_messages.py b/tests/core/chat/test_messages.py
index 8383b723f6..fb63e7230f 100644
--- a/tests/core/chat/test_messages.py
+++ b/tests/core/chat/test_messages.py
@@ -1,3 +1,6 @@
+"""
+Tests for the chat messages module
+"""
import unittest
import pytest
@@ -8,8 +11,8 @@
from axolotl.core.chat.messages import ChatFormattedChats, Chats
-@pytest.fixture(scope="session")
-def llama_tokenizer():
+@pytest.fixture(scope="session", name="llama_tokenizer")
+def llama_tokenizer_fixture():
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
@@ -109,6 +112,10 @@ def llama_tokenizer_w_chatml(llama_tokenizer):
class TestMessagesCase:
+ """
+ Test cases for the chat messages module
+ """
+
def test_tool_call_stringify(self):
assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str(
chat_msgs_as_obj.conversation[2].content[1].value