-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PY] feat: streaming support for Tools Augmentation #2215
base: main
Are you sure you want to change the base?
Changes from 2 commits
76a9214
a48e047
745ba65
02985ee
e476731
663f482
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -155,22 +155,22 @@ def before_completion( | |||
if context != ctx: | ||||
return | ||||
|
||||
nonlocal streamer | ||||
|
||||
# Check for a streaming response | ||||
if streaming: | ||||
nonlocal is_streaming | ||||
is_streaming = True | ||||
|
||||
nonlocal streamer | ||||
streamer = StreamingResponse(context) | ||||
memory.set("temp.streamer", streamer) | ||||
streamer = memory.get("temp.streamer") | ||||
if not streamer: | ||||
streamer = StreamingResponse(context) | ||||
memory.set("temp.streamer", streamer) | ||||
|
||||
if self._enable_feedback_loop is not None: | ||||
streamer.set_feedback_loop(self._enable_feedback_loop) | ||||
if self._enable_feedback_loop is not None: | ||||
streamer.set_feedback_loop(self._enable_feedback_loop) | ||||
|
||||
streamer.set_generated_by_ai_label(True) | ||||
streamer.set_generated_by_ai_label(True) | ||||
|
||||
if self._start_streaming_message: | ||||
streamer.queue_informative_update(self._start_streaming_message) | ||||
if self._start_streaming_message: | ||||
streamer.queue_informative_update(self._start_streaming_message) | ||||
|
||||
def chunk_received( | ||||
ctx: TurnContext, | ||||
|
@@ -181,6 +181,13 @@ def chunk_received( | |||
nonlocal streamer | ||||
if (context != ctx) or (streamer is None): | ||||
return | ||||
|
||||
if chunk.delta and ( | ||||
(chunk.delta.action_calls and len(chunk.delta.action_calls) > 0) or | ||||
chunk.delta.action_call_id or | ||||
getattr(chunk.delta, "tool_calls", None) | ||||
): | ||||
return | ||||
|
||||
text = chunk.delta.content if (chunk.delta and chunk.delta.content) else "" | ||||
citations = ( | ||||
|
@@ -284,7 +291,20 @@ def chunk_received( | |||
|
||||
if streamer is not None: | ||||
await streamer.end_stream() | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also need to remove this line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this not align with this line from Steve's updates?
else {
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh I may have the wrong place |
||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to also remove L281-283 (check for |
||||
# Tool call handling | ||||
# Keep the streamer around during tool calls, letting them return as normal messages minus the content. | ||||
# When the tool call completes, reattach to the streamer for continued streaming to the client. | ||||
if res.message and isinstance(res.message.action_calls, list): | ||||
lilyydu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
res.message.content = "" | ||||
else: | ||||
if res.status == "success": | ||||
res.message = None | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we also add in the comments that Steve has for JS so we remember the purpose of these edge cases in the future |
||||
|
||||
await streamer.end_stream() | ||||
memory.delete("temp.streamer") | ||||
return res | ||||
|
||||
except Exception as err: # pylint: disable=broad-except | ||||
return PromptResponse(status="error", error=str(err)) | ||||
finally: | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -271,7 +271,34 @@ async def complete_prompt( | |
if delta.content: | ||
message_content += delta.content | ||
|
||
# TODO: Handle tool calls | ||
# Handle tool calls during streaming | ||
if is_tools_aug and delta.tool_calls: | ||
if not hasattr(message, "action_calls") or message.action_calls is None: | ||
message.action_calls = [] | ||
|
||
for curr_tool_call in delta.tool_calls: | ||
index = curr_tool_call.index | ||
|
||
# Ensure the action_calls array is long enough | ||
while index >= len(message.action_calls): | ||
lilyydu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
message.action_calls.append( | ||
ActionCall( | ||
id="", | ||
function=ActionFunction(name="", arguments=""), | ||
type="function", # /python/packages/ai/teams/ai/prompts/message.py#L123 | ||
lilyydu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
) | ||
|
||
# Update the action_call at the specific index | ||
if curr_tool_call.id: | ||
message.action_calls[index].id = curr_tool_call.id | ||
if curr_tool_call.function: | ||
if curr_tool_call.function.name: | ||
message.action_calls[index].function.name += curr_tool_call.function.name | ||
if curr_tool_call.function.arguments: | ||
message.action_calls[index].function.arguments += curr_tool_call.function.arguments | ||
if curr_tool_call.type == "function": # Type must always match the expected value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we supposed to set the type regardless? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can experiment with leaving it empty, however to my understanding if its anything other than empty, it has to be "function" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, did it end up throwing an error? |
||
message.action_calls[index].type = curr_tool_call.type | ||
|
||
if self._options.logger is not None: | ||
self._options.logger.debug(f"CHUNK ${delta}") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
import sys | ||
import time | ||
import traceback | ||
import logging | ||
from typing import Any, Dict, List | ||
|
||
from botbuilder.core import MemoryStorage, TurnContext | ||
|
@@ -24,6 +25,9 @@ | |
from config import Config | ||
from state import AppTurnState | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
config = Config() | ||
|
||
if config.OPENAI_KEY is None and config.AZURE_OPENAI_KEY is None: | ||
|
@@ -45,6 +49,8 @@ | |
default_model="gpt-4o", | ||
api_version="2023-03-15-preview", | ||
endpoint=config.AZURE_OPENAI_ENDPOINT, | ||
logger=logger, | ||
stream=True, | ||
) | ||
) | ||
|
||
|
@@ -84,7 +90,8 @@ async def on_lights_on( | |
state: AppTurnState, | ||
): | ||
state.conversation.lights_on = True | ||
await context.send_activity("[lights on]") | ||
logging.info("[lights on]") | ||
# await context.send_activity("[lights on]") | ||
return "the lights are now on" | ||
|
||
|
||
|
@@ -94,7 +101,8 @@ async def on_lights_off( | |
state: AppTurnState, | ||
): | ||
state.conversation.lights_on = False | ||
await context.send_activity("[lights off]") | ||
logging.info("[lights off]") | ||
# await context.send_activity("[lights off]") | ||
return "the lights are now off" | ||
|
||
|
||
|
@@ -104,7 +112,8 @@ async def on_pause( | |
_state: AppTurnState, | ||
): | ||
time_ms = int(context.data["time"]) if context.data["time"] else 1000 | ||
await context.send_activity(f"[pausing for {time_ms / 1000} seconds]") | ||
logging.info(f"[pausing for {time_ms / 1000} seconds]") | ||
# await context.send_activity(f"[pausing for {time_ms / 1000} seconds]") | ||
time.sleep(time_ms / 1000) | ||
return "done pausing" | ||
|
||
|
@@ -114,16 +123,18 @@ async def on_lights_status( | |
_context: ActionTurnContext[Dict[str, Any]], | ||
state: AppTurnState, | ||
): | ||
return "the lights are on" if state.conversation.lights_on else "the lights are off" | ||
light_status = "the lights are on" if state.conversation.lights_on else "the lights are off" | ||
logging.info(f"{light_status}") | ||
return light_status | ||
|
||
|
||
@app.error | ||
async def on_error(context: TurnContext, error: Exception): | ||
# This check writes out errors to console log .vs. app insights. | ||
# NOTE: In production environment, you should consider logging this to Azure | ||
# application insights. | ||
print(f"\n [on_turn_error] unhandled error: {error}", file=sys.stderr) | ||
logging.info(f"\n [on_turn_error] unhandled error: {error}, {sys.stderr}") | ||
traceback.print_exc() | ||
|
||
# Send a message to the user | ||
await context.send_activity("The bot encountered an error or bug.") | ||
# await context.send_activity("The bot encountered an error or bug.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. forgot to remove commented out lines in file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BMS-geodev I think you accidentally uncommented instead of removing for latest updates |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should keep it consistent with JS/C# and check
chunk.delta.content
instead