Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: message content type #696

Merged
merged 4 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/jsonExtract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ async function main() {
],
});

const json = JSON.parse(response.message.content);

console.log(json);
console.log(response.message.content);
}

main().catch(console.error);
5 changes: 3 additions & 2 deletions examples/recipes/cost-analysis.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { encodingForModel } from "js-tiktoken";
import { OpenAI } from "llamaindex";
import { Settings } from "llamaindex/Settings";
import { extractText } from "llamaindex/llm/utils";

const encoding = encodingForModel("gpt-4-0125-preview");

Expand All @@ -13,7 +14,7 @@ let tokenCount = 0;
Settings.callbackManager.on("llm-start", (event) => {
const { messages } = event.detail.payload;
tokenCount += messages.reduce((count, message) => {
return count + encoding.encode(message.content).length;
return count + encoding.encode(extractText(message.content)).length;
}, 0);
console.log("Token count:", tokenCount);
// https://openai.com/pricing
Expand All @@ -22,7 +23,7 @@ Settings.callbackManager.on("llm-start", (event) => {
});
Settings.callbackManager.on("llm-end", (event) => {
const { response } = event.detail.payload;
tokenCount += encoding.encode(response.message.content).length;
tokenCount += encoding.encode(extractText(response.message.content)).length;
console.log("Token count:", tokenCount);
// https://openai.com/pricing
// $30.00 / 1M tokens
Expand Down
4 changes: 3 additions & 1 deletion packages/core/src/ChatHistory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { SummaryPrompt } from "./Prompt.js";
import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js";
import { OpenAI } from "./llm/open_ai.js";
import type { ChatMessage, LLM, MessageType } from "./llm/types.js";
import { extractText } from "./llm/utils.js";

/**
* A ChatHistory is used to keep the state of back and forth chat messages
Expand Down Expand Up @@ -188,7 +189,8 @@ export class SummaryChatHistory extends ChatHistory {

// get tokens of current request messages and the transient messages
const tokens = requestMessages.reduce(
(count, message) => count + this.tokenizer(message.content).length,
(count, message) =>
count + this.tokenizer(extractText(message.content)).length,
0,
);
if (tokens > this.tokensToSummarize) {
Expand Down
11 changes: 9 additions & 2 deletions packages/core/src/agent/openai/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ import {
type LLMChatParamsBase,
type OpenAIAdditionalChatOptions,
} from "../../llm/index.js";
import { streamConverter, streamReducer } from "../../llm/utils.js";
import {
extractText,
streamConverter,
streamReducer,
} from "../../llm/utils.js";
import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js";
import type { ObjectRetriever } from "../../objects/base.js";
import type { ToolOutput } from "../../tools/types.js";
Expand Down Expand Up @@ -162,7 +166,10 @@ export class OpenAIAgentWorker
): AgentChatResponse {
task.extraState.newMemory.put(aiMessage);

return new AgentChatResponse(aiMessage.content, task.extraState.sources);
return new AgentChatResponse(
extractText(aiMessage.content),
task.extraState.sources,
);
}

private async _getStreamAiResponse(
Expand Down
7 changes: 5 additions & 2 deletions packages/core/src/agent/react/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ChatMessage } from "../../llm/index.js";
import { extractText } from "../../llm/utils.js";

export interface BaseReasoningStep {
getContent(): string;
Expand Down Expand Up @@ -51,10 +52,12 @@ export abstract class BaseOutputParser {
formatMessages(messages: ChatMessage[]): ChatMessage[] {
if (messages) {
if (messages[0].role === "system") {
messages[0].content = this.format(messages[0].content || "");
messages[0].content = this.format(
extractText(messages[0].content) || "",
);
} else {
messages[messages.length - 1].content = this.format(
messages[messages.length - 1].content || "",
extractText(messages[messages.length - 1].content) || "",
);
}
}
Expand Down
9 changes: 5 additions & 4 deletions packages/core/src/agent/react/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { ChatMessage } from "cohere-ai/api";
import { Settings } from "../../Settings.js";
import { AgentChatResponse } from "../../engines/chat/index.js";
import { type ChatResponse, type LLM } from "../../llm/index.js";
import { extractText } from "../../llm/utils.js";
import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js";
import type { ObjectRetriever } from "../../objects/base.js";
import { ToolOutput } from "../../tools/index.js";
Expand Down Expand Up @@ -34,7 +35,7 @@ function addUserStepToReasoning(
): void {
if (step.stepState.isFirst) {
memory.put({
content: step.input,
content: step.input ?? "",
role: "user",
});
step.stepState.isFirst = false;
Expand Down Expand Up @@ -130,7 +131,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> {

try {
reasoningStep = this.outputParser.parse(
messageContent,
extractText(messageContent),
isStreaming,
) as ActionReasoningStep;
} catch (e) {
Expand All @@ -144,7 +145,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> {
currentReasoning.push(reasoningStep);

if (reasoningStep.isDone()) {
return [messageContent, currentReasoning, true];
return [extractText(messageContent), currentReasoning, true];
}

const actionReasoningStep = new ActionReasoningStep({
Expand All @@ -157,7 +158,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> {
throw new Error(`Expected ActionReasoningStep, got ${reasoningStep}`);
}

return [messageContent, currentReasoning, false];
return [extractText(messageContent), currentReasoning, false];
}

async _processActions(
Expand Down
5 changes: 4 additions & 1 deletion packages/core/src/engines/chat/ContextChatEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {
messages: requestMessages.messages,
});
chatHistory.addMessage(response.message);
return new Response(response.message.content, requestMessages.nodes);
return new Response(
extractText(response.message.content),
requestMessages.nodes,
);
}

reset() {
Expand Down
10 changes: 7 additions & 3 deletions packages/core/src/engines/chat/SimpleChatEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ import { Response } from "../../Response.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatResponseChunk, LLM } from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
import { streamConverter, streamReducer } from "../../llm/utils.js";
import {
extractText,
streamConverter,
streamReducer,
} from "../../llm/utils.js";
import type {
ChatEngine,
ChatEngineParamsNonStreaming,
Expand Down Expand Up @@ -46,7 +50,7 @@ export class SimpleChatEngine implements ChatEngine {
streamReducer({
stream,
initialValue: "",
reducer: (accumulator, part) => (accumulator += part.delta),
reducer: (accumulator, part) => accumulator + part.delta,
finished: (accumulator) => {
chatHistory.addMessage({ content: accumulator, role: "assistant" });
},
Expand All @@ -59,7 +63,7 @@ export class SimpleChatEngine implements ChatEngine {
messages: await chatHistory.requestMessages(),
});
chatHistory.addMessage(response.message);
return new Response(response.message.content);
return new Response(extractText(response.message.content));
}

reset() {
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/evaluation/Correctness.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { MetadataMode } from "../Node.js";
import type { ServiceContext } from "../ServiceContext.js";
import { llmFromSettingsOrContext } from "../Settings.js";
import type { ChatMessage, LLM } from "../llm/types.js";
import { extractText } from "../llm/utils.js";
import { PromptMixin } from "../prompts/Mixin.js";
import type { CorrectnessSystemPrompt } from "./prompts.js";
import {
Expand Down Expand Up @@ -85,7 +86,7 @@ export class CorrectnessEvaluator extends PromptMixin implements BaseEvaluator {
});

const [score, reasoning] = this.parserFunction(
evalResponse.message.content,
extractText(evalResponse.message.content),
);

return {
Expand Down
19 changes: 12 additions & 7 deletions packages/core/src/llm/LLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import type {
LLMMetadata,
MessageType,
} from "./types.js";
import { wrapLLMEvent } from "./utils.js";
import { extractText, wrapLLMEvent } from "./utils.js";

export const ALL_AVAILABLE_LLAMADEUCE_MODELS = {
"Llama-2-70b-chat-old": {
Expand Down Expand Up @@ -215,16 +215,15 @@ If a question does not make any sense, or is not factually coherent, explain why

return {
prompt: messages.reduce((acc, message, index) => {
const content = extractText(message.content);
if (index % 2 === 0) {
return (
`${acc}${
withBos ? BOS : ""
}${B_INST} ${message.content.trim()} ${E_INST}` +
`${acc}${withBos ? BOS : ""}${B_INST} ${content.trim()} ${E_INST}` +
(withNewlines ? "\n" : "")
);
} else {
return (
`${acc} ${message.content.trim()}` +
`${acc} ${content.trim()}` +
(withNewlines ? "\n" : " ") +
(withBos ? EOS : "")
); // Yes, the EOS comes after the space. This is not a mistake.
Expand Down Expand Up @@ -322,7 +321,10 @@ export class Portkey extends BaseLLM {
} else {
const bodyParams = additionalChatOptions || {};
const response = await this.session.portkey.chatCompletions.create({
messages,
messages: messages.map((message) => ({
content: extractText(message.content),
role: message.role,
})),
...bodyParams,
});

Expand All @@ -337,7 +339,10 @@ export class Portkey extends BaseLLM {
params?: Record<string, any>,
): AsyncIterable<ChatResponseChunk> {
const chunkStream = await this.session.portkey.chatCompletions.create({
messages,
messages: messages.map((message) => ({
content: extractText(message.content),
role: message.role,
})),
...params,
stream: true,
});
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/llm/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import type {
} from "llamaindex";
import _ from "lodash";
import { BaseLLM } from "./base.js";
import { wrapLLMEvent } from "./utils.js";
import { extractText, wrapLLMEvent } from "./utils.js";

export class AnthropicSession {
anthropic: SDKAnthropic;
Expand Down Expand Up @@ -138,7 +138,7 @@ export class Anthropic extends BaseLLM {
}

return {
content: message.content,
content: extractText(message.content),
role: message.role,
};
});
Expand Down
7 changes: 5 additions & 2 deletions packages/core/src/llm/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import type {
LLMCompletionParamsStreaming,
LLMMetadata,
} from "./types.js";
import { streamConverter } from "./utils.js";
import { extractText, streamConverter } from "./utils.js";

export abstract class BaseLLM<
AdditionalChatOptions extends Record<string, unknown> = Record<
Expand Down Expand Up @@ -44,7 +44,10 @@ export abstract class BaseLLM<
const chatResponse = await this.chat({
messages: [{ content: prompt, role: "user" }],
});
return { text: chatResponse.message.content as string };
return {
text: extractText(chatResponse.message.content),
raw: chatResponse.raw,
};
}

abstract chat(
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/llm/open_ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> {
stream: false,
});

const content = response.choices[0].message?.content ?? null;
const content = response.choices[0].message?.content ?? "";

const kwargsOutput: Record<string, any> = {};

Expand Down
23 changes: 15 additions & 8 deletions packages/core/src/llm/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ export type MessageType =
| "tool";

export interface ChatMessage {
// TODO: use MessageContent
content: any;
content: MessageContent;
role: MessageType;
additionalKwargs?: Record<string, any>;
}
Expand Down Expand Up @@ -137,7 +136,7 @@ export interface LLMChatParamsNonStreaming<
}

export interface LLMCompletionParamsBase {
prompt: any;
prompt: MessageContent;
}

export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase {
Expand All @@ -149,11 +148,19 @@ export interface LLMCompletionParamsNonStreaming
stream?: false | null;
}

export interface MessageContentDetail {
type: "text" | "image_url";
text?: string;
image_url?: { url: string };
}
export type MessageContentTextDetail = {
type: "text";
text: string;
};

export type MessageContentImageDetail = {
type: "image_url";
image_url: { url: string };
};

export type MessageContentDetail =
| MessageContentTextDetail
| MessageContentImageDetail;

/**
* Extended type for the content of a message that allows for multi-modal messages.
Expand Down
Loading
Loading