Skip to content

Commit dcd5452

Browse files
committed
Add parallel tool call support to OpenAI integration and AI driver
Implemented grouping of AI messages to generate a single assistant message with multiple tool calls, added `parallel_tool_calls` flag in OpenAI request construction, and refactored tool execution to batch parallel results. Updated CLI formatting, added necessary imports, and introduced comprehensive tests for request message generation, request flag behavior, and streaming with parallel tools. Adjusted driver logic to handle parallel tool messages without sequential loops.
1 parent 866f5a5 commit dcd5452

6 files changed

Lines changed: 302 additions & 57 deletions

File tree

omlish/text/docwrap/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def _main(argv: ta.Sequence[str] | None = None) -> None:
7575

7676
parser.add_argument('-s', '--start-line', type=int)
7777
parser.add_argument('-e', '--end-line', type=int)
78+
7879
parser.add_argument('-i', '--in-place', action='store_true')
7980

8081
args = parser.parse_args(argv)

ommlds/minichain/backends/openai/format.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import typing as ta
23

34
from omlish import cached
@@ -40,48 +41,65 @@
4041
def build_oai_request_msgs(mc_chat: Chat) -> ta.Sequence[pt.ChatCompletionMessage]:
4142
oai_msgs: list[pt.ChatCompletionMessage] = []
4243

43-
for mc_msg in mc_chat:
44-
if isinstance(mc_msg, SystemMessage):
45-
oai_msgs.append(pt.SystemChatCompletionMessage(
46-
content=check.isinstance(mc_msg.c, str),
47-
))
48-
49-
elif isinstance(mc_msg, AiMessage):
50-
oai_msgs.append(pt.AssistantChatCompletionMessage(
51-
content=check.isinstance(mc_msg.c, (str, None)),
52-
))
53-
54-
elif isinstance(mc_msg, ToolUseMessage):
55-
oai_msgs.append(pt.AssistantChatCompletionMessage(
56-
tool_calls=[pt.AssistantChatCompletionMessage.ToolCall(
57-
id=check.not_none(mc_msg.tu.id),
58-
function=pt.AssistantChatCompletionMessage.ToolCall.Function(
59-
arguments=check.not_none(mc_msg.tu.raw_args),
60-
name=mc_msg.tu.name,
61-
),
62-
)],
63-
))
64-
65-
elif isinstance(mc_msg, UserMessage):
66-
oai_msgs.append(pt.UserChatCompletionMessage(
67-
content=render_content_str(mc_msg.c),
68-
))
69-
70-
elif isinstance(mc_msg, ToolUseResultMessage):
71-
tc: str
72-
if isinstance(mc_msg.tur.c, str):
73-
tc = mc_msg.tur.c
74-
elif isinstance(mc_msg.tur.c, JsonContent):
75-
tc = json.dumps_compact(mc_msg.tur.c)
76-
else:
77-
raise TypeError(mc_msg.tur.c)
78-
oai_msgs.append(pt.ToolChatCompletionMessage(
79-
tool_call_id=check.not_none(mc_msg.tur.id),
80-
content=tc,
81-
))
44+
for _, g in itertools.groupby(mc_chat, lambda mc_m: isinstance(mc_m, AnyAiMessage)):
45+
mc_msgs = list(g)
46+
47+
if isinstance(mc_msgs[0], AnyAiMessage):
48+
tups: list[tuple[AiMessage | None, list[ToolUseMessage]]] = []
49+
for mc_msg in mc_msgs:
50+
if isinstance(mc_msg, AiMessage):
51+
tups.append((mc_msg, []))
52+
53+
elif isinstance(mc_msg, ToolUseMessage):
54+
if not tups:
55+
tups.append((None, []))
56+
tups[-1][1].append(mc_msg)
57+
58+
else:
59+
raise TypeError(mc_msg)
60+
61+
for mc_ai_msg, mc_tu_msgs in tups:
62+
oai_msgs.append(pt.AssistantChatCompletionMessage(
63+
content=check.isinstance(mc_ai_msg.c, (str, None)) if mc_ai_msg is not None else None,
64+
tool_calls=[
65+
pt.AssistantChatCompletionMessage.ToolCall(
66+
id=check.not_none(mc_tu_msg.tu.id),
67+
function=pt.AssistantChatCompletionMessage.ToolCall.Function(
68+
arguments=check.not_none(mc_tu_msg.tu.raw_args),
69+
name=mc_tu_msg.tu.name,
70+
),
71+
)
72+
for mc_tu_msg in mc_tu_msgs
73+
] if mc_tu_msgs else None,
74+
))
8275

8376
else:
84-
raise TypeError(mc_msg)
77+
for mc_msg in mc_msgs:
78+
if isinstance(mc_msg, SystemMessage):
79+
oai_msgs.append(pt.SystemChatCompletionMessage(
80+
content=check.isinstance(mc_msg.c, str),
81+
))
82+
83+
elif isinstance(mc_msg, UserMessage):
84+
oai_msgs.append(pt.UserChatCompletionMessage(
85+
content=render_content_str(mc_msg.c),
86+
))
87+
88+
elif isinstance(mc_msg, ToolUseResultMessage):
89+
tc: str
90+
if isinstance(mc_msg.tur.c, str):
91+
tc = mc_msg.tur.c
92+
elif isinstance(mc_msg.tur.c, JsonContent):
93+
tc = json.dumps_compact(mc_msg.tur.c)
94+
else:
95+
raise TypeError(mc_msg.tur.c)
96+
oai_msgs.append(pt.ToolChatCompletionMessage(
97+
tool_call_id=check.not_none(mc_msg.tur.id),
98+
content=tc,
99+
))
100+
101+
else:
102+
raise TypeError(mc_msg)
85103

86104
return oai_msgs
87105

@@ -228,6 +246,7 @@ def oai_request(self) -> pt.ChatCompletionRequest:
228246
messages=build_oai_request_msgs(self._chat),
229247
top_p=1,
230248
tools=tools or None,
249+
parallel_tool_calls=True if (tools and not (self._mandatory_kwargs or {}).get('stream')) else None,
231250
frequency_penalty=0.0,
232251
presence_penalty=0.0,
233252
**po.kwargs,

ommlds/minichain/backends/openai/tests/test_chat.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from omlish.http import all as http
99
from omlish.secrets.tests.harness import HarnessSecrets
1010

11+
from .....backends.openai import protocol as pt
1112
from ....chat.choices.services import ChatChoicesRequest
1213
from ....chat.choices.services import ChatChoicesService
1314
from ....chat.messages import Message
@@ -22,9 +23,87 @@
2223
from ....tools.types import ToolDtype
2324
from ....tools.types import ToolParam
2425
from ....tools.types import ToolSpec
26+
from ....tools.types import ToolUse
2527
from ....tools.types import ToolUseResult
2628
from ....wrappers.updateoptions import UpdateOptionsService
2729
from ..chat import OpenaiChatChoicesService
30+
from ..format import OpenaiChatRequestHandler
31+
from ..format import build_oai_request_msgs
32+
33+
34+
def test_openai_parallel_tool_call_request_messages():
35+
chat: list[Message] = [
36+
UserMessage('Use both tools.'),
37+
ToolUseMessage(ToolUse(
38+
id='call_1',
39+
name='first_tool',
40+
args={'value': 1},
41+
raw_args='{"value":1}',
42+
)),
43+
ToolUseMessage(ToolUse(
44+
id='call_2',
45+
name='second_tool',
46+
args={'value': 2},
47+
raw_args='{"value":2}',
48+
)),
49+
ToolUseResultMessage(ToolUseResult(
50+
id='call_1',
51+
name='first_tool',
52+
c='first result',
53+
)),
54+
ToolUseResultMessage(ToolUseResult(
55+
id='call_2',
56+
name='second_tool',
57+
c='second result',
58+
)),
59+
]
60+
61+
oai_msgs = list(build_oai_request_msgs(chat))
62+
63+
assert len(oai_msgs) == 4
64+
assistant_msg = check.isinstance(oai_msgs[1], pt.AssistantChatCompletionMessage)
65+
assert assistant_msg.content is None
66+
tool_calls = list(check.not_none(assistant_msg.tool_calls))
67+
assert [tc.id for tc in tool_calls] == ['call_1', 'call_2']
68+
assert [tc.function.name for tc in tool_calls] == ['first_tool', 'second_tool']
69+
assert [tc.function.arguments for tc in tool_calls] == ['{"value":1}', '{"value":2}']
70+
71+
tool_result_msgs = [
72+
check.isinstance(oai_msg, pt.ToolChatCompletionMessage)
73+
for oai_msg in oai_msgs[2:]
74+
]
75+
assert [m.tool_call_id for m in tool_result_msgs] == ['call_1', 'call_2']
76+
77+
78+
def test_openai_parallel_tool_calls_enabled_with_tools():
79+
tool_spec = ToolSpec(
80+
'get_weather',
81+
params=[
82+
ToolParam(
83+
'location',
84+
type=ToolDtype.of(str),
85+
),
86+
],
87+
)
88+
89+
req = OpenaiChatRequestHandler(
90+
[UserMessage('What is the weather in Seattle?')],
91+
Tool(tool_spec),
92+
model='gpt-test',
93+
).oai_request()
94+
95+
assert req.tools is not None
96+
assert req.parallel_tool_calls is True
97+
98+
99+
def test_openai_parallel_tool_calls_omitted_without_tools():
100+
req = OpenaiChatRequestHandler(
101+
[UserMessage('Hi!')],
102+
model='gpt-test',
103+
).oai_request()
104+
105+
assert req.tools is None
106+
assert req.parallel_tool_calls is None
28107

29108

30109
@pytest.mark.online

ommlds/minichain/backends/openai/tests/test_stream.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,42 @@ async def test_openai_stream_tools(harness):
8585
[
8686
SystemMessage("You are a helpful agent. Use any tools available to you to answer the user's questions."),
8787
UserMessage('What is the weather in Seattle?'),
88-
UserMessage(''),
88+
],
89+
[
90+
Tool(tool_spec),
91+
],
92+
)
93+
94+
async with (await llm.invoke(foo_req)).v as it:
95+
async for o in it:
96+
print(o)
97+
print(it.outputs)
98+
99+
100+
@pytest.mark.asyncs('asyncio')
101+
@pytest.mark.online
102+
async def test_openai_stream_parallel_tools(harness):
103+
tool_spec = ToolSpec(
104+
'get_weather',
105+
params=[
106+
ToolParam(
107+
'location',
108+
type=ToolDtype.of(str),
109+
desc='The location to get the weather for.',
110+
),
111+
],
112+
desc='Gets the weather in the given location.',
113+
)
114+
115+
llm = OpenaiChatChoicesStreamService(
116+
ApiKey(harness[HarnessSecrets].get_or_skip('openai_api_key').reveal()),
117+
)
118+
119+
foo_req: ChatChoicesStreamRequest
120+
foo_req = ChatChoicesStreamRequest(
121+
[
122+
SystemMessage("You are a helpful agent. Use any tools available to you to answer the user's questions."),
123+
UserMessage('What is the weather in Seattle? Also, what is the weather in San Francisco?'),
89124
],
90125
[
91126
Tool(tool_spec),

ommlds/minichain/drivers/ai/tools.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,29 @@ async def generate_ai_chat(self, args: GenerateAiChatArgs) -> Chat:
4747
dc.replace(args, chat=[*args.chat, *out]),
4848
)
4949

50-
cont = False
50+
out.extend(new)
5151

52-
for msg in new:
53-
out.append(msg)
52+
tool_use_messages = [
53+
msg
54+
for msg in new
55+
if isinstance(msg, ToolUseMessage)
56+
]
5457

55-
if isinstance(msg, ToolUseMessage):
56-
use = msg.tu
58+
for msg in tool_use_messages:
59+
use = msg.tu
5760

58-
tce = self._catalog.by_name[check.non_empty_str(use.name)]
61+
tce = self._catalog.by_name[check.non_empty_str(use.name)]
5962

60-
trr = await self._executor.execute_tool_use(ToolUseExecution(
61-
msg.tu,
62-
tce,
63-
))
63+
trr = await self._executor.execute_tool_use(ToolUseExecution(
64+
msg.tu,
65+
tce,
66+
))
6467

65-
trm = ToolUseResultMessage(trr)
68+
trm = ToolUseResultMessage(trr)
6669

67-
trm = check.isinstance(check.single(self._mt.transform(trm)), ToolUseResultMessage)
70+
trm = check.isinstance(check.single(self._mt.transform(trm)), ToolUseResultMessage)
6871

69-
out.append(trm)
72+
out.append(trm)
7073

71-
cont = True
72-
73-
if not cont:
74+
if not tool_use_messages:
7475
return out

0 commit comments

Comments
 (0)