Skip to content

Commit

Permalink
Added lore functionality, still need to fully test
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Feb 21, 2024
1 parent af142bd commit 25fc369
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 57 deletions.
103 changes: 103 additions & 0 deletions src/lib/__tests__/lore.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import { type User } from "@supabase/supabase-js";
import { type UUID } from "crypto";
import dotenv from "dotenv";
import { createRuntime } from "../../test/createRuntime";
import { MemoryManager } from "../memory";
import { getRelationship } from "../relationships";
import { type Content, type Memory } from "../types";
import { getCachedEmbedding, writeCachedEmbedding } from "../../test/cache";
import { BgentRuntime } from "../runtime";
import { addLore, getLore } from "../lore";
import { composeContext } from "../context";
import { requestHandlerTemplate } from "../templates";

dotenv.config();
describe("Lore", () => {
const zeroUuid: UUID = "00000000-0000-0000-0000-000000000000";
let runtime: BgentRuntime;
let user: User;
let room_id: UUID;

beforeAll(async () => {
const result = await createRuntime({
env: process.env as Record<string, string>,
});
runtime = result.runtime;
user = result.session.user;

const data = await getRelationship({
runtime,
userA: user?.id as UUID,
userB: zeroUuid,
});

room_id = data?.room_id;
});

beforeEach(async () => {
await runtime.loreManager.removeAllMemoriesByUserIds([
user?.id as UUID,
zeroUuid,
]);
});

afterAll(async () => {
await runtime.loreManager.removeAllMemoriesByUserIds([
user?.id as UUID,
zeroUuid,
]);
});

test("Add and get lore", async () => {
const content: Content = { content: "Test", source: "/Test.md" };
await addLore({
runtime,
source: "/Test.md",
content: "Test",
user_id: user.id as UUID,
room_id,
});

const lore = await getLore({
runtime,
message: "Test",
});

expect(lore).toHaveLength(1);
expect(lore[0].content).toEqual(content);
});

// TODO: Test that the lore is in the context of the agent

test("Test that lore is in the context of the agent", async () => {
await addLore({
runtime,
source: "Test Lore Source",
content: "Test Lore Content",
user_id: user.id as UUID,
room_id,
});

const message = {
senderId: user.id as UUID,
agentId: zeroUuid,
userIds: [user.id as UUID, zeroUuid],
content: "Test",
room_id,
};

const state = await runtime.composeState(message);

// expect state.lore to exist
expect(state.lore).toHaveLength(1);

const context = composeContext({
state,
template: requestHandlerTemplate,
});

// expect context to contain 'Test Lore Source' and 'Test Lore Content'
expect(context).toContain("Test Lore Source");
expect(context).toContain("Test Lore Content");
});
});
26 changes: 25 additions & 1 deletion src/lib/actions/continue.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { composeContext } from "../context";
import logger from "../logger";
import { embeddingZeroVector } from "../memory";
import { type BgentRuntime } from "../runtime";
import { requestHandlerTemplate } from "../templates";
import { Content, State, type Action, type Message } from "../types";
Expand Down Expand Up @@ -110,7 +111,30 @@ export default {
return responseContent;
}

await runtime.saveResponseMessage(message, state, responseContent);
const _saveResponseMessage = async (
message: Message,
state: State,
responseContent: Content,
) => {
const { agentId, userIds, room_id } = message;

responseContent.content = responseContent.content?.trim();

if (responseContent.content) {
await runtime.messageManager.createMemory({
user_ids: userIds!,
user_id: agentId!,
content: responseContent,
room_id,
embedding: embeddingZeroVector,
});
await runtime.evaluate(message, { ...state, responseContent });
} else {
console.warn("Empty response, skipping");
}
};

await _saveResponseMessage(message, state, responseContent);

// if the action is CONTINUE, check if we are over maxContinuesInARow
// if so, then we should change the action to WAIT
Expand Down
4 changes: 4 additions & 0 deletions src/lib/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import { type UUID } from "crypto";

export const zeroUuid = "00000000-0000-0000-0000-000000000000" as UUID;
export const zeroUuidPlus1 = "00000000-0000-0000-0000-000000000001" as UUID;
1 change: 1 addition & 0 deletions src/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export * from "./evaluation";
export * from "./evaluators/index";
export * from "./goals";
export * from "./logger";
export * from "./lore";
export * from "./memory";
export * from "./messages";
export * from "./relationships";
Expand Down
64 changes: 64 additions & 0 deletions src/lib/lore.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import { UUID } from "crypto";
import { zeroUuid } from "./constants";
import { type BgentRuntime } from "./runtime";
import { Content, Memory } from "./types";

export async function addLore({
runtime,
source,
content,
embedContent,
user_id = zeroUuid,
room_id = zeroUuid,
}: {
runtime: BgentRuntime;
source: string;
content: string;
embedContent?: string;
user_id?: UUID;
room_id?: UUID;
}) {
const loreManager = runtime.loreManager;

const embedding = embedContent
? await runtime.embed(embedContent)
: await runtime.embed(content);

await loreManager.createMemory({
user_id,
user_ids: [user_id],
content: { content, source },
room_id,
embedding: embedding,
});
}

export async function getLore({
runtime,
message,
match_threshold,
count,
}: {
runtime: BgentRuntime;
message: string;
match_threshold?: number;
count?: number;
}) {
const loreManager = runtime.loreManager;
const embedding = await runtime.embed(message);
const lore = await loreManager.searchMemoriesByEmbedding(embedding, {
userIds: [zeroUuid],
match_threshold,
count,
});
return lore;
}

export const formatLore = (lore: Memory[]) => {
const messageStrings = lore.reverse().map((fragment: Memory) => {
const content = fragment.content as Content;
return `${content.content}\n${content.source ? " (Source: " + content.source + ")" : ""}`;
});
const finalMessageStrings = messageStrings.join("\n");
return finalMessageStrings;
};
114 changes: 62 additions & 52 deletions src/lib/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import {
getMessageActors,
} from "./messages";
import { type Actor, /*type Goal,*/ type Memory } from "./types";
import { getLore } from "./lore";
export interface AgentRuntimeOpts {
recentMessageCount?: number; // number of messages to hold in the recent message cache
token: string; // JWT token, can be a JWT token if outside worker, or an OpenAI token if inside worker
Expand Down Expand Up @@ -69,6 +70,11 @@ export class BgentRuntime {
tableName: "summarizations",
});

loreManager: MemoryManager = new MemoryManager({
runtime: this,
tableName: "lore",
});

actions: Action[] = [];
evaluators: Evaluator[] = [];

Expand Down Expand Up @@ -201,21 +207,37 @@ export class BgentRuntime {
}

async handleRequest(message: Message, state?: State) {
console.log("handleRequest", message, state);
await this.saveRequestMessage(message, state as State);
console.log("handleRequest", message, state);
const _saveRequestMessage = async (message: Message, state: State) => {
const { content: senderContent, senderId, userIds, room_id } = message;

const _senderContent = (
(senderContent as Content).content ?? senderContent
)?.trim();
if (_senderContent) {
await this.messageManager.createMemory({
user_ids: userIds!,
user_id: senderId!,
content: {
content: _senderContent,
action: (message.content as Content)?.action ?? "null",
},
room_id,
embedding: embeddingZeroVector,
});
await this.evaluate(message, state);
}
};

await _saveRequestMessage(message, state as State);
if (!state) {
state = (await this.composeState(message)) as State;
}
console.log("handleRequest", message, state);

const context = composeContext({
state,
template: requestHandlerTemplate,
});

console.log("context", context);

if (this.debugMode) {
logger.log(context, {
title: "Response Context",
Expand Down Expand Up @@ -266,7 +288,30 @@ export class BgentRuntime {
};
}

await this.saveResponseMessage(message, state, responseContent);
const _saveResponseMessage = async (
message: Message,
state: State,
responseContent: Content,
) => {
const { agentId, userIds, room_id } = message;

responseContent.content = responseContent.content?.trim();

if (responseContent.content) {
await this.messageManager.createMemory({
user_ids: userIds!,
user_id: agentId!,
content: responseContent,
room_id,
embedding: embeddingZeroVector,
});
await this.evaluate(message, { ...state, responseContent });
} else {
console.warn("Empty response, skipping");
}
};

await _saveResponseMessage(message, state, responseContent);
await this.processActions(message, responseContent);

return responseContent;
Expand Down Expand Up @@ -297,50 +342,6 @@ export class BgentRuntime {
await action.handler(this, message);
}

async saveRequestMessage(message: Message, state: State) {
const { content: senderContent, senderId, userIds, room_id } = message;

const _senderContent = (
(senderContent as Content).content ?? senderContent
)?.trim();
if (_senderContent) {
await this.messageManager.createMemory({
user_ids: userIds!,
user_id: senderId!,
content: {
content: _senderContent,
action: (message.content as Content)?.action ?? "null",
},
room_id,
embedding: embeddingZeroVector,
});
await this.evaluate(message, state);
}
}

async saveResponseMessage(
message: Message,
state: State,
responseContent: Content,
) {
const { agentId, userIds, room_id } = message;

responseContent.content = responseContent.content?.trim();

if (responseContent.content) {
await this.messageManager.createMemory({
user_ids: userIds!,
user_id: agentId!,
content: responseContent,
room_id,
embedding: embeddingZeroVector,
});
await this.evaluate(message, { ...state, responseContent });
} else {
console.warn("Empty response, skipping");
}
}

async evaluate(message: Message, state: State) {
const evaluatorPromises = this.evaluators.map(
async (evaluator: Evaluator) => {
Expand Down Expand Up @@ -408,7 +409,8 @@ export class BgentRuntime {
recentMessagesData,
recentSummarizationsData,
goalsData,
]: [Actor[], Memory[], Memory[], Goal[]] = await Promise.all([
loreData,
]: [Actor[], Memory[], Memory[], Goal[], Memory[]] = await Promise.all([
getMessageActors({ runtime: this, userIds: userIds! }),
this.messageManager.getMemoriesByIds({
userIds: userIds!,
Expand All @@ -425,6 +427,12 @@ export class BgentRuntime {
onlyInProgress: true,
userIds: userIds!,
}),
getLore({
runtime: this,
message: message.content as string,
count: 5,
match_threshold: 0.5,
}),
]);

const goals = await formatGoalsAsString({ goals: goalsData });
Expand Down Expand Up @@ -464,6 +472,8 @@ export class BgentRuntime {
relevantSummarizationsData,
);

const lore = formatLore(loreData);

const senderName = actorsData?.find(
(actor: Actor) => actor.id === senderId,
)?.name;
Expand Down
Loading

0 comments on commit 25fc369

Please sign in to comment.