|
| 1 | +import itertools |
1 | 2 | import typing as ta |
2 | 3 |
|
3 | 4 | from omlish import cached |
|
40 | 41 | def build_oai_request_msgs(mc_chat: Chat) -> ta.Sequence[pt.ChatCompletionMessage]: |
41 | 42 | oai_msgs: list[pt.ChatCompletionMessage] = [] |
42 | 43 |
|
43 | | - for mc_msg in mc_chat: |
44 | | - if isinstance(mc_msg, SystemMessage): |
45 | | - oai_msgs.append(pt.SystemChatCompletionMessage( |
46 | | - content=check.isinstance(mc_msg.c, str), |
47 | | - )) |
48 | | - |
49 | | - elif isinstance(mc_msg, AiMessage): |
50 | | - oai_msgs.append(pt.AssistantChatCompletionMessage( |
51 | | - content=check.isinstance(mc_msg.c, (str, None)), |
52 | | - )) |
53 | | - |
54 | | - elif isinstance(mc_msg, ToolUseMessage): |
55 | | - oai_msgs.append(pt.AssistantChatCompletionMessage( |
56 | | - tool_calls=[pt.AssistantChatCompletionMessage.ToolCall( |
57 | | - id=check.not_none(mc_msg.tu.id), |
58 | | - function=pt.AssistantChatCompletionMessage.ToolCall.Function( |
59 | | - arguments=check.not_none(mc_msg.tu.raw_args), |
60 | | - name=mc_msg.tu.name, |
61 | | - ), |
62 | | - )], |
63 | | - )) |
64 | | - |
65 | | - elif isinstance(mc_msg, UserMessage): |
66 | | - oai_msgs.append(pt.UserChatCompletionMessage( |
67 | | - content=render_content_str(mc_msg.c), |
68 | | - )) |
69 | | - |
70 | | - elif isinstance(mc_msg, ToolUseResultMessage): |
71 | | - tc: str |
72 | | - if isinstance(mc_msg.tur.c, str): |
73 | | - tc = mc_msg.tur.c |
74 | | - elif isinstance(mc_msg.tur.c, JsonContent): |
75 | | - tc = json.dumps_compact(mc_msg.tur.c) |
76 | | - else: |
77 | | - raise TypeError(mc_msg.tur.c) |
78 | | - oai_msgs.append(pt.ToolChatCompletionMessage( |
79 | | - tool_call_id=check.not_none(mc_msg.tur.id), |
80 | | - content=tc, |
81 | | - )) |
| 44 | + for _, g in itertools.groupby(mc_chat, lambda mc_m: isinstance(mc_m, AnyAiMessage)): |
| 45 | + mc_msgs = list(g) |
| 46 | + |
| 47 | + if isinstance(mc_msgs[0], AnyAiMessage): |
| 48 | + tups: list[tuple[AiMessage | None, list[ToolUseMessage]]] = [] |
| 49 | + for mc_msg in mc_msgs: |
| 50 | + if isinstance(mc_msg, AiMessage): |
| 51 | + tups.append((mc_msg, [])) |
| 52 | + |
| 53 | + elif isinstance(mc_msg, ToolUseMessage): |
| 54 | + if not tups: |
| 55 | + tups.append((None, [])) |
| 56 | + tups[-1][1].append(mc_msg) |
| 57 | + |
| 58 | + else: |
| 59 | + raise TypeError(mc_msg) |
| 60 | + |
| 61 | + for mc_ai_msg, mc_tu_msgs in tups: |
| 62 | + oai_msgs.append(pt.AssistantChatCompletionMessage( |
| 63 | + content=check.isinstance(mc_ai_msg.c, (str, None)) if mc_ai_msg is not None else None, |
| 64 | + tool_calls=[ |
| 65 | + pt.AssistantChatCompletionMessage.ToolCall( |
| 66 | + id=check.not_none(mc_tu_msg.tu.id), |
| 67 | + function=pt.AssistantChatCompletionMessage.ToolCall.Function( |
| 68 | + arguments=check.not_none(mc_tu_msg.tu.raw_args), |
| 69 | + name=mc_tu_msg.tu.name, |
| 70 | + ), |
| 71 | + ) |
| 72 | + for mc_tu_msg in mc_tu_msgs |
| 73 | + ] if mc_tu_msgs else None, |
| 74 | + )) |
82 | 75 |
|
83 | 76 | else: |
84 | | - raise TypeError(mc_msg) |
| 77 | + for mc_msg in mc_msgs: |
| 78 | + if isinstance(mc_msg, SystemMessage): |
| 79 | + oai_msgs.append(pt.SystemChatCompletionMessage( |
| 80 | + content=check.isinstance(mc_msg.c, str), |
| 81 | + )) |
| 82 | + |
| 83 | + elif isinstance(mc_msg, UserMessage): |
| 84 | + oai_msgs.append(pt.UserChatCompletionMessage( |
| 85 | + content=render_content_str(mc_msg.c), |
| 86 | + )) |
| 87 | + |
| 88 | + elif isinstance(mc_msg, ToolUseResultMessage): |
| 89 | + tc: str |
| 90 | + if isinstance(mc_msg.tur.c, str): |
| 91 | + tc = mc_msg.tur.c |
| 92 | + elif isinstance(mc_msg.tur.c, JsonContent): |
| 93 | + tc = json.dumps_compact(mc_msg.tur.c) |
| 94 | + else: |
| 95 | + raise TypeError(mc_msg.tur.c) |
| 96 | + oai_msgs.append(pt.ToolChatCompletionMessage( |
| 97 | + tool_call_id=check.not_none(mc_msg.tur.id), |
| 98 | + content=tc, |
| 99 | + )) |
| 100 | + |
| 101 | + else: |
| 102 | + raise TypeError(mc_msg) |
85 | 103 |
|
86 | 104 | return oai_msgs |
87 | 105 |
|
@@ -228,6 +246,7 @@ def oai_request(self) -> pt.ChatCompletionRequest: |
228 | 246 | messages=build_oai_request_msgs(self._chat), |
229 | 247 | top_p=1, |
230 | 248 | tools=tools or None, |
| 249 | + parallel_tool_calls=True if (tools and not (self._mandatory_kwargs or {}).get('stream')) else None, |
231 | 250 | frequency_penalty=0.0, |
232 | 251 | presence_penalty=0.0, |
233 | 252 | **po.kwargs, |
|
0 commit comments