Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions example/convex/_generated/api.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ export declare const components: {
vectors: Array<Array<number> | null>;
};
failPendingSteps?: boolean;
hideFromUserIdSearch?: boolean;
messages: Array<{
error?: string;
fileIds?: Array<string>;
Expand Down Expand Up @@ -836,6 +837,22 @@ export declare const components: {
}>;
}
>;
cloneThread: FunctionReference<
"action",
"internal",
{
batchSize?: number;
copyUserIdForVectorSearch?: boolean;
excludeToolMessages?: boolean;
insertAtOrder?: number;
limit?: number;
sourceThreadId: string;
statuses?: Array<"pending" | "success" | "failed">;
targetThreadId: string;
upToAndIncludingMessageId?: string;
},
number
>;
deleteByIds: FunctionReference<
"mutation",
"internal",
Expand Down
17 changes: 17 additions & 0 deletions src/component/_generated/api.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ export type Mounts = {
vectors: Array<Array<number> | null>;
};
failPendingSteps?: boolean;
hideFromUserIdSearch?: boolean;
messages: Array<{
error?: string;
fileIds?: Array<string>;
Expand Down Expand Up @@ -662,6 +663,22 @@ export type Mounts = {
}>;
}
>;
cloneThread: FunctionReference<
"action",
"public",
{
batchSize?: number;
copyUserIdForVectorSearch?: boolean;
excludeToolMessages?: boolean;
insertAtOrder?: number;
limit?: number;
sourceThreadId: string;
statuses?: Array<"pending" | "success" | "failed">;
targetThreadId: string;
upToAndIncludingMessageId?: string;
},
number
>;
deleteByIds: FunctionReference<
"mutation",
"public",
Expand Down
261 changes: 209 additions & 52 deletions src/component/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { api, internal } from "./_generated/api.js";
import type { Doc, Id } from "./_generated/dataModel.js";
import {
action,
internalMutation,
internalQuery,
mutation,
type MutationCtx,
Expand Down Expand Up @@ -130,6 +131,9 @@ const addMessagesArgs = {
failPendingSteps: v.optional(v.boolean()),
// A pending message to update. If the pending message failed, abort.
pendingMessageId: v.optional(v.id("messages")),
// if set to true, these messages will not show up in text or vector search
// results for the userId
hideFromUserIdSearch: v.optional(v.boolean()),
};
export const addMessages = mutation({
args: addMessagesArgs,
Expand All @@ -153,6 +157,7 @@ async function addMessagesHandler(
messages,
promptMessageId,
pendingMessageId,
hideFromUserIdSearch,
...rest
} = args;
const promptMessage = promptMessageId && (await ctx.db.get(promptMessageId));
Expand Down Expand Up @@ -219,7 +224,7 @@ async function addMessagesHandler(
vector: embeddings.vectors[i]!,
model: embeddings.model,
table: "messages",
userId,
userId: hideFromUserIdSearch ? undefined : userId,
threadId,
});
}
Expand All @@ -230,7 +235,7 @@ async function addMessagesHandler(
parentMessageId: promptMessageId,
userId,
tool: isTool(message.message),
text: extractText(message.message),
text: hideFromUserIdSearch ? undefined : extractText(message.message),
status: fail ? "failed" : (message.status ?? "success"),
error: fail ? error : message.error,
} satisfies Omit<
Expand Down Expand Up @@ -432,64 +437,216 @@ export const updateMessage = mutation({
},
});

export const listMessagesByThreadId = query({
const cloneMessageArgs = {
sourceThreadId: v.id("threads"),
targetThreadId: v.id("threads"),
// defaults to false, so searching for a message by userId will not find
// these copies
copyUserIdForVectorSearch: v.optional(v.boolean()),
// defaults to false, so tool calls & responses will be copied
excludeToolMessages: v.optional(v.boolean()),
// defaults to copying all messages, but you could just copy success messages.
statuses: v.optional(v.array(vMessageStatus)),
// stop at this message id
upToAndIncludingMessageId: v.optional(v.id("messages")),
// defaults to 0. the messages will be inserted starting at this order.
insertAtOrder: v.optional(v.number()),
};
export const cloneMessageBatch = internalMutation({
args: {
threadId: v.id("threads"),
excludeToolMessages: v.optional(v.boolean()),
/** What order to sort the messages in. To get the latest, use "desc". */
order: v.union(v.literal("asc"), v.literal("desc")),
paginationOpts: v.optional(paginationOptsValidator),
statuses: v.optional(v.array(vMessageStatus)),
upToAndIncludingMessageId: v.optional(v.id("messages")),
...cloneMessageArgs,
paginationOpts: paginationOptsValidator,
},
handler: async (ctx, args) => {
const statuses =
args.statuses ?? vMessageStatus.members.map((m) => m.value);
const last =
args.upToAndIncludingMessageId &&
(await ctx.db.get(args.upToAndIncludingMessageId));
assert(
!last || last.threadId === args.threadId,
"upToAndIncludingMessageId must be a message in the thread",
);
const toolOptions = args.excludeToolMessages ? [false] : [true, false];
const order = args.order ?? "desc";
const streams = toolOptions.flatMap((tool) =>
statuses.map((status) =>
stream(ctx.db, schema)
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) => {
const qq = q
.eq("threadId", args.threadId)
.eq("status", status)
.eq("tool", tool);
if (last) {
return qq.lte("order", last.order);
}
return qq;
})
.order(order)
.filterWith(
// We allow all messages on the same order.
async (m) =>
!last || m.order < last.order || m.order === last.order,
),
),
);
const messages = await mergedStream(streams, [
"order",
"stepOrder",
]).paginate(
args.paginationOpts ?? {
numItems: DEFAULT_RECENT_MESSAGES,
cursor: null,
},
handler: async (
ctx,
args,
): Promise<{
numCopied: number;
continueCursor: string;
isDone: boolean;
}> => {
const orderOffset = args.insertAtOrder ?? 0;
const result = await listMessagesByThreadIdHandler(ctx, {
threadId: args.sourceThreadId,
excludeToolMessages: args.excludeToolMessages,
order: "desc",
paginationOpts: args.paginationOpts,
statuses: args.statuses,
upToAndIncludingMessageId: args.upToAndIncludingMessageId,
});

const existing =
result.page.length === 0
? []
: await mergedStream(
[true, false].flatMap((tool) =>
messageStatuses.map((status) =>
stream(ctx.db, schema)
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) =>
q
.eq("threadId", args.targetThreadId)
.eq("status", status)
.eq("tool", tool)
.gte("order", result.page[0].order)
.lte("order", result.page[result.page.length - 1].order),
),
),
),
["order", "stepOrder"],
).collect();

await Promise.all(
result.page
.filter(
(m) =>
!existing.some(
(e) => e.order === m.order && e.stepOrder === m.stepOrder,
),
)
.map(async (m) => {
// update file refs
if (m.fileIds) {
await changeRefcount(ctx, [], m.fileIds);
}
let embeddingId: VectorTableId | undefined = undefined;
if (m.embeddingId) {
const vector = await ctx.db.get(m.embeddingId);
assert(vector, `Vector ${m.embeddingId} not found`);
const dimension = vector.vector.length;
validateVectorDimension(dimension);
embeddingId = await insertVector(ctx, dimension, {
...pick(vector, ["model", "table", "vector"]),
userId: args.copyUserIdForVectorSearch
? vector.userId
: undefined,
threadId: args.targetThreadId,
});
}
await ctx.db.insert("messages", {
...omit(m, [
"_id",
"_creationTime",
"threadId",
"order",
"embeddingId",
]),
embeddingId,
threadId: args.targetThreadId,
order: orderOffset + m.order,
});
}),
);
return {
numCopied: result.page.length,
continueCursor: result.continueCursor,
isDone: result.isDone,
};
},
});

export const cloneThread = action({
args: {
...cloneMessageArgs,
batchSize: v.optional(v.number()),
// how many messages to copy
limit: v.optional(v.number()),
},
returns: v.number(),
handler: async (ctx, args) => {
let cursor: string | null = null;
let copiedSoFar = 0;
while (copiedSoFar < (args.limit ?? Infinity)) {
const numToCopy = Math.min(
args.batchSize ?? DEFAULT_RECENT_MESSAGES,
args.limit ?? Infinity - copiedSoFar,
);
const result: {
numCopied: number;
continueCursor: string;
isDone: boolean;
} = await ctx.runMutation(internal.messages.cloneMessageBatch, {
...args,
paginationOpts: {
cursor,
numItems: numToCopy,
},
});
copiedSoFar += result.numCopied;
cursor = result.continueCursor;
if (result.isDone) {
break;
}
}
return copiedSoFar;
},
});

export const listMessagesByThreadIdArgs = {
threadId: v.id("threads"),
excludeToolMessages: v.optional(v.boolean()),
/** What order to sort the messages in. To get the latest, use "desc". */
order: v.union(v.literal("asc"), v.literal("desc")),
paginationOpts: v.optional(paginationOptsValidator),
statuses: v.optional(v.array(vMessageStatus)),
upToAndIncludingMessageId: v.optional(v.id("messages")),
};
export const listMessagesByThreadId = query({
args: listMessagesByThreadIdArgs,
handler: async (ctx, args) => {
const messages = await listMessagesByThreadIdHandler(ctx, args);
return { ...messages, page: messages.page.map(publicMessage) };
},
returns: vPaginationResult(vMessageDoc),
});

async function listMessagesByThreadIdHandler(
ctx: QueryCtx,
args: ObjectType<typeof listMessagesByThreadIdArgs>,
) {
const statuses = args.statuses ?? vMessageStatus.members.map((m) => m.value);
const last =
args.upToAndIncludingMessageId &&
(await ctx.db.get(args.upToAndIncludingMessageId));
assert(
!last || last.threadId === args.threadId,
"upToAndIncludingMessageId must be a message in the thread",
);
const toolOptions = args.excludeToolMessages ? [false] : [true, false];
const order = args.order ?? "desc";
const streams = toolOptions.flatMap((tool) =>
statuses.map((status) =>
stream(ctx.db, schema)
.query("messages")
.withIndex("threadId_status_tool_order_stepOrder", (q) => {
const qq = q
.eq("threadId", args.threadId)
.eq("status", status)
.eq("tool", tool);
if (last) {
return qq.lte("order", last.order);
}
return qq;
})
.order(order)
.filterWith(
// We allow all messages on the same order.
async (m) => !last || m.order <= last.order,
),
),
);
const messages = await mergedStream(streams, ["order", "stepOrder"]).paginate(
args.paginationOpts ?? {
numItems: DEFAULT_RECENT_MESSAGES,
cursor: null,
},
);
if (messages.page.length === 0) {
messages.isDone = true;
}
return messages;
}

export const getMessagesByIds = query({
args: { messageIds: v.array(v.id("messages")) },
handler: async (ctx, args) => {
Expand Down