diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1724ef6a9909..2f6c65e0924b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -88,6 +88,8 @@ title: Tool use - local: chat_templating_writing title: Writing a chat template + - local: chat_response_parsing + title: Response parsing title: Chat with models - sections: - local: serving diff --git a/docs/source/en/chat_response_parsing.md b/docs/source/en/chat_response_parsing.md new file mode 100644 index 000000000000..0ef6f1faa31f --- /dev/null +++ b/docs/source/en/chat_response_parsing.md @@ -0,0 +1,229 @@ + + +# Response Parsing + +It is increasingly common for chat models to generate structured outputs, rather than just a single reply string. +The most common uses for structured outputs are [tool calling](./chat_extras) and [reasoning models](https://huggingface.co/reasoning-course). +Tool calling models can output tool calls, containing the name of the tool to call and any arguments to be passed to it, +while reasoning models often output reasoning steps as a "chain of thought". Some recent models even use both of these, +and may output reasoning and/or one or more tool calls before their final answer. + +Models with structured outputs pose a challenge for chat templating, because the output needs to be parsed before it +can be appended to the chat. For a concrete example, let's say we ask [GPT-OSS](https://huggingface.co/openai/gpt-oss-120b) +what the weather is like, and it thinks and decides to call a tool. Here's what the raw model output might look like: + +``` +<|start|><|assistant|><|channel|>analysis<|message|>The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). +So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. But we need to call function to get weather data. +So we should call get_current_weather with location "San Francisco, CA". Let's do that. + +We will call function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{ + "location": "San Francisco, CA" +} +``` + +And here's what that output would look like as a chat message dict: + +```json +{ + "role": "assistant", + "thinking": "The user asks: \"What is the weather like in SF?\" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. But we need to call function to get weather data. So we should call get_current_weather with location \"San Francisco, CA\". Let's do that.", + "tool_calls": [ + { + "name": "get_current_weather", + "arguments": { + "location": "San Francisco, CA" + } + } + ] +} +``` + +Chat **templates** give us a way to turn messages into formatted input for a model, but we need something else to +parse model output back into a standard message dict. This is what chat **parsing** is for. + +## The `parse_response` method + +Parsing a chat response on a model that supports it is straightforward. Simply take the raw, decoded output from +`generate()`, and pass it to the tokenizer's `parse_response` method: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +checkpoint = "HuggingFaceTB/SmolLM3-3B" + +tokenizer = AutoTokenizer.from_pretrained(checkpoint) +model = AutoModelForCausalLM.from_pretrained(checkpoint, dtype="auto", device_map="auto") + +messages = [ + { + "role": "user", + "content": "Hey! Can you summarize the end of the Cold War as briefly as possible? Like, comically briefly. It should really leave out almost most of the relevant information." + } +] + +input_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_tensors="pt" +).to(model.device) + +outputs = model.generate(input_ids, max_new_tokens=1024)[0, input_ids.shape[1]:] +out_text = tokenizer.decode(outputs) +parsed = tokenizer.parse_response(out_text) +print(parsed) +``` + +And that's all you need to start using response parsing! `parse_response` should return a complete message dict that is ready to be appended to the chat history. +When the tokenizer does not support response parsing, `parse_response` will throw an error. We hope to add support +to more tokenizers over time. + +## Developers: Understanding a simple response schema + +Under the hood, `parse_response` uses a **JSON schema** to parse the model output. A JSON schema represents +the structure of the output message dict. The schema is augmented with additional fields that indicate how the +output message string should be parsed into the expected format. Let's take a look at the schema for a SmolLM response, +excluding tool calls for now: + +```python +{ + "x-regex": "(?:\n?(?P.+?)\n?)?\s*(?P.+?)?\s*(?:<\|im_end\|>|$)", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "thinking": {"type": "string"} + } +} +``` + +We can see that the schema describes a JSON "object" (a `dict`, in other words) with three keys: `role`, `content`, and `thinking`. +Because all assistant responses have the role "assistant", the `role` key is a `const`(ant). The other two keys are strings, extracted +from the named groups in the regex in the `x-regex` field. + +Like chat templates, response schemas are set as a property of the tokenizer. To enable response parsing, all you need +to do is set `tokenizer.response_schema` to a valid schema dict, and `tokenizer.parse_response()` will work! Again, like +chat templates, this schema will be saved with the processor, so once you set it, you can use `save_pretrained()` or `push_to_hub()` to +save and share the schema. + +## Developers: Complex schemas + +Now, let's look at a more complex schema, which includes tool calls, to gain more of an understanding of the parser +internals. For this, we'll use the `GPT-OSS` schema. GPT-OSS emits both tool calls and thinking blocks, and it uses +an unusual format where model responses are tagged with one of three "channels": `commentary` for things like +tool calls, `analysis` for chain of thought blocks, and `final` for messages intended to be sent to the user. +A full message where the model calls a tool named `get_current_weather` might look like this, with some extra linebreaks added for clarity: + +```text +<|channel|>analysis<|message|> +The user asks: "What is the weather like in SF?" So we need to get the current weather in San Francisco, CA. +We need to call get_current_weather function. So we should call get_current_weather with location "San Francisco, CA". +<|end|> +<|start|>assistant<|channel|>commentary +to=functions.get_current_weather <|constrain|>json<|message|> +{ + "location": "San Francisco, CA" +} +<|call|> +``` + +Parsing proceeds recursively; the output of a regex (or other parser) at one level becomes the input to the nodes below it. +In other words, don't feel like you have to parse the entire output in one enormous regex! Instead, start with the schema, +and then add regexes to extract the relevant chunks as you go. Here's a schema that will parse it, with some +explanatory comments: + +```python +{ + "type": "object", + "properties": { + "role": {"const": "assistant"}, + # "content" and "thinking" are both similar to the previous example, and just extract a single string + # However, rather than using a single regex with named groups to extract both, we use a regex in each subkey. + # When an object node has no parser/regex, the entire input string is passed to all of its children, so + # parsing can either be done with named groups at the object level, or with separate regexes at the property level. + "content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"}, + "thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"}, + "tool_calls": { + # "x-regex-iterator" uses re.findall to find multiple possible manages, and returns them as an + # array/list. You don't need to worry about array handling, though - each item in the array will be + # parsed by the `items` schema, so just write the schema for a single item. + "x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)", + "type": "array", + "items": { + "type": "object", + "properties": { + # A const property is a fixed value, and the input has no effect on it. + "type": {"const": "function"}, + # Here, we wrap the entire tool call dict in a `{"function": ...}` block. The input string is passed through to it unchanged. + "function": { + "type": "object", + "properties": { + "name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"}, + "arguments": { + "type": "object", + "x-regex": "<\|message\|>(.*)", + # The "x-parser" field indicates that the extracted string should be parsed as JSON. + # The output is then passed to the schema nodes below and recursive parsing continues. + "x-parser": "json", + "additionalProperties": {"type": "any"}, + }, + }, + }, + }, + }, + }, + }, +} +``` + +## Developers: Understanding the parser logic + +The parser follows a few simple rules: + +1. Each level of the schema receives input from the level above, applies any regex or parser it has, and then passes the output to its children. +2. The root level receives the entire decoded model output string as input. +3. If a node has structured content after parsing (for example, if the regex has named groups and returns a dict, or if the parser returns a dict or list), + then that structured content is mapped to the node's children, and each child node receives its corresponding value as input. +4. If an `object` (dict) node has unstructured (string) output, then the entire string is passed to all of its children. This allows child nodes + to handle parsing individually rather than requiring a single parent regex to extract all keys at once. +5. If an `array` (list) node has unstructured (string) output, then this throws an error. + +There is a small set of allowable `x-` keys that indicate how parsing should be done at each node: +- `x-regex`: A regex string to apply to the input. If the regex has named groups, the output is a dict of group names to values. Named groups should only be used in `object` nodes. + Otherwise, the regex must have exactly one unnamed capturing group, and the output is the value of that group as a string. +- `x-regex-iterator`: A regex string to apply to the input using `re.findall()`. The output is a list of all matches. + This should only be used in `array` nodes, and the regex must have exactly one unnamed capturing group. The output is distributed to + the node's `items` schema. +- `x-parser`: Calls a built-in parser to apply to the input. Currently, the only supported parser is `json`, which parses the input string as JSON. + The output is passed to the child nodes for further parsing. Note that the `json` parser can return deeply nested output - in this case, the output + will be progressively unwrapped as it is passed through child nodes. The child nodes do not need additional `x-parser` or `x-regex` fields in this case, + but their structure must match the structure of the parsed JSON. +- `x-parser-args`: Only allowed in conjunction with `x-parser`. This is a dict of additional arguments that control parsing. Right now, the only supported + argument is `transform`, which specifies a `jmespath` transformation to apply to the output. This is useful when the JSON parser returns a structure + that needs to be modified to match the schema. +- `x-regex-key-value`: This is rarely necessary, but it can be useful when parsing key-value pairs in non-JSON format where the names of the keys are not known + in advance, such as when a model emits XML tool calls with arbitrary argument names. The regex must have exactly two named capturing groups, + `key` and `value`, and the output is a dict mapping keys to values. This should only be used in `object` nodes. + +In general, multiple regexes/parsers cannot be combined at the same level. The exception is that `x-regex`, returning a single string, can be combined with the other parsers. In this case, +`x-regex` is applied first, and then the output is passed to the other parser, either `x-regex-iterator`, `x-parser`, or `x-regex-key-value`. + +Putting these ideas together, you can see that the input flows through the schema, being parsed at each level and then distributed to child nodes. Each level +only needs to extract the input content that is relevant for that part of the schema, and can then let its child nodes handle the rest. Internally, this is handled +with a parser function that receives input, applies any regexes/parsers at the current level, then maps the result to its child nodes before recursively calling itself on each of them. +Recursion terminates when it reaches leaf nodes, usually primitive types like `string` or `number`, which simply return the input they receive. \ No newline at end of file diff --git a/setup.py b/setup.py index 20d7a007ca3a..6149c8cb5528 100644 --- a/setup.py +++ b/setup.py @@ -118,6 +118,7 @@ "importlib_metadata", "ipadic>=1.0.0,<2.0", "jinja2>=3.1.0", + "jmespath>=1.0.1", "kenlm", "kernels>=0.10.2,<0.11", "librosa", @@ -297,7 +298,7 @@ def run(self): extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["tiktoken"] = deps_list("tiktoken", "blobfile") extras["mistral-common"] = deps_list("mistral-common[opencv]") -extras["chat_template"] = deps_list("jinja2") +extras["chat_template"] = deps_list("jinja2", "jmespath") extras["testing"] = ( deps_list( "pytest", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 8b013aa2a00e..962f52505c15 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -27,6 +27,7 @@ "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "jinja2": "jinja2>=3.1.0", + "jmespath": "jmespath>=1.0.1", "kenlm": "kenlm", "kernels": "kernels>=0.10.2,<0.11", "librosa": "librosa", diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index e1ea152d7a0a..7950e6faf2da 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -152,6 +152,8 @@ def _sanitize_parameters( continue_final_message=None, skip_special_tokens=None, tokenizer_encode_kwargs=None, + tools=None, + documents=None, **generate_kwargs, ): # preprocess kwargs @@ -170,6 +172,11 @@ def _sanitize_parameters( preprocess_params["max_length"] = max_length generate_kwargs["max_length"] = max_length + if tools is not None: + preprocess_params["tools"] = tools + if documents is not None: + preprocess_params["documents"] = documents + if prefix is not None: preprocess_params["prefix"] = prefix if prefix: @@ -335,6 +342,8 @@ def preprocess( max_length=None, continue_final_message=None, tokenizer_encode_kwargs=None, + tools=None, + documents=None, **generate_kwargs, ): # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults @@ -359,6 +368,8 @@ def preprocess( continue_final_message=continue_final_message, return_dict=True, return_tensors="pt", + tools=tools, + documents=documents, **tokenizer_kwargs, ) else: @@ -514,7 +525,12 @@ def postprocess( ] else: # When we're not starting from a prefill, the output is a new assistant message - all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}] + if self.tokenizer.response_schema: + assistant_message = self.tokenizer.parse_response(all_text) + else: + # If there's no schema, then we have to assume it's all content + assistant_message = {"role": "assistant", "content": all_text} + all_text = list(prompt_text.messages) + [assistant_message] record = {"generated_text": all_text} for key, values in split_keys.items(): record[key] = values[idx] diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 89e5a9700739..af70bcb91f73 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -102,6 +102,7 @@ is_huggingface_hub_greater_or_equal, is_ipex_available, is_jinja_available, + is_jmespath_available, is_jumanpp_available, is_kernels_available, is_levenshtein_available, @@ -509,6 +510,13 @@ def require_jinja(test_case): return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case) +def require_jmespath(test_case): + """ + Decorator marking a test that requires jmespath. These tests are skipped when jmespath isn't installed. + """ + return unittest.skipUnless(is_jmespath_available(), "test requires jmespath")(test_case) + + def require_onnx(test_case): return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 89612915797e..7cf160bb6a70 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -61,6 +61,7 @@ requires_backends, to_py_obj, ) +from .utils.chat_parsing_utils import recursive_parse from .utils.chat_template_utils import render_jinja_template from .utils.import_utils import PROTOBUF_IMPORT_ERROR @@ -1429,6 +1430,8 @@ def __init__(self, **kwargs): # we reconstruct that into a single dict while loading them. self.chat_template = {template["name"]: template["template"] for template in self.chat_template} + self.response_schema = kwargs.pop("response_schema", None) + super().__init__(**kwargs) self.extra_special_tokens = kwargs.pop("extra_special_tokens", {}) @@ -1855,6 +1858,13 @@ def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional return chat_template + def parse_response(self, response: str, schema: Optional[Union[list, dict]] = None): + if schema is None: + if getattr(self, "response_schema", None) is None: + raise AttributeError("This tokenizer does not have a `response_schema` for parsing chat responses!") + schema = self.response_schema + return recursive_parse(response, schema) + @classmethod def from_pretrained( cls, @@ -2564,6 +2574,8 @@ def save_pretrained( tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates( save_directory, tokenizer_config, filename_prefix, save_jinja_files ) + if getattr(self, "response_schema", None) is not None: + tokenizer_config["response_schema"] = self.response_schema if len(self.init_inputs) > 0: tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 82a9e3a85bd1..6c14ac26a2aa 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -160,6 +160,7 @@ is_in_notebook, is_ipex_available, is_jinja_available, + is_jmespath_available, is_jumanpp_available, is_kenlm_available, is_kernels_available, diff --git a/src/transformers/utils/chat_parsing_utils.py b/src/transformers/utils/chat_parsing_utils.py new file mode 100644 index 000000000000..92421af5c0ce --- /dev/null +++ b/src/transformers/utils/chat_parsing_utils.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import json +import re + +from transformers.utils import is_jmespath_available + + +if is_jmespath_available(): + import jmespath +else: + jmespath = None + + +def _parse_re_match(node_match): + # If the regex has named groups, return a dict of those groups + if node_match.groupdict(): + return {key: val for key, val in node_match.groupdict().items() if val is not None} + # Otherwise the regex must have exactly one unnamed group, and we return that + else: + groups = list(node_match.groups()) + if len(groups) > 1: + raise ValueError(f"Regex has multiple unnamed groups!\nGroups: {groups}\n") + elif len(groups) == 0: + raise ValueError(f"Regex has no capture groups:\n\n{node_match.group(0)}") + return groups[0] + + +def recursive_parse( + node_content: str | list | dict, + node_schema: dict, +): + """ + This function takes content and a JSON schema which includes + regex extractors, and recursively parses the content. The output + should be a data structure matching the schema. + + Args: + node_content: The content corresponding to this node. Usually a string, but can be something else + if the parent node has multiple capture groups or named groups. In that case, + we generally pass the capture groups straight through to the children of this node + and don't do any parsing at this level. + node_schema: The schema node controlling the parsing. + + Returns: + The parsed data structure for the current node. + """ + + # If the schema has a const, we just return that value and do absolutely nothing else + if "const" in node_schema: + return node_schema["const"] + + # If the node content is None, we return None. EZ. + if node_content is None: + return None + + # If not, we have to do a little parsing. First, set some vars and do basic validation + node_type = node_schema["type"] + has_regex = "x-regex" in node_schema or "x-regex-iterator" in node_schema or "x-regex-key-value" in node_schema + if has_regex and not isinstance(node_content, str): + raise TypeError( + "Schema node got a non-string input, but has a regex for parsing.\n" + f"Input: {node_content}\n" + f"Schema: {node_schema}" + ) + + node_regex = node_schema.get("x-regex") + node_regex_iterator = node_schema.get("x-regex-iterator") + node_regex_to_dict = node_schema.get("x-regex-key-value") + if node_regex is not None: + node_match = re.search(node_regex, node_content, flags=re.DOTALL) + if not node_match: + return None + node_content = _parse_re_match(node_match) + if node_regex_iterator is not None: + if node_type != "array": + raise TypeError(f"Schema node with type {node_type} cannot use x-regex-iterator.\nSchema: {node_schema}") + # Note that this can be applied after a standard node-regex search + node_content = [ + _parse_re_match(node_match) + for node_match in re.finditer(node_regex_iterator, node_content, flags=re.DOTALL) + ] + if not node_content: + return None + if node_regex_to_dict is not None: + if node_type != "object": + raise TypeError(f"Schema node with type {node_type} cannot use x-regex-key-value.\nSchema: {node_schema}") + # Note that this can be applied after a standard node-regex search + output_content = {} + for node_match in re.finditer(node_regex_to_dict, node_content, flags=re.DOTALL): + match_groups = _parse_re_match(node_match) + if not isinstance(match_groups, dict) or "key" not in match_groups or "value" not in match_groups: + raise ValueError( + f"Regex for x-regex-key-value must have named groups 'key' and 'value'.\n" + f"Match groups: {match_groups}\n" + f"Schema: {node_schema}" + ) + output_content[match_groups["key"]] = match_groups["value"] + node_content = output_content + if not node_content: + return None + + # Next, if the node has a parser, apply it. We do this after regexes so that the regex can extract + # a substring to parse, if needed. + if "x-parser" in node_schema: + parser = node_schema["x-parser"] + if parser == "json": + if not isinstance(node_content, str): + raise TypeError( + f"Node has JSON parser but got non-string input: {node_content}\nSchema: {node_schema}" + ) + parser_args = node_schema.get("x-parser-args", {}) + transform = parser_args.get("transform") + allow_non_json = parser_args.get("allow_non_json", False) + try: + parsed_json = json.loads(node_content) + except json.JSONDecodeError as e: + if allow_non_json: + parsed_json = node_content + else: + raise ValueError( + f"Node has JSON parser but could not parse its contents as JSON. You can use the `allow_non_json` parser arg for nodes which may contain JSON or string content.\n\nContent: {node_content}\n\nError: {e}" + ) + if transform is not None: + if jmespath is None: + raise ImportError( + "Chat response schema includes a jmespath transformation, but jmespath is not installed. You can install it with `pip install jmespath`." + ) + parsed_json = jmespath.search(parser_args["transform"], parsed_json) + node_content = parsed_json + else: + raise ValueError(f"Unknown parser {parser} for schema node: {node_schema}") + + # If there's a mapping, apply it now + if "x-mapping" in node_schema: + if not isinstance(node_content, str): + raise TypeError( + f"Schema node with type {node_type} cannot use x-mapping on non-string content.\n" + f"Content: {node_content}\n" + f"Schema: {node_schema}" + ) + mapping = node_schema["x-mapping"] + if node_content in mapping: + node_content = mapping[node_content] + + if "x-mapping-regex" in node_schema: + if not isinstance(node_content, str): + raise TypeError( + f"Schema node with type {node_type} cannot use x-mapping-regex on non-string content.\n" + f"Content: {node_content}\n" + f"Schema: {node_schema}" + ) + mapping_regex = node_schema["x-mapping-regex"] + for pattern, replacement in mapping_regex.items(): + node_content = re.sub(pattern, replacement, node_content, flags=re.DOTALL) + + # Finally, handle parsed content based on schema type and recurse if required + if node_type == "object": + parsed_schema = {} + if isinstance(node_content, str): + # This means we don't have a regex at this level, so all of our child nodes need to parse the whole + # string themselves to extract their value. + if "properties" not in node_schema: + raise ValueError( + f"Object node received string content but has no regex or parser to handle it.\n" + f"Content: {node_content}\n" + f"Schema: {node_schema}" + ) + for key, child_node in node_schema["properties"].items(): + child_node_content = recursive_parse(node_content, node_schema["properties"][key]) + if child_node_content is not None: + parsed_schema[key] = child_node_content + return parsed_schema + elif isinstance(node_content, dict): + for key, child_node in node_schema.get("properties", {}).items(): + if key in node_content: + parsed_schema[key] = recursive_parse(node_content[key], child_node) + elif "default" in child_node: + parsed_schema[key] = child_node["default"] + else: + pass + if "additionalProperties" in node_schema: + for key, value in node_content.items(): + if key not in node_schema.get("properties", {}): + parsed_schema[key] = recursive_parse(value, node_schema["additionalProperties"]) + return parsed_schema + else: + raise TypeError(f"Expected a dict or str for schema node with type object, got {node_content}") + elif node_type == "array": + if not node_content: + return [] + parsed_schema = [] + if "items" in node_schema: + if not isinstance(node_content, list): + raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}") + for item in node_content: + parsed_schema.append(recursive_parse(item, node_schema["items"])) + return parsed_schema + elif "prefixItems" in node_schema: + if not isinstance(node_content, list): + if len(node_schema["prefixItems"]) == 1: + # If there's only one prefix item, this is a single item array, we can just wrap the string + node_content = [node_content] + else: + raise TypeError(f"Expected a list or regex for schema node with type array, got {node_content}") + if len(node_content) != len(node_schema["prefixItems"]): + raise ValueError( + f"Array node has {len(node_content)} items, but schema only has " + f"{len(node_schema['prefixItems'])} prefixItems defined.\n" + f"Content: {node_content}\n" + f"Schema: {node_schema}" + ) + for item, item_schema in zip(node_content, node_schema["prefixItems"]): + parsed_schema.append(recursive_parse(item, item_schema)) + return parsed_schema + else: + raise ValueError(f"Array node has no items or prefixItems schema defined.\nSchema: {node_schema}") + elif node_type in ("string", "integer", "number", "boolean"): + if not isinstance(node_content, str): + raise TypeError(f"Expected a string for schema node with type {node_type}, got {node_content}") + if node_type == "integer": + return int(node_content) + elif node_type == "number": + return float(node_content) + elif node_type == "boolean": + if node_content.lower() in ("true", "1"): + return True + elif node_content.lower() in ("false", "0"): + return False + else: + raise ValueError(f"Invalid boolean value: {node_content}") + else: + # String type + return node_content + else: + raise TypeError(f"Unsupported schema type {node_type} for node: {node_content}") diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a956efc97fdb..7ce4e5e266f1 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1131,6 +1131,11 @@ def is_jinja_available() -> bool: return _is_package_available("jinja2") +@lru_cache +def is_jmespath_available() -> bool: + return _is_package_available("jmespath") + + @lru_cache def is_mlx_available() -> bool: return _is_package_available("mlx") diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index f0f576364c41..4f7aa91c5094 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -263,6 +263,31 @@ def data(): ], ) + @require_torch + def test_small_chat_model_with_response_parsing(self): + text_generator = pipeline( + task="text-generation", + model="hf-internal-testing/tiny-gpt2-with-chatml-template", + ) + # Using `do_sample=False` to force deterministic output + chat = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + ] + text_generator.tokenizer.response_schema = { + # A real response schema should probably have things like "role" and "content" + # and "reasoning_content" but it's unlikely we'd get a tiny model to reliably + # output anything like that, so let's keep it simple. + "type": "object", + "properties": { + "first_word": {"type": "string", "x-regex": r"^\s*([a-zA-Z]+)"}, + "last_word": {"type": "string", "x-regex": r"([a-zA-Z]+)\s*$"}, + }, + } + outputs = text_generator(chat, do_sample=False, max_new_tokens=10) + parsed_message = outputs[0]["generated_text"][-1] + self.assertEqual(parsed_message, {"first_word": "factors", "last_word": "factors"}) + def get_test_pipeline( self, model, diff --git a/tests/utils/test_chat_schema_utils.py b/tests/utils/test_chat_schema_utils.py new file mode 100644 index 000000000000..35404a682222 --- /dev/null +++ b/tests/utils/test_chat_schema_utils.py @@ -0,0 +1,349 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +from transformers import AutoProcessor, AutoTokenizer +from transformers.testing_utils import require_jmespath +from transformers.utils.chat_parsing_utils import recursive_parse + + +cohere_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string", "x-regex": r"<\|START_RESPONSE\|>(.*?)(?:<\|END_RESPONSE\|>|$)"}, + "thinking": {"type": "string", "x-regex": r"<\|START_THINKING\|>(.*?)(?:<\|END_THINKING\|>|$)"}, + "tool_calls": { + "x-regex": r"<\|START_ACTION\|>(.*?)(?:<\|END_ACTION\|>|$)", + "x-parser": "json", + "x-parser-args": { + "transform": "[*].{type: 'function', function: {name: tool_name, arguments: parameters}}" + }, + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {"type": "any"}, + }, + }, + }, + }, + }, + }, + }, +} + +ernie_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string", "x-regex": "\n(.*?)\n?"}, + "thinking": {"type": "string", "x-regex": r"(?:^|\s*)(.*?)\s*<\/think>"}, + "tool_calls": { + "x-regex-iterator": "(.*?)", + "type": "array", + "items": { + "type": "object", + "x-parser": "json", + "x-parser-args": {"transform": "{type: 'function', function: @}"}, + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {"type": "any"}, + }, + }, + }, + }, + }, + }, + }, +} + +gpt_oss_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"}, + "thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"}, + "tool_calls": { + "x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)", + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"}, + "arguments": { + "type": "object", + "x-regex": r"<\|message\|>(.*)", + "x-parser": "json", + "additionalProperties": {"type": "any"}, + }, + }, + }, + }, + }, + }, + }, +} + +smollm_schema = { + "x-regex": r"(?:\n?(?P.+?)\n?)?\s*(?:(?P.+?))?\s*(?P.+?)?\s*(?:<\|im_end\|>|$)", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "thinking": {"type": "string"}, + "tool_calls": { + "x-parser": "json", + "x-parser-args": {"transform": "[{type: 'function', function: @}]"}, + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {"type": "any"}, + }, + }, + }, + }, + }, + }, + }, +} + +qwen3_schema = { + "x-regex": r"^(?:(?:)?\s*(?P.+?)\s*)?\s*(?:(?P.*?)\s*)?\s*(?P.+?)?\s*$", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "thinking": {"type": "string"}, + "tool_calls": { + "x-regex-iterator": r"^(.*)$", # We have already extracted tool calls and there can only be one, so just make it a list + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string", "x-regex": r""}, + "arguments": { + "type": "object", + "x-regex-key-value": r"\w+)>\n(?P.*?)\n", + "additionalProperties": { + "x-parser": "json", + "x-parser-args": {"allow_non_json": True}, + }, + }, + }, + }, + }, + }, + }, + }, +} + + +@require_jmespath +class ChatSchemaParserTest(unittest.TestCase): + def test_schema_save_load(self): + # Has no schema by default + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + tokenizer.response_schema = ernie_schema + with tempfile.TemporaryDirectory() as tmpdir: + tokenizer.save_pretrained(tmpdir) + reloaded_tokenizer = AutoTokenizer.from_pretrained(tmpdir) + self.assertEqual(reloaded_tokenizer.response_schema, ernie_schema) + + # Has no schema by default + processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + processor.response_schema = ernie_schema + with tempfile.TemporaryDirectory() as tmpdir: + processor.save_pretrained(tmpdir) + reloaded_processor = AutoProcessor.from_pretrained(tmpdir) + self.assertEqual(reloaded_processor.response_schema, ernie_schema) + + def test_tokenizer_method(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>' + parsed_chat = recursive_parse(model_out, cohere_schema) + tokenizer.response_schema = cohere_schema + tokenizer_parsed_chat = tokenizer.parse_response(model_out) + self.assertEqual(tokenizer_parsed_chat, parsed_chat) + + def test_cohere_template(self): + model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>' + parsed_chat = recursive_parse(model_out, cohere_schema) + self.assertEqual( + parsed_chat, + { + "role": "assistant", + "thinking": "I should call a tool.", + "tool_calls": [ + { + "type": "function", + "function": {"name": "simple_tool", "arguments": {"temperature_format": "Celsius"}}, + } + ], + }, + ) + + def test_ernie_template_with_tools(self): + model_out = 'The user is asking about the weather in Paris today. Let me check the available tools. There\'s a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to "Paris". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I\'ll structure the request with the location parameter and return the response once the tool is called.\n\n\n\n{"name": "get_current_temperature", "arguments": {"location": "Paris"}}\n\n' + parsed_chat = recursive_parse(model_out, ernie_schema) + self.assertEqual( + parsed_chat, + { + "role": "assistant", + "thinking": "The user is asking about the weather in Paris today. Let me check the available tools. There's a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to \"Paris\". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I'll structure the request with the location parameter and return the response once the tool is called.", + "tool_calls": [ + { + "type": "function", + "function": {"name": "get_current_temperature", "arguments": {"location": "Paris"}}, + } + ], + }, + ) + + def test_ernie_template_no_tools(self): + model_out = "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.\n\n\n\nHello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!\n\n" + parsed_chat = recursive_parse(model_out, ernie_schema) + self.assertEqual( + parsed_chat, + { + "role": "assistant", + "content": "Hello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!", + "thinking": "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.", + }, + ) + + def test_gpt_oss_template_with_tool_call(self): + model_out = '<|channel|>analysis<|message|>We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\n "location": "San Francisco, CA"\n}' + parsed_chat = recursive_parse(model_out, gpt_oss_schema) + self.assertEqual( + parsed_chat, + { + "role": "assistant", + "thinking": 'We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.', + "tool_calls": [ + { + "type": "function", + "function": {"name": "get_current_weather", "arguments": {"location": "San Francisco, CA"}}, + } + ], + }, + ) + + def test_gpt_oss_template_no_tool_call(self): + model_out = "<|channel|>analysis<|message|>User asks a simple math question: 2+2 = 4. Provide answer.<|end|><|start|>assistant<|channel|>final<|message|>2" + parsed_chat = recursive_parse(model_out, gpt_oss_schema) + self.assertEqual( + parsed_chat, + { + "role": "assistant", + "content": "2", + "thinking": "User asks a simple math question: 2+2 = 4. Provide answer.", + }, + ) + + def test_smollm_template_thinking_and_tool_call(self): + model_out = '\nOkay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.\n\n\n{"name": "greet_user", "arguments": {"greeting": "Hello! I\'m doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!"}}' + parsed_chat = recursive_parse(model_out, smollm_schema) + self.assertEqual( + parsed_chat, + { + "thinking": 'Okay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.', + "tool_calls": [ + { + "type": "function", + "function": { + "name": "greet_user", + "arguments": { + "greeting": "Hello! I'm doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!" + }, + }, + } + ], + }, + ) + + def test_smollm_template_tool_call_no_thinking(self): + model_out = '{"name": "get_weather", "arguments": {"city": "Paris"}}' + parsed_chat = recursive_parse(model_out, smollm_schema) + self.assertEqual( + parsed_chat, + { + "tool_calls": [ + {"type": "function", "function": {"name": "get_weather", "arguments": {"city": "Paris"}}} + ] + }, + ) + + def test_smollm_template_thinking_no_tool_call(self): + model_out = '\nOkay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.\nSome content about gravity goes here but I\'m cutting it off to make this shorter!' + parsed_chat = recursive_parse(model_out, smollm_schema) + self.assertEqual( + parsed_chat, + { + "content": "Some content about gravity goes here but I'm cutting it off to make this shorter!", + "thinking": 'Okay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.', + }, + ) + + def test_qwen3_tool_calls(self): + model_out = '\n\n\n[{"country": "France", "city": "Paris"}]\n\n\ncelsius\n\n\n' + parsed_chat = recursive_parse(model_out, qwen3_schema) + self.assertEqual( + parsed_chat, + { + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": { + "locations": [{"country": "France", "city": "Paris"}], + "temp_units": "celsius", + }, + }, + } + ] + }, + )