Skip to content

Commit

Permalink
Merge pull request #33 from aurelio-labs/luca/fix-on-embeddings-check
Browse files Browse the repository at this point in the history
Fix for embeddings
  • Loading branch information
simjak committed Dec 18, 2023
2 parents 45ce599 + 8011844 commit bfbf10b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 74 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ mac.env
.coverage
.coverage.*
.pytest_cache
test.py
138 changes: 66 additions & 72 deletions docs/examples/function_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,22 @@
},
{
"cell_type": "code",
"execution_count": 213,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext dotenv\n",
"%dotenv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# OpenAI\n",
"import os\n",
"import openai\n",
"from semantic_router.utils.logger import logger\n",
"\n",
Expand All @@ -39,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 214,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -48,7 +59,7 @@
"import requests\n",
"\n",
"# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n",
"HF_API_TOKEN = os.environ[\"HF_API_TOKEN\"]\n",
"HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n",
"\n",
"\n",
"def llm_mistral(prompt: str) -> str:\n",
Expand Down Expand Up @@ -180,7 +191,7 @@
},
{
"cell_type": "code",
"execution_count": 217,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -242,6 +253,23 @@
"Set up the routing layer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from semantic_router.schema import Route\n",
"from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n",
"from semantic_router.layer import RouteLayer\n",
"from semantic_router.utils.logger import logger\n",
"\n",
"\n",
"def create_router(routes: list[dict]) -> RouteLayer:\n",
" logger.info(\"Creating route layer...\")\n",
" encoder = OpenAIEncoder"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -256,7 +284,7 @@
"\n",
"def create_router(routes: list[dict]) -> RouteLayer:\n",
" logger.info(\"Creating route layer...\")\n",
" encoder = CohereEncoder()\n",
" encoder = OpenAIEncoder()\n",
"\n",
" route_list: list[Route] = []\n",
" for route in routes:\n",
Expand All @@ -278,7 +306,7 @@
},
{
"cell_type": "code",
"execution_count": 219,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -349,72 +377,38 @@
},
{
"cell_type": "code",
"execution_count": 220,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_time(location: str) -> str:\n",
" \"\"\"Useful to get the time in a specific location\"\"\"\n",
" print(f\"Calling `get_time` function with location: {location}\")\n",
" return \"get_time\"\n",
"\n",
"\n",
"def get_news(category: str, country: str) -> str:\n",
" \"\"\"Useful to get the news in a specific country\"\"\"\n",
" print(\n",
" f\"Calling `get_news` function with category: {category} and country: {country}\"\n",
" )\n",
" return \"get_news\"\n",
"\n",
"\n",
"# Registering functions to the router\n",
"route_get_time = generate_route(get_time)\n",
"route_get_news = generate_route(get_news)\n",
"\n",
"routes = [route_get_time, route_get_news]\n",
"router = create_router(routes)\n",
"\n",
"# Tools\n",
"tools = [get_time, get_news]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger AI message: \n",
" {\n",
" 'location': 'Stockholm'\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"parameters: {'location': 'Stockholm'}\n",
"Calling `get_time` function with location: Stockholm\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger AI message: \n",
" {\n",
" 'category': 'tech',\n",
" 'country': 'Lithuania'\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"parameters: {'category': 'tech', 'country': 'Lithuania'}\n",
"Calling `get_news` function with category: tech and country: Lithuania\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[33m2023-12-15 11:41:57 WARNING semantic_router.utils.logger No function found\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' How can I help you today?'"
]
},
"execution_count": 220,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n",
"call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n",
Expand All @@ -438,7 +432,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
version = "0.0.9"
version = "0.0.10"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/encoders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __call__(self, docs: list[str]) -> list[list[float]]:
try:
logger.info(f"Encoding {len(docs)} documents...")
embeds = self.client.embeddings.create(input=docs, model=self.name)
if isinstance(embeds, dict) and "data" in embeds:
if "data" in embeds:
break
except OpenAIError as e:
sleep(2**j)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/encoders/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_openai_encoder_call_failure_non_openai_error(self, openai_encoder, mock
)
with pytest.raises(ValueError) as e:
openai_encoder(["test document"])

assert "OpenAI API call failed. Error: Non-OpenAIError" in str(e.value)

def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):
Expand Down

0 comments on commit bfbf10b

Please sign in to comment.