Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PY] feat: streaming support for Tools Augmentation #2215

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions python/packages/ai/teams/ai/clients/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,22 @@ def before_completion(
if context != ctx:
return

nonlocal streamer

# Check for a streaming response
if streaming:
nonlocal is_streaming
is_streaming = True

nonlocal streamer
streamer = StreamingResponse(context)
memory.set("temp.streamer", streamer)
streamer = memory.get("temp.streamer")
if not streamer:
streamer = StreamingResponse(context)
memory.set("temp.streamer", streamer)

if self._enable_feedback_loop is not None:
streamer.set_feedback_loop(self._enable_feedback_loop)
if self._enable_feedback_loop is not None:
streamer.set_feedback_loop(self._enable_feedback_loop)

streamer.set_generated_by_ai_label(True)
streamer.set_generated_by_ai_label(True)

if self._start_streaming_message:
streamer.queue_informative_update(self._start_streaming_message)
if self._start_streaming_message:
streamer.queue_informative_update(self._start_streaming_message)

def chunk_received(
ctx: TurnContext,
Expand All @@ -181,6 +181,13 @@ def chunk_received(
nonlocal streamer
if (context != ctx) or (streamer is None):
return

if chunk.delta and (
(chunk.delta.action_calls and len(chunk.delta.action_calls) > 0) or
chunk.delta.action_call_id or
getattr(chunk.delta.content, "tool_calls", None)
):
return

text = chunk.delta.content if (chunk.delta and chunk.delta.content) else ""
citations = (
Expand Down Expand Up @@ -278,13 +285,27 @@ def chunk_received(
self._add_message_to_history(memory, self._options.history_variable, res.input)
self._add_message_to_history(memory, self._options.history_variable, res.message)

if is_streaming and res.status == "success":
# Delete message from response to avoid sending it twice
res.message = None

# if is_streaming and res.status == "success":
# # Delete message from response to avoid sending it twice
# res.message = None

# if streamer is not None:
# await streamer.end_stream()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to also remove L281-283 (check for is_streaming and res.status == "success")

# Tool call handling
# Keep the streamer around during tool calls, letting them return as normal messages minus the content.
# When the tool call completes, reattach to the streamer for continued streaming to the client.
if res.message and res.message.action_calls and len(res.message.action_calls) > 0:
res.message.content = ""
else:
if res.status == "success":
res.message = None

if streamer is not None:
await streamer.end_stream()
memory.delete("temp.streamer")
return res

except Exception as err: # pylint: disable=broad-except
return PromptResponse(status="error", error=str(err))
finally:
Expand Down
29 changes: 28 additions & 1 deletion python/packages/ai/teams/ai/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,34 @@ async def complete_prompt(
if delta.content:
message_content += delta.content

# TODO: Handle tool calls
# Handle tool calls during streaming
if is_tools_aug and delta.tool_calls:
if not hasattr(message, "action_calls") or message.action_calls is None:
message.action_calls = []

for curr_tool_call in delta.tool_calls:
index = curr_tool_call.index

# Ensure the action_calls array is long enough
if index >= len(message.action_calls):
message.action_calls.append(
ActionCall(
id="",
function=ActionFunction(name="", arguments=""),
type="function",
)
)

# Update the action_call at the specific index
if curr_tool_call.id:
message.action_calls[index].id = curr_tool_call.id
if curr_tool_call.function:
if curr_tool_call.function.name:
message.action_calls[index].function.name += curr_tool_call.function.name
if curr_tool_call.function.arguments:
message.action_calls[index].function.arguments += curr_tool_call.function.arguments
if curr_tool_call.type == "function":
message.action_calls[index].type = curr_tool_call.type

if self._options.logger is not None:
self._options.logger.debug(f"CHUNK ${delta}")
Expand Down
15 changes: 13 additions & 2 deletions python/samples/04.ai.c.actionMapping.lightBot/src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import time
import traceback
import logging
from typing import Any, Dict, List

from botbuilder.core import MemoryStorage, TurnContext
Expand All @@ -24,6 +25,9 @@
from config import Config
from state import AppTurnState

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

config = Config()

if config.OPENAI_KEY is None and config.AZURE_OPENAI_KEY is None:
Expand All @@ -45,6 +49,8 @@
default_model="gpt-4o",
api_version="2023-03-15-preview",
endpoint=config.AZURE_OPENAI_ENDPOINT,
logger=logger,
stream=True,
)
)

Expand Down Expand Up @@ -84,6 +90,7 @@ async def on_lights_on(
state: AppTurnState,
):
state.conversation.lights_on = True
logging.info("[lights on]")
await context.send_activity("[lights on]")
return "the lights are now on"

Expand All @@ -94,6 +101,7 @@ async def on_lights_off(
state: AppTurnState,
):
state.conversation.lights_on = False
logging.info("[lights off]")
await context.send_activity("[lights off]")
return "the lights are now off"

Expand All @@ -104,6 +112,7 @@ async def on_pause(
_state: AppTurnState,
):
time_ms = int(context.data["time"]) if context.data["time"] else 1000
logging.info(f"[pausing for {time_ms / 1000} seconds]")
await context.send_activity(f"[pausing for {time_ms / 1000} seconds]")
time.sleep(time_ms / 1000)
return "done pausing"
Expand All @@ -114,15 +123,17 @@ async def on_lights_status(
_context: ActionTurnContext[Dict[str, Any]],
state: AppTurnState,
):
return "the lights are on" if state.conversation.lights_on else "the lights are off"
light_status = "the lights are on" if state.conversation.lights_on else "the lights are off"
logging.info(f"{light_status}")
return light_status


@app.error
async def on_error(context: TurnContext, error: Exception):
# This check writes out errors to console log .vs. app insights.
# NOTE: In production environment, you should consider logging this to Azure
# application insights.
print(f"\n [on_turn_error] unhandled error: {error}", file=sys.stderr)
logging.info(f"\n [on_turn_error] unhandled error: {error}, {sys.stderr}")
traceback.print_exc()

# Send a message to the user
Expand Down