Skip to content

Commit

Permalink
Merge pull request #140 from aymeric-roucher/main
Browse files Browse the repository at this point in the history
Make agent text-to-sql cookbook
  • Loading branch information
stevhliu committed Jul 8, 2024
2 parents fd1c449 + 9eada20 commit 7cefd6a
Show file tree
Hide file tree
Showing 4 changed files with 768 additions and 1 deletion.
4 changes: 4 additions & 0 deletions notebooks/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@
title: Build an agent with tool-calling superpowers using Transformers Agents
- local: agent_rag
title: Agentic RAG - turbocharge your RAG with query reformulation and self-query
- local: agent_change_llm
title: Create a Transformers Agent from any LLM inference provider
- local: agent_text_to_sql
title: Agent for Text-to-SQL with automatic error correction

- title: Enterprise Hub Cookbook
isExpanded: True
Expand Down
318 changes: 318 additions & 0 deletions notebooks/en/agent_change_llm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create a Transformers Agent from any LLM inference provider\n",
"_Authored by: [Aymeric Roucher](https://huggingface.co/m-ric)_\n",
"\n",
"> This tutorial builds upon agent knowledge: to know more about agents, you can start with [this introductory notebook](agents)\n",
"\n",
"[Transformers Agents](https://huggingface.co/docs/transformers/en/agents) is a library to build agents, using an LLM to power it in the `llm_engine` argument. This argument was designed to leave the user maximal freedom to choose any LLM.\n",
"\n",
"Let's see how to build this `llm_engine` from the APIs of a few leading providers."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## HuggingFace Serverless API and Dedicated Endpoints\n",
"\n",
"Transformers agents provides a built-in `HfEngine` class that lets you use any model on the Hub via the Serverless API or your own dedicated Endpoint. This is the preferred way to use HF agents."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[33;1m======== New task ========\u001b[0m\n",
"\u001b[37;1mWhat's the 10th Fibonacci number?\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['unicodedata', 're', 'math', 'collections', 'queue', 'itertools', 'random', 'time', 'stat', 'statistics']\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[33;1m==== Agent is executing the code below:\u001b[0m\n",
"\u001b[0m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m=\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;139m0\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;139m1\u001b[39m\n",
"\u001b[38;5;109;01mfor\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7m_\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01min\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;109mrange\u001b[39m\u001b[38;5;7m(\u001b[39m\u001b[38;5;139m9\u001b[39m\u001b[38;5;7m)\u001b[39m\u001b[38;5;7m:\u001b[39m\n",
"\u001b[38;5;7m \u001b[39m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m=\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m+\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\n",
"\u001b[38;5;109mprint\u001b[39m\u001b[38;5;7m(\u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m)\u001b[39m\u001b[0m\n",
"\u001b[33;1m====\u001b[0m\n",
"\u001b[33;1mPrint outputs:\u001b[0m\n",
"\u001b[32;20m55\n",
"\u001b[0m\n",
"\u001b[33;1m==== Agent is executing the code below:\u001b[0m\n",
"\u001b[0m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m=\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;139m0\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;139m1\u001b[39m\n",
"\u001b[38;5;109;01mfor\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7m_\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01min\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;109mrange\u001b[39m\u001b[38;5;7m(\u001b[39m\u001b[38;5;139m9\u001b[39m\u001b[38;5;7m)\u001b[39m\u001b[38;5;7m:\u001b[39m\n",
"\u001b[38;5;7m \u001b[39m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m=\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m,\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;7ma\u001b[39m\u001b[38;5;7m \u001b[39m\u001b[38;5;109;01m+\u001b[39;00m\u001b[38;5;7m \u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[0m\n",
"\u001b[33;1m====\u001b[0m\n",
"\u001b[33;1mPrint outputs:\u001b[0m\n",
"\u001b[32;20m\u001b[0m\n",
"\u001b[33;1m==== Agent is executing the code below:\u001b[0m\n",
"\u001b[0m\u001b[38;5;7mfinal_answer\u001b[39m\u001b[38;5;7m(\u001b[39m\u001b[38;5;7mb\u001b[39m\u001b[38;5;7m)\u001b[39m\u001b[0m\n",
"\u001b[33;1m====\u001b[0m\n",
"\u001b[33;1mPrint outputs:\u001b[0m\n",
"\u001b[32;20m\u001b[0m\n",
"\u001b[33;1m>>> Final answer:\u001b[0m\n",
"\u001b[32;20m55\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"55"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers.agents import HfEngine, ReactCodeAgent\n",
"\n",
"repo_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
"endpoint_url = \"your_endpoint_url\"\n",
"\n",
"llm_engine = HfEngine(model=repo_id) # you could use model=endpoint_url here\n",
"\n",
"agent = ReactCodeAgent(tools=[], llm_engine=llm_engine)\n",
"\n",
"agent.run(\"What's the 10th Fibonacci number?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `llm_engine` initialization arg of the agent could be a simple callable such as:\n",
"```py\n",
"def llm_engine(messages, stop_sequences=[]) -> str:\n",
" return response(messages)\n",
"```\n",
"This callable is the heart of the llm engine. It should respect these requirements:\n",
"- takes as input a list of messages in [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating) format and outputs a `str`.\n",
"- accepts a `stop_sequences` argument where the agent system will pass it sequences where it should stop generation.\n",
"\n",
"Let's take a closer look at the code for the `HfEngine` that we used:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Dict\n",
"from transformers.agents.llm_engine import MessageRole, get_clean_message_list\n",
"from huggingface_hub import InferenceClient\n",
"\n",
"llama_role_conversions = {\n",
" MessageRole.TOOL_RESPONSE: MessageRole.USER,\n",
"}\n",
"\n",
"\n",
"class HfEngine:\n",
" def __init__(self, model: str = \"meta-llama/Meta-Llama-3-8B-Instruct\"):\n",
" self.model = model\n",
" self.client = InferenceClient(model=self.model, timeout=120)\n",
"\n",
" def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:\n",
" # Get clean message list\n",
" messages = get_clean_message_list(\n",
" messages, role_conversions=llama_role_conversions\n",
" )\n",
"\n",
" # Get LLM output\n",
" response = self.client.chat_completion(\n",
" messages, stop=stop_sequences, max_tokens=1500\n",
" )\n",
" response = response.choices[0].message.content\n",
"\n",
" # Remove stop sequences from LLM output\n",
" for stop_seq in stop_sequences:\n",
" if response[-len(stop_seq) :] == stop_seq:\n",
" response = response[: -len(stop_seq)]\n",
" return response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here the engine is not a function, but a class with a `__call__` method, which adds the possibility to store attributes such as the client.\n",
"\n",
"We also use `get_clean_message_list()` utility to concatenate successive messages to the same role\n",
"This method takes a `role_conversions` arg to convert the range of roles supported in Transformers Agents to only the ones accepted by your LLM.\n",
"\n",
"\n",
"This recipe can be adapted for any LLM! Let's look at other examples."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adapting the recipe for any LLM\n",
"\n",
"Using the above recipe, you can use any LLM inference source as your `llm_engine`.\n",
"Just keep in mind the two main constraints:\n",
"- `llm_engine` is a callable that takes as input a list of messages in [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating) format and outputs a `str`.\n",
"- It accepts a `stop_sequences` argument.\n",
"\n",
"\n",
"### OpenAI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from openai import OpenAI\n",
"\n",
"openai_role_conversions = {\n",
" MessageRole.TOOL_RESPONSE: MessageRole.USER,\n",
"}\n",
"\n",
"\n",
"class OpenAIEngine:\n",
" def __init__(self, model_name=\"gpt-4o\"):\n",
" self.model_name = model_name\n",
" self.client = OpenAI(\n",
" api_key=os.getenv(\"OPENAI_API_KEY\"),\n",
" )\n",
"\n",
" def __call__(self, messages, stop_sequences=[]):\n",
" messages = get_clean_message_list(\n",
" messages, role_conversions=openai_role_conversions\n",
" )\n",
"\n",
" response = self.client.chat.completions.create(\n",
" model=self.model_name,\n",
" messages=messages,\n",
" stop=stop_sequences,\n",
" temperature=0.5,\n",
" )\n",
" return response.choices[0].message.content"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Anthropic"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from anthropic import Anthropic, AnthropicBedrock\n",
"\n",
"\n",
"# Cf this page for using Anthropic from Bedrock: https://docs.anthropic.com/en/api/claude-on-amazon-bedrock\n",
"class AnthropicEngine:\n",
" def __init__(self, model_name=\"claude-3-5-sonnet-20240620\", use_bedrock=False):\n",
" self.model_name = model_name\n",
" if use_bedrock:\n",
" self.model_name = \"anthropic.claude-3-5-sonnet-20240620-v1:0\"\n",
" self.client = AnthropicBedrock(\n",
" aws_access_key=os.getenv(\"AWS_BEDROCK_ID\"),\n",
" aws_secret_key=os.getenv(\"AWS_BEDROCK_KEY\"),\n",
" aws_region=\"us-east-1\",\n",
" )\n",
" else:\n",
" self.client = Anthropic(\n",
" api_key=os.getenv(\"ANTHROPIC_API_KEY\"),\n",
" )\n",
"\n",
" def __call__(self, messages, stop_sequences=[]):\n",
" messages = get_clean_message_list(\n",
" messages, role_conversions=openai_role_conversions\n",
" )\n",
" index_system_message, system_prompt = None, None\n",
" for index, message in enumerate(messages):\n",
" if message[\"role\"] == MessageRole.SYSTEM:\n",
" index_system_message = index\n",
" system_prompt = message[\"content\"]\n",
" if system_prompt is None:\n",
" raise Exception(\"No system prompt found!\")\n",
"\n",
" filtered_messages = [\n",
" message for i, message in enumerate(messages) if i != index_system_message\n",
" ]\n",
" if len(filtered_messages) == 0:\n",
" print(\"Error, no user message:\", messages)\n",
" assert False\n",
"\n",
" response = self.client.messages.create(\n",
" model=self.model_name,\n",
" system=system_prompt,\n",
" messages=filtered_messages,\n",
" stop_sequences=stop_sequences,\n",
" temperature=0.5,\n",
" max_tokens=2000,\n",
" )\n",
" full_response_text = \"\"\n",
" for content_block in response.content:\n",
" if content_block.type == \"text\":\n",
" full_response_text += content_block.text\n",
" return full_response_text"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Next steps\n",
"\n",
"Go on and implement your `llm_engine` for `transformers.agents` with your own LLM inference provider!\n",
"\n",
"Then to use this shiny new `llm_engine`, check out these use cases:\n",
"- [Agentic RAG: turbocharge your RAG with query reformulation and self-query](agent_rag)\n",
"- [Agent for text-to-SQL with automatic error correction](agent_text_to_sql)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "disposable",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion notebooks/en/agent_rag.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 65/65 [02:17<00:00, 2.12s/it]"
"100%|██████████| 65/65 [02:17<00:00, 2.12s/it]\n"
]
},
{
Expand Down
Loading

0 comments on commit 7cefd6a

Please sign in to comment.