From 4e66292f1c221afa34c93f8c1ab4fb7d4d5925d3 Mon Sep 17 00:00:00 2001 From: tolgadevAI <164843802+tolgadevAI@users.noreply.github.com> Date: Tue, 16 Jul 2024 14:31:07 +0300 Subject: [PATCH] update: integrate the OpenAILLM async into the RL --- docs/02-dynamic-routes.ipynb | 491 +++++++++++++++++++++++++-------- semantic_router/layer.py | 11 +- semantic_router/llms/openai.py | 68 ++++- 3 files changed, 445 insertions(+), 125 deletions(-) diff --git a/docs/02-dynamic-routes.ipynb b/docs/02-dynamic-routes.ipynb index 6353e71a..bb2e439b 100644 --- a/docs/02-dynamic-routes.ipynb +++ b/docs/02-dynamic-routes.ipynb @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 63, "metadata": { "id": "dLElfRhgur0v", "outputId": "da0e506e-24cf-43da-9243-894a7c4955db" @@ -80,16 +80,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: tzdata in c:\\users\\siraj\\documents\\personal\\work\\aurelio\\virtual environments\\semantic_router_3\\lib\\site-packages (2024.1)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - "[notice] A new release of pip is available: 23.1.2 -> 24.0\n", - "[notice] To update, run: python.exe -m pip install --upgrade pip\n" + "Requirement already satisfied: tzdata in /opt/anaconda3/envs/semantic-router/lib/python3.12/site-packages (2024.1)\n" ] } ], @@ -118,21 +109,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 64, "metadata": { "id": "kc9Ty6Lgur0x", "outputId": "f32e3a25-c073-4802-ced3-d7a5663670c1" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "from semantic_router import Route\n", "\n", @@ -171,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 65, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -179,15 +161,7 @@ "id": "BI9AiDspur0y", "outputId": "27329a54-3f16-44a5-ac20-13a6b26afb97" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-05-08 01:57:55 INFO semantic_router.utils.logger local\u001b[0m\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "from getpass import getpass\n", @@ -220,7 +194,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 66, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -235,15 +209,37 @@ "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" ] }, - "execution_count": 4, + "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# sync\n", "rl(\"how's the weather today?\")" ] }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 67, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# async\n", + "await rl.acall(\"how's the weather today?\")" + ] + }, { "cell_type": "markdown", "metadata": { @@ -264,7 +260,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 68, "metadata": { "id": "5jaF1Xa5ur0y" }, @@ -290,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 69, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -303,10 +299,10 @@ { "data": { "text/plain": [ - "'17:57'" + "'07:25'" ] }, - "execution_count": 6, + "execution_count": 69, "metadata": {}, "output_type": "execute_result" } @@ -326,7 +322,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 70, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -347,7 +343,7 @@ " 'required': ['timezone']}}}]" ] }, - "execution_count": 7, + "execution_count": 70, "metadata": {}, "output_type": "execute_result" } @@ -370,7 +366,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 71, "metadata": { "id": "iesBG9P3ur0z" }, @@ -387,17 +383,6 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jmVwEWEIg9hq" - }, - "outputs": [], - "source": [ - "time_route.llm" - ] - }, { "cell_type": "markdown", "metadata": { @@ -409,7 +394,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 72, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -422,7 +407,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-08 01:57:56 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" + "\u001b[32m2024-07-16 14:25:29 INFO semantic_router.utils.logger Adding `get_time` route\u001b[0m\n" ] } ], @@ -430,17 +415,6 @@ "rl.add(time_route)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mbccVdy5g9hr" - }, - "outputs": [], - "source": [ - "time_route.llm" - ] - }, { "cell_type": "markdown", "metadata": { @@ -452,7 +426,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 73, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -466,8 +440,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-05-08 01:57:57 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", - "\u001b[32m2024-05-08 01:57:58 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" + "\u001b[33m2024-07-16 14:25:31 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", + "\u001b[32m2024-07-16 14:25:32 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" ] }, { @@ -476,19 +450,49 @@ "RouteChoice(name='get_time', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}], similarity_score=None)" ] }, - "execution_count": 12, + "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# sync\n", "response = rl(\"what is the time in new york city?\")\n", "response" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-07-16 14:25:35 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='get_time', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}], similarity_score=None)" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# async\n", + "response = await rl.acall(\"what is the time in new york city?\")\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 75, "metadata": { "id": "92x96x1Og9hr", "outputId": "c1e46a81-b681-4a10-fff6-71e03342a88e" @@ -508,7 +512,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 76, "metadata": { "id": "xvdyUPKqg9hr", "outputId": "4161e7e0-ab6d-4e76-f068-2d66728305ff" @@ -518,7 +522,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "17:57\n" + "07:25\n" ] } ], @@ -579,7 +583,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 77, "metadata": { "id": "dtrksov0g9hs" }, @@ -661,7 +665,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 78, "metadata": { "id": "AjoYy7mFg9hs" }, @@ -672,7 +676,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 79, "metadata": { "id": "DoOkXV2Tg9hs", "outputId": "f1e0fe08-b6ed-4f50-d845-5c54832ca677" @@ -710,7 +714,7 @@ " 'required': ['time', 'from_timezone', 'to_timezone']}}}]" ] }, - "execution_count": 17, + "execution_count": 79, "metadata": {}, "output_type": "execute_result" } @@ -725,7 +729,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 80, "metadata": { "id": "YBRHxhnkg9hs" }, @@ -762,7 +766,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 81, "metadata": { "id": "yEbQadQbg9ht" }, @@ -773,20 +777,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 82, "metadata": { "id": "C0aYIXaog9ht", "outputId": "74114a86-4a6f-49c5-8e2e-600f577d63f5" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-05-08 01:57:58 INFO semantic_router.utils.logger local\u001b[0m\n" - ] - } - ], + "outputs": [], "source": [ "rl2 = RouteLayer(encoder=encoder, routes=routes)" ] @@ -802,7 +798,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 83, "metadata": { "id": "PJR97klVg9ht" }, @@ -834,7 +830,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 84, "metadata": { "id": "D2kXFv9Xg9ht", "outputId": "569cf17f-2091-4aea-9cba-11bb0af2ebd4" @@ -846,16 +842,41 @@ "RouteChoice(name='politics', function_call=None, similarity_score=None)" ] }, - "execution_count": 22, + "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# sync\n", + "\n", "response = rl2(\"What is your political leaning?\")\n", "response" ] }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='politics', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 85, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# async\n", + "\n", + "response = await rl2.acall(\"What is your political leaning?\")\n", + "response" + ] + }, { "cell_type": "markdown", "metadata": { @@ -867,7 +888,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 86, "metadata": { "id": "YsI5O_bHg9ht", "outputId": "a6e3814b-97e0-4406-ec9a-17b1c7103e40" @@ -879,16 +900,39 @@ "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" ] }, - "execution_count": 23, + "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# sync\n", "response = rl2(\"Hello bot, how are you today?\")\n", "response" ] }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteChoice(name='chitchat', function_call=None, similarity_score=None)" + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# async\n", + "response = await rl2.acall(\"Hello bot, how are you today?\")\n", + "response" + ] + }, { "cell_type": "markdown", "metadata": { @@ -900,7 +944,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 88, "metadata": { "id": "BdOfLx-wg9hu", "outputId": "ef55a34c-7c34-4acc-918d-a173fac95171" @@ -910,8 +954,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2024-05-08 01:58:00 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", - "\u001b[32m2024-05-08 01:58:01 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" + "\u001b[33m2024-07-16 14:26:13 WARNING semantic_router.utils.logger No LLM provided for dynamic route, will use OpenAI LLM default. Ensure API key is set in OPENAI_API_KEY environment variable.\u001b[0m\n", + "\u001b[32m2024-07-16 14:26:14 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" ] }, { @@ -920,19 +964,20 @@ "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}], similarity_score=None)" ] }, - "execution_count": 24, + "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# sync\n", "response = rl2(\"what is the time in New York?\")\n", "response" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 89, "metadata": { "id": "QrpF_JcHg9hu", "outputId": "242d645f-43c3-4e9f-9a46-d3aa3105f02a" @@ -942,7 +987,53 @@ "name": "stdout", "output_type": "stream", "text": [ - "17:58\n" + "07:26\n" + ] + } + ], + "source": [ + "parse_response(response)" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-07-16 14:26:18 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}]\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'America/New_York'}}], similarity_score=None)" + ] + }, + "execution_count": 90, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# async\n", + "response = await rl2.acall(\"what is the time in New York?\")\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "07:26\n" ] } ], @@ -961,7 +1052,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 92, "metadata": { "id": "W85287lAg9hu", "outputId": "4f247f13-046b-4a5c-f119-de17df29131f" @@ -971,7 +1062,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-08 01:58:02 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}]\u001b[0m\n" + "\u001b[32m2024-07-16 14:26:24 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}]\u001b[0m\n" ] }, { @@ -980,19 +1071,20 @@ "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}], similarity_score=None)" ] }, - "execution_count": 26, + "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# sync\n", "response = rl2(\"What is the time difference between Los Angeles and Istanbul?\")\n", "response" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 93, "metadata": { "id": "2jxAIi6rg9hv", "outputId": "8abff974-602f-4c0d-8d21-3a275b0eee62" @@ -1004,6 +1096,68 @@ "text": [ "The time difference between America/Los_Angeles and Europe/Istanbul is 10.0 hours.\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/7b/r9z99rtd2zg774gf4bh8cr9m0000gn/T/ipykernel_86079/3683005204.py:28: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " now_utc = datetime.utcnow().replace(tzinfo=ZoneInfo(\"UTC\"))\n" + ] + } + ], + "source": [ + "parse_response(response)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-07-16 14:26:30 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}]\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time_difference', 'arguments': {'timezone1': 'America/Los_Angeles', 'timezone2': 'Europe/Istanbul'}}], similarity_score=None)" + ] + }, + "execution_count": 94, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# async\n", + "response = await rl2.acall(\"What is the time difference between Los Angeles and Istanbul?\")\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The time difference between America/Los_Angeles and Europe/Istanbul is 10.0 hours.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/7b/r9z99rtd2zg774gf4bh8cr9m0000gn/T/ipykernel_86079/3683005204.py:28: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " now_utc = datetime.utcnow().replace(tzinfo=ZoneInfo(\"UTC\"))\n" + ] } ], "source": [ @@ -1021,7 +1175,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 96, "metadata": { "id": "PzM1HH7Rg9hv", "outputId": "e123c86f-9754-453a-d895-bfcce26110d4" @@ -1031,7 +1185,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-08 01:58:04 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}]\u001b[0m\n" + "\u001b[32m2024-07-16 14:26:35 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}]\u001b[0m\n" ] }, { @@ -1040,19 +1194,20 @@ "RouteChoice(name='timezone_management', function_call=[{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}], similarity_score=None)" ] }, - "execution_count": 28, + "execution_count": 96, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# sync\n", "response = rl2(\"What is 23:02 Dubai time in Tokyo time? Please and thank you.\")\n", "response" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 97, "metadata": { "id": "QFKZ757Pg9hv", "outputId": "af5c1328-f6dd-4dc7-c104-e920198885fc" @@ -1070,6 +1225,52 @@ "parse_response(response)" ] }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-07-16 14:26:40 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}]\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RouteChoice(name='timezone_management', function_call=[{'function_name': 'convert_time', 'arguments': {'time': '23:02', 'from_timezone': 'Asia/Dubai', 'to_timezone': 'Asia/Tokyo'}}], similarity_score=None)" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# async\n", + "response = await rl2.acall(\"What is 23:02 Dubai time in Tokyo time? Please and thank you.\")\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "04:02\n" + ] + } + ], + "source": [ + "parse_response(response)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -1081,7 +1282,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 100, "metadata": { "id": "Vnj6A3AVg9hv", "outputId": "c8a61c3f-a504-430b-82fb-c211c0523dcb" @@ -1091,11 +1292,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-05-08 01:58:07 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}]\u001b[0m\n" + "\u001b[32m2024-07-16 14:26:46 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}]\u001b[0m\n" ] } ], "source": [ + "# sync\n", "response = rl2(\n", " \"\"\"\n", " What is the time in Prague?\n", @@ -1107,7 +1309,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 101, "metadata": { "id": "L9jq_Yoag9hv", "outputId": "50fae028-4af4-46f5-f6e9-4262b8874caa" @@ -1119,7 +1321,7 @@ "RouteChoice(name='timezone_management', function_call=[{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}], similarity_score=None)" ] }, - "execution_count": 31, + "execution_count": 101, "metadata": {}, "output_type": "execute_result" } @@ -1130,7 +1332,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 102, "metadata": { "id": "Hw3raSVBg9hv", "outputId": "d30b9cba-979d-4bdf-86e0-37c550c4187d" @@ -1140,10 +1342,69 @@ "name": "stdout", "output_type": "stream", "text": [ - "23:58\n", + "13:26\n", + "The time difference between Europe/Berlin and Asia/Shanghai is 6.0 hours.\n", + "11:53\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/7b/r9z99rtd2zg774gf4bh8cr9m0000gn/T/ipykernel_86079/3683005204.py:28: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " now_utc = datetime.utcnow().replace(tzinfo=ZoneInfo(\"UTC\"))\n" + ] + } + ], + "source": [ + "parse_response(response)" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-07-16 14:26:52 INFO semantic_router.utils.logger Function inputs: [{'function_name': 'get_time', 'arguments': {'timezone': 'Europe/Prague'}}, {'function_name': 'get_time_difference', 'arguments': {'timezone1': 'Europe/Berlin', 'timezone2': 'Asia/Shanghai'}}, {'function_name': 'convert_time', 'arguments': {'time': '05:53', 'from_timezone': 'Europe/Lisbon', 'to_timezone': 'Asia/Bangkok'}}]\u001b[0m\n" + ] + } + ], + "source": [ + "# async\n", + "response = await rl2.acall(\n", + " \"\"\"\n", + " What is the time in Prague?\n", + " What is the time difference between Frankfurt and Beijing?\n", + " What is 5:53 Lisbon time in Bangkok time?\n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13:26\n", "The time difference between Europe/Berlin and Asia/Shanghai is 6.0 hours.\n", "11:53\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/7b/r9z99rtd2zg774gf4bh8cr9m0000gn/T/ipykernel_86079/3683005204.py:28: DeprecationWarning: datetime.datetime.utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC).\n", + " now_utc = datetime.utcnow().replace(tzinfo=ZoneInfo(\"UTC\"))\n" + ] } ], "source": [ @@ -1170,9 +1431,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 5c2d7228..072b697f 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -297,9 +297,14 @@ async def acall( "Route has a function schema, but no text was provided." ) if route.function_schemas and not isinstance(route.llm, BaseLLM): - raise NotImplementedError( - "Dynamic routes not yet supported for async calls." - ) + if not self.llm: + logger.warning( + "No LLM provided for dynamic route, will use OpenAI LLM default" + ) + self.llm = OpenAILLM(use_async=True) + route.llm = self.llm + else: + route.llm = self.llm return route(text) elif passed and route is not None and simulate_static: return RouteChoice( diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 2a531195..20360655 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -21,7 +21,7 @@ class OpenAILLM(BaseLLM): - client: Optional[openai.OpenAI] + client: Union[openai.AsyncOpenAI, openai.OpenAI] temperature: Optional[float] max_tokens: Optional[int] @@ -31,6 +31,7 @@ def __init__( openai_api_key: Optional[str] = None, temperature: float = 0.01, max_tokens: int = 200, + use_async=False, ): if name is None: name = EncoderDefault.OPENAI.value["language_model"] @@ -38,12 +39,21 @@ def __init__( api_key = openai_api_key or os.getenv("OPENAI_API_KEY") if api_key is None: raise ValueError("OpenAI API key cannot be 'None'.") - try: - self.client = openai.OpenAI(api_key=api_key) - except Exception as e: - raise ValueError( - f"OpenAI API client failed to initialize. Error: {e}" - ) from e + + if use_async: + try: + self.client = openai.AsyncOpenAI(api_key=api_key) + except Exception as e: + raise ValueError( + f"AsyncOpenAI API client failed to initialize. Error: {e}" + ) from e + else: + try: + self.client = openai.OpenAI(api_key=api_key) + except Exception as e: + raise ValueError( + f"OpenAI API client failed to initialize. Error: {e}" + ) from e self.temperature = temperature self.max_tokens = max_tokens @@ -108,6 +118,50 @@ def __call__( logger.error(f"LLM error: {e}") raise Exception(f"LLM error: {e}") from e + async def acall( + self, + messages: List[Message], + function_schemas: Optional[List[Dict[str, Any]]] = None, + ) -> str: + if self.client is None: + raise ValueError("OpenAI client is not initialized.") + try: + tools: Union[List[Dict[str, Any]], NotGiven] = ( + function_schemas if function_schemas is not None else NOT_GIVEN + ) + + completion = await self.client.chat.completions.create( + model=self.name, + messages=[m.to_openai() for m in messages], + temperature=self.temperature, + max_tokens=self.max_tokens, + tools=tools, # type: ignore # We pass a list of dicts which get interpreted as Iterable[ChatCompletionToolParam]. + ) + + if function_schemas: + tool_calls = completion.choices[0].message.tool_calls + if tool_calls is None: + raise ValueError("Invalid output, expected a tool call.") + if len(tool_calls) < 1: + raise ValueError( + "Invalid output, expected at least one tool to be specified." + ) + + # Collecting multiple tool calls information + output = str( + self._extract_tool_calls_info(tool_calls) + ) # str in keeping with base type. + else: + content = completion.choices[0].message.content + if content is None: + raise ValueError("Invalid output, expected content.") + output = content + return output + + except Exception as e: + logger.error(f"LLM error: {e}") + raise Exception(f"LLM error: {e}") from e + def extract_function_inputs( self, query: str, function_schemas: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: