Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PY] feat: streaming support for Tools Augmentation #2215

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions python/packages/ai/teams/ai/clients/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,22 @@ def before_completion(
if context != ctx:
return

nonlocal streamer

# Check for a streaming response
if streaming:
nonlocal is_streaming
is_streaming = True

nonlocal streamer
streamer = StreamingResponse(context)
memory.set("temp.streamer", streamer)
streamer = memory.get("temp.streamer")
if not streamer:
streamer = StreamingResponse(context)
memory.set("temp.streamer", streamer)

if self._enable_feedback_loop is not None:
streamer.set_feedback_loop(self._enable_feedback_loop)
if self._enable_feedback_loop is not None:
streamer.set_feedback_loop(self._enable_feedback_loop)

streamer.set_generated_by_ai_label(True)
streamer.set_generated_by_ai_label(True)

if self._start_streaming_message:
streamer.queue_informative_update(self._start_streaming_message)
if self._start_streaming_message:
streamer.queue_informative_update(self._start_streaming_message)

def chunk_received(
ctx: TurnContext,
Expand All @@ -181,6 +181,13 @@ def chunk_received(
nonlocal streamer
if (context != ctx) or (streamer is None):
return

if chunk.delta and (
(chunk.delta.action_calls and len(chunk.delta.action_calls) > 0) or
chunk.delta.action_call_id or
getattr(chunk.delta, "tool_calls", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should keep it consistent with JS/C# and check chunk.delta.content instead

):
return

text = chunk.delta.content if (chunk.delta and chunk.delta.content) else ""
citations = (
Expand Down Expand Up @@ -284,7 +291,20 @@ def chunk_received(

if streamer is not None:
await streamer.end_stream()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to remove this line

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this not align with this line from Steve's updates?

await streamer.endStream();

else {
if (response.status == 'success') {
// Delete message from response to avoid sending it twice
delete response.message;
}

                // End the stream and remove pointer from memory
                // - We're not listening for the response received event because we can't await the completion of events.
                await streamer.endStream();
                memory.deleteValue('temp.streamer');
            }

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I may have the wrong place


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to also remove L281-283 (check for is_streaming and res.status == "success")

# Tool call handling
# Keep the streamer around during tool calls, letting them return as normal messages minus the content.
# When the tool call completes, reattach to the streamer for continued streaming to the client.
if res.message and isinstance(res.message.action_calls, list):
lilyydu marked this conversation as resolved.
Show resolved Hide resolved
res.message.content = ""
else:
if res.status == "success":
res.message = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we also add in the comments that Steve has for JS so we remember the purpose of these edge cases in the future


await streamer.end_stream()
memory.delete("temp.streamer")
return res

except Exception as err: # pylint: disable=broad-except
return PromptResponse(status="error", error=str(err))
finally:
Expand Down
29 changes: 28 additions & 1 deletion python/packages/ai/teams/ai/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,34 @@ async def complete_prompt(
if delta.content:
message_content += delta.content

# TODO: Handle tool calls
# Handle tool calls during streaming
if is_tools_aug and delta.tool_calls:
if not hasattr(message, "action_calls") or message.action_calls is None:
message.action_calls = []

for curr_tool_call in delta.tool_calls:
index = curr_tool_call.index

# Ensure the action_calls array is long enough
while index >= len(message.action_calls):
lilyydu marked this conversation as resolved.
Show resolved Hide resolved
message.action_calls.append(
ActionCall(
id="",
function=ActionFunction(name="", arguments=""),
type="function", # /python/packages/ai/teams/ai/prompts/message.py#L123
lilyydu marked this conversation as resolved.
Show resolved Hide resolved
)
)

# Update the action_call at the specific index
if curr_tool_call.id:
message.action_calls[index].id = curr_tool_call.id
if curr_tool_call.function:
if curr_tool_call.function.name:
message.action_calls[index].function.name += curr_tool_call.function.name
if curr_tool_call.function.arguments:
message.action_calls[index].function.arguments += curr_tool_call.function.arguments
if curr_tool_call.type == "function": # Type must always match the expected value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we supposed to set the type regardless?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can experiment with leaving it empty, however to my understanding if its anything other than empty, it has to be "function"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, did it end up throwing an error?

message.action_calls[index].type = curr_tool_call.type

if self._options.logger is not None:
self._options.logger.debug(f"CHUNK ${delta}")
Expand Down
23 changes: 17 additions & 6 deletions python/samples/04.ai.c.actionMapping.lightBot/src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import time
import traceback
import logging
from typing import Any, Dict, List

from botbuilder.core import MemoryStorage, TurnContext
Expand All @@ -24,6 +25,9 @@
from config import Config
from state import AppTurnState

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

config = Config()

if config.OPENAI_KEY is None and config.AZURE_OPENAI_KEY is None:
Expand All @@ -45,6 +49,8 @@
default_model="gpt-4o",
api_version="2023-03-15-preview",
endpoint=config.AZURE_OPENAI_ENDPOINT,
logger=logger,
stream=True,
)
)

Expand Down Expand Up @@ -84,7 +90,8 @@ async def on_lights_on(
state: AppTurnState,
):
state.conversation.lights_on = True
await context.send_activity("[lights on]")
logging.info("[lights on]")
# await context.send_activity("[lights on]")
return "the lights are now on"


Expand All @@ -94,7 +101,8 @@ async def on_lights_off(
state: AppTurnState,
):
state.conversation.lights_on = False
await context.send_activity("[lights off]")
logging.info("[lights off]")
# await context.send_activity("[lights off]")
return "the lights are now off"


Expand All @@ -104,7 +112,8 @@ async def on_pause(
_state: AppTurnState,
):
time_ms = int(context.data["time"]) if context.data["time"] else 1000
await context.send_activity(f"[pausing for {time_ms / 1000} seconds]")
logging.info(f"[pausing for {time_ms / 1000} seconds]")
# await context.send_activity(f"[pausing for {time_ms / 1000} seconds]")
time.sleep(time_ms / 1000)
return "done pausing"

Expand All @@ -114,16 +123,18 @@ async def on_lights_status(
_context: ActionTurnContext[Dict[str, Any]],
state: AppTurnState,
):
return "the lights are on" if state.conversation.lights_on else "the lights are off"
light_status = "the lights are on" if state.conversation.lights_on else "the lights are off"
logging.info(f"{light_status}")
return light_status


@app.error
async def on_error(context: TurnContext, error: Exception):
# This check writes out errors to console log .vs. app insights.
# NOTE: In production environment, you should consider logging this to Azure
# application insights.
print(f"\n [on_turn_error] unhandled error: {error}", file=sys.stderr)
logging.info(f"\n [on_turn_error] unhandled error: {error}, {sys.stderr}")
traceback.print_exc()

# Send a message to the user
await context.send_activity("The bot encountered an error or bug.")
# await context.send_activity("The bot encountered an error or bug.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to remove commented out lines in file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BMS-geodev I think you accidentally uncommented instead of removing for latest updates