-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #140 from aymeric-roucher/main
Make agent text-to-sql cookbook
- Loading branch information
Showing
4 changed files
with
768 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,318 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Create a Transformers Agent from any LLM inference provider\n", | ||
"_Authored by: [Aymeric Roucher](https://huggingface.co/m-ric)_\n", | ||
"\n", | ||
"> This tutorial builds upon agent knowledge: to know more about agents, you can start with [this introductory notebook](agents)\n", | ||
"\n", | ||
"[Transformers Agents](https://huggingface.co/docs/transformers/en/agents) is a library to build agents, using an LLM to power it in the `llm_engine` argument. This argument was designed to leave the user maximal freedom to choose any LLM.\n", | ||
"\n", | ||
"Let's see how to build this `llm_engine` from the APIs of a few leading providers." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## HuggingFace Serverless API and Dedicated Endpoints\n", | ||
"\n", | ||
"Transformers agents provides a built-in `HfEngine` class that lets you use any model on the Hub via the Serverless API or your own dedicated Endpoint. This is the preferred way to use HF agents." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[33;1m======== New task ========\u001b[0m\n", | ||
"\u001b[37;1mWhat's the 10th Fibonacci number?\u001b[0m\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"['unicodedata', 're', 'math', 'collections', 'queue', 'itertools', 'random', 'time', 'stat', 'statistics']\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[33;1m==== Agent is executing the code below:\u001b[0m\n", | ||
"\u001b[0m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m=\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;139m0\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;139m1\u001b[39m\n", | ||
"\u001b[38;5;109;01mfor\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7m_\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01min\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;109mrange\u001b[39m\u001b[38;5;7m(\u001b[39m\u001b[38;5;139m9\u001b[39m\u001b[38;5;7m)\u001b[39m\u001b[38;5;7m:\u001b[39m\n", | ||
"\u001b[38;5;7m \u001b[39m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m=\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m+\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\n", | ||
"\u001b[38;5;109mprint\u001b[39m\u001b[38;5;7m(\u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m)\u001b[39m\u001b[0m\n", | ||
"\u001b[33;1m====\u001b[0m\n", | ||
"\u001b[33;1mPrint outputs:\u001b[0m\n", | ||
"\u001b[32;20m55\n", | ||
"\u001b[0m\n", | ||
"\u001b[33;1m==== Agent is executing the code below:\u001b[0m\n", | ||
"\u001b[0m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m=\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;139m0\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;139m1\u001b[39m\n", | ||
"\u001b[38;5;109;01mfor\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7m_\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01min\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;109mrange\u001b[39m\u001b[38;5;7m(\u001b[39m\u001b[38;5;139m9\u001b[39m\u001b[38;5;7m)\u001b[39m\u001b[38;5;7m:\u001b[39m\n", | ||
"\u001b[38;5;7m \u001b[39m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m=\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m+\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[0m\n", | ||
"\u001b[33;1m====\u001b[0m\n", | ||
"\u001b[33;1mPrint outputs:\u001b[0m\n", | ||
"\u001b[32;20m\u001b[0m\n", | ||
"\u001b[33;1m==== Agent is executing the code below:\u001b[0m\n", | ||
"\u001b[0m\u001b[38;5;7mfinal_answer\u001b[39m\u001b[38;5;7m(\u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m)\u001b[39m\u001b[0m\n", | ||
"\u001b[33;1m====\u001b[0m\n", | ||
"\u001b[33;1mPrint outputs:\u001b[0m\n", | ||
"\u001b[32;20m\u001b[0m\n", | ||
"\u001b[33;1m>>> Final answer:\u001b[0m\n", | ||
"\u001b[32;20m55\u001b[0m\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"55" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"from transformers.agents import HfEngine, ReactCodeAgent\n", | ||
"\n", | ||
"repo_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n", | ||
"endpoint_url = \"your_endpoint_url\"\n", | ||
"\n", | ||
"llm_engine = HfEngine(model=repo_id) # you could use model=endpoint_url here\n", | ||
"\n", | ||
"agent = ReactCodeAgent(tools=[], llm_engine=llm_engine)\n", | ||
"\n", | ||
"agent.run(\"What's the 10th Fibonacci number?\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The `llm_engine` initialization arg of the agent could be a simple callable such as:\n", | ||
"```py\n", | ||
"def llm_engine(messages, stop_sequences=[]) -> str:\n", | ||
" return response(messages)\n", | ||
"```\n", | ||
"This callable is the heart of the llm engine. It should respect these requirements:\n", | ||
"- takes as input a list of messages in [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating) format and outputs a `str`.\n", | ||
"- accepts a `stop_sequences` argument where the agent system will pass it sequences where it should stop generation.\n", | ||
"\n", | ||
"Let's take a closer look at the code for the `HfEngine` that we used:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from typing import List, Dict\n", | ||
"from transformers.agents.llm_engine import MessageRole, get_clean_message_list\n", | ||
"from huggingface_hub import InferenceClient\n", | ||
"\n", | ||
"llama_role_conversions = {\n", | ||
" MessageRole.TOOL_RESPONSE: MessageRole.USER,\n", | ||
"}\n", | ||
"\n", | ||
"\n", | ||
"class HfEngine:\n", | ||
" def __init__(self, model: str = \"meta-llama/Meta-Llama-3-8B-Instruct\"):\n", | ||
" self.model = model\n", | ||
" self.client = InferenceClient(model=self.model, timeout=120)\n", | ||
"\n", | ||
" def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:\n", | ||
" # Get clean message list\n", | ||
" messages = get_clean_message_list(\n", | ||
" messages, role_conversions=llama_role_conversions\n", | ||
" )\n", | ||
"\n", | ||
" # Get LLM output\n", | ||
" response = self.client.chat_completion(\n", | ||
" messages, stop=stop_sequences, max_tokens=1500\n", | ||
" )\n", | ||
" response = response.choices[0].message.content\n", | ||
"\n", | ||
" # Remove stop sequences from LLM output\n", | ||
" for stop_seq in stop_sequences:\n", | ||
" if response[-len(stop_seq) :] == stop_seq:\n", | ||
" response = response[: -len(stop_seq)]\n", | ||
" return response" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Here the engine is not a function, but a class with a `__call__` method, which adds the possibility to store attributes such as the client.\n", | ||
"\n", | ||
"We also use `get_clean_message_list()` utility to concatenate successive messages to the same role\n", | ||
"This method takes a `role_conversions` arg to convert the range of roles supported in Transformers Agents to only the ones accepted by your LLM.\n", | ||
"\n", | ||
"\n", | ||
"This recipe can be adapted for any LLM! Let's look at other examples." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Adapting the recipe for any LLM\n", | ||
"\n", | ||
"Using the above recipe, you can use any LLM inference source as your `llm_engine`.\n", | ||
"Just keep in mind the two main constraints:\n", | ||
"- `llm_engine` is a callable that takes as input a list of messages in [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating) format and outputs a `str`.\n", | ||
"- It accepts a `stop_sequences` argument.\n", | ||
"\n", | ||
"\n", | ||
"### OpenAI" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"from openai import OpenAI\n", | ||
"\n", | ||
"openai_role_conversions = {\n", | ||
" MessageRole.TOOL_RESPONSE: MessageRole.USER,\n", | ||
"}\n", | ||
"\n", | ||
"\n", | ||
"class OpenAIEngine:\n", | ||
" def __init__(self, model_name=\"gpt-4o\"):\n", | ||
" self.model_name = model_name\n", | ||
" self.client = OpenAI(\n", | ||
" api_key=os.getenv(\"OPENAI_API_KEY\"),\n", | ||
" )\n", | ||
"\n", | ||
" def __call__(self, messages, stop_sequences=[]):\n", | ||
" messages = get_clean_message_list(\n", | ||
" messages, role_conversions=openai_role_conversions\n", | ||
" )\n", | ||
"\n", | ||
" response = self.client.chat.completions.create(\n", | ||
" model=self.model_name,\n", | ||
" messages=messages,\n", | ||
" stop=stop_sequences,\n", | ||
" temperature=0.5,\n", | ||
" )\n", | ||
" return response.choices[0].message.content" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Anthropic" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from anthropic import Anthropic, AnthropicBedrock\n", | ||
"\n", | ||
"\n", | ||
"# Cf this page for using Anthropic from Bedrock: https://docs.anthropic.com/en/api/claude-on-amazon-bedrock\n", | ||
"class AnthropicEngine:\n", | ||
" def __init__(self, model_name=\"claude-3-5-sonnet-20240620\", use_bedrock=False):\n", | ||
" self.model_name = model_name\n", | ||
" if use_bedrock:\n", | ||
" self.model_name = \"anthropic.claude-3-5-sonnet-20240620-v1:0\"\n", | ||
" self.client = AnthropicBedrock(\n", | ||
" aws_access_key=os.getenv(\"AWS_BEDROCK_ID\"),\n", | ||
" aws_secret_key=os.getenv(\"AWS_BEDROCK_KEY\"),\n", | ||
" aws_region=\"us-east-1\",\n", | ||
" )\n", | ||
" else:\n", | ||
" self.client = Anthropic(\n", | ||
" api_key=os.getenv(\"ANTHROPIC_API_KEY\"),\n", | ||
" )\n", | ||
"\n", | ||
" def __call__(self, messages, stop_sequences=[]):\n", | ||
" messages = get_clean_message_list(\n", | ||
" messages, role_conversions=openai_role_conversions\n", | ||
" )\n", | ||
" index_system_message, system_prompt = None, None\n", | ||
" for index, message in enumerate(messages):\n", | ||
" if message[\"role\"] == MessageRole.SYSTEM:\n", | ||
" index_system_message = index\n", | ||
" system_prompt = message[\"content\"]\n", | ||
" if system_prompt is None:\n", | ||
" raise Exception(\"No system prompt found!\")\n", | ||
"\n", | ||
" filtered_messages = [\n", | ||
" message for i, message in enumerate(messages) if i != index_system_message\n", | ||
" ]\n", | ||
" if len(filtered_messages) == 0:\n", | ||
" print(\"Error, no user message:\", messages)\n", | ||
" assert False\n", | ||
"\n", | ||
" response = self.client.messages.create(\n", | ||
" model=self.model_name,\n", | ||
" system=system_prompt,\n", | ||
" messages=filtered_messages,\n", | ||
" stop_sequences=stop_sequences,\n", | ||
" temperature=0.5,\n", | ||
" max_tokens=2000,\n", | ||
" )\n", | ||
" full_response_text = \"\"\n", | ||
" for content_block in response.content:\n", | ||
" if content_block.type == \"text\":\n", | ||
" full_response_text += content_block.text\n", | ||
" return full_response_text" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Next steps\n", | ||
"\n", | ||
"Go on and implement your `llm_engine` for `transformers.agents` with your own LLM inference provider!\n", | ||
"\n", | ||
"Then to use this shiny new `llm_engine`, check out these use cases:\n", | ||
"- [Agentic RAG: turbocharge your RAG with query reformulation and self-query](agent_rag)\n", | ||
"- [Agent for text-to-SQL with automatic error correction](agent_text_to_sql)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "disposable", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.