|
| 1 | +""" |
| 2 | +Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +Licensed under the MIT License. |
| 4 | +""" |
| 5 | + |
| 6 | +from dataclasses import dataclass |
| 7 | +from typing import Any, Dict, List, Optional, Union |
| 8 | + |
| 9 | +import yaml |
| 10 | +from botbuilder.core import TurnContext |
| 11 | + |
| 12 | +from teams.ai.modelsv2.chat_completion_action import ChatCompletionAction |
| 13 | +from teams.ai.promptsv2.message import Message |
| 14 | +from teams.ai.promptsv2.prompt_functions import PromptFunctions |
| 15 | +from teams.ai.promptsv2.prompt_section_base import PromptSectionBase |
| 16 | +from teams.ai.promptsv2.rendered_prompt_section import RenderedPromptSection |
| 17 | +from teams.ai.tokenizers.tokenizer import Tokenizer |
| 18 | +from teams.state.memory import Memory |
| 19 | + |
| 20 | + |
| 21 | +@dataclass |
| 22 | +class ActionValue: |
| 23 | + description: Optional[str] = None |
| 24 | + parameters: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class ActionList: |
| 29 | + actions: Dict[str, ActionValue] |
| 30 | + |
| 31 | + |
| 32 | +class ActionAugmentationSection(PromptSectionBase): |
| 33 | + """ |
| 34 | + A prompt section that renders a list of actions to the prompt. |
| 35 | + """ |
| 36 | + |
| 37 | + _text: str |
| 38 | + _token_list: Optional[List[int]] = None |
| 39 | + _actions: Dict[str, ChatCompletionAction] = {} |
| 40 | + |
| 41 | + @property |
| 42 | + def actions(self) -> Dict[str, ChatCompletionAction]: |
| 43 | + """ |
| 44 | + Map of action names to actions. |
| 45 | + """ |
| 46 | + return self._actions |
| 47 | + |
| 48 | + def __init__(self, actions: List[ChatCompletionAction], call_to_action: str) -> None: |
| 49 | + """ |
| 50 | + Creates a new `ActionAugmentationSection` instance. |
| 51 | +
|
| 52 | + Args: |
| 53 | + actions (List[ChatCompletionAction]): List of actions to render. |
| 54 | + call_to_action (str): Text to display after the list of actions. |
| 55 | +
|
| 56 | + """ |
| 57 | + super().__init__(-1, True, "\n\n") |
| 58 | + |
| 59 | + # Convert actions to an ActionList |
| 60 | + action_list: ActionList = {"actions": {}} |
| 61 | + |
| 62 | + for action in actions: |
| 63 | + self._actions[action.name] = action |
| 64 | + action_list["actions"][action.name] = {} |
| 65 | + if action.description: |
| 66 | + action_list["actions"][action.name]["description"] = action.description |
| 67 | + if action.parameters: |
| 68 | + params = action.parameters |
| 69 | + action_list["actions"][action.name]["parameters"] = ( |
| 70 | + params.get("properties") |
| 71 | + if params.get("additional_properties") is None |
| 72 | + else params |
| 73 | + ) |
| 74 | + |
| 75 | + # Build augmentation text |
| 76 | + self._text = f"{yaml.dump(action_list)}\n\n{call_to_action}" |
| 77 | + |
| 78 | + async def render_as_messages( |
| 79 | + self, |
| 80 | + context: TurnContext, |
| 81 | + memory: Memory, |
| 82 | + functions: PromptFunctions, |
| 83 | + tokenizer: Tokenizer, |
| 84 | + max_tokens: int, |
| 85 | + ) -> RenderedPromptSection[List[Message[str]]]: |
| 86 | + """ |
| 87 | + Renders the prompt section as a list of `Message` objects. |
| 88 | +
|
| 89 | + Args: |
| 90 | + context (TurnContext): Context for the current turn of conversation. |
| 91 | + memory (Memory): Interface for accessing state variables. |
| 92 | + functions (PromptFunctions): Functions for rendering prompts. |
| 93 | + tokenizer (Tokenizer): Tokenizer to use for encoding/decoding text. |
| 94 | + max_tokens (int): Maximum number of tokens allowed for the rendered prompt. |
| 95 | +
|
| 96 | + Returns: |
| 97 | + RenderedPromptSection[List[Message[str]]]: The rendered prompt section. |
| 98 | +
|
| 99 | + """ |
| 100 | + # Tokenize on first use |
| 101 | + if not self._token_list: |
| 102 | + self._token_list = tokenizer.encode(self._text) |
| 103 | + |
| 104 | + # Check for max tokens |
| 105 | + if len(self._token_list) > max_tokens: |
| 106 | + trimmed = self._token_list[0:max_tokens] |
| 107 | + return RenderedPromptSection[List[Message[str]]]( |
| 108 | + output=[Message[str](role="system", content=tokenizer.decode(trimmed))], |
| 109 | + length=len(trimmed), |
| 110 | + too_long=True, |
| 111 | + ) |
| 112 | + return RenderedPromptSection[List[Message[str]]]( |
| 113 | + output=[Message[str](role="system", content=self._text)], |
| 114 | + length=len(self._token_list), |
| 115 | + too_long=False, |
| 116 | + ) |
0 commit comments