Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions rig-core/src/providers/mistral/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ pub enum Message {
System {
content: String,
},
Tool {
/// The name of the tool that was called
name: String,
/// The content of the tool call
content: String,
/// The id of the tool call
tool_call_id: String,
},
}

impl Message {
Expand All @@ -107,21 +115,39 @@ impl TryFrom<message::Message> for Vec<Message> {
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
match message {
message::Message::User { content } => {
let (_, other_content): (Vec<_>, Vec<_>) = content
.into_iter()
.partition(|content| matches!(content, message::UserContent::ToolResult(_)));

let messages = other_content
.into_iter()
.filter_map(|content| match content {
let mut tool_result_messages = Vec::new();
let mut other_messages = Vec::new();

for content_item in content {
match content_item {
message::UserContent::ToolResult(message::ToolResult {
id,
call_id,
content: tool_content,
}) => {
let call_id_key = call_id.unwrap_or_else(|| id.clone());
let content_text = tool_content
.into_iter()
.find_map(|content_item| match content_item {
message::ToolResultContent::Text(text) => Some(text.text),
message::ToolResultContent::Image(_) => None,
})
.unwrap_or_default();
tool_result_messages.push(Message::Tool {
name: id,
content: content_text,
tool_call_id: call_id_key,
});
}
message::UserContent::Text(message::Text { text }) => {
Some(Message::User { content: text })
other_messages.push(Message::User { content: text });
}
_ => None,
})
.collect::<Vec<_>>();
_ => {}
}
}

Ok(messages)
tool_result_messages.append(&mut other_messages);
Ok(tool_result_messages)
}
message::Message::Assistant { content, .. } => {
let (text_content, tool_calls) = content.into_iter().fold(
Expand Down