-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added lore functionality, still need to fully test
- Loading branch information
Showing
8 changed files
with
266 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import { type User } from "@supabase/supabase-js"; | ||
import { type UUID } from "crypto"; | ||
import dotenv from "dotenv"; | ||
import { createRuntime } from "../../test/createRuntime"; | ||
import { MemoryManager } from "../memory"; | ||
import { getRelationship } from "../relationships"; | ||
import { type Content, type Memory } from "../types"; | ||
import { getCachedEmbedding, writeCachedEmbedding } from "../../test/cache"; | ||
import { BgentRuntime } from "../runtime"; | ||
import { addLore, getLore } from "../lore"; | ||
import { composeContext } from "../context"; | ||
import { requestHandlerTemplate } from "../templates"; | ||
|
||
dotenv.config(); | ||
describe("Lore", () => { | ||
const zeroUuid: UUID = "00000000-0000-0000-0000-000000000000"; | ||
let runtime: BgentRuntime; | ||
let user: User; | ||
let room_id: UUID; | ||
|
||
beforeAll(async () => { | ||
const result = await createRuntime({ | ||
env: process.env as Record<string, string>, | ||
}); | ||
runtime = result.runtime; | ||
user = result.session.user; | ||
|
||
const data = await getRelationship({ | ||
runtime, | ||
userA: user?.id as UUID, | ||
userB: zeroUuid, | ||
}); | ||
|
||
room_id = data?.room_id; | ||
}); | ||
|
||
beforeEach(async () => { | ||
await runtime.loreManager.removeAllMemoriesByUserIds([ | ||
user?.id as UUID, | ||
zeroUuid, | ||
]); | ||
}); | ||
|
||
afterAll(async () => { | ||
await runtime.loreManager.removeAllMemoriesByUserIds([ | ||
user?.id as UUID, | ||
zeroUuid, | ||
]); | ||
}); | ||
|
||
test("Add and get lore", async () => { | ||
const content: Content = { content: "Test", source: "/Test.md" }; | ||
await addLore({ | ||
runtime, | ||
source: "/Test.md", | ||
content: "Test", | ||
user_id: user.id as UUID, | ||
room_id, | ||
}); | ||
|
||
const lore = await getLore({ | ||
runtime, | ||
message: "Test", | ||
}); | ||
|
||
expect(lore).toHaveLength(1); | ||
expect(lore[0].content).toEqual(content); | ||
}); | ||
|
||
// TODO: Test that the lore is in the context of the agent | ||
|
||
test("Test that lore is in the context of the agent", async () => { | ||
await addLore({ | ||
runtime, | ||
source: "Test Lore Source", | ||
content: "Test Lore Content", | ||
user_id: user.id as UUID, | ||
room_id, | ||
}); | ||
|
||
const message = { | ||
senderId: user.id as UUID, | ||
agentId: zeroUuid, | ||
userIds: [user.id as UUID, zeroUuid], | ||
content: "Test", | ||
room_id, | ||
}; | ||
|
||
const state = await runtime.composeState(message); | ||
|
||
// expect state.lore to exist | ||
expect(state.lore).toHaveLength(1); | ||
|
||
const context = composeContext({ | ||
state, | ||
template: requestHandlerTemplate, | ||
}); | ||
|
||
// expect context to contain 'Test Lore Source' and 'Test Lore Content' | ||
expect(context).toContain("Test Lore Source"); | ||
expect(context).toContain("Test Lore Content"); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import { type UUID } from "crypto"; | ||
|
||
export const zeroUuid = "00000000-0000-0000-0000-000000000000" as UUID; | ||
export const zeroUuidPlus1 = "00000000-0000-0000-0000-000000000001" as UUID; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import { UUID } from "crypto"; | ||
import { zeroUuid } from "./constants"; | ||
import { type BgentRuntime } from "./runtime"; | ||
import { Content, Memory } from "./types"; | ||
|
||
export async function addLore({ | ||
runtime, | ||
source, | ||
content, | ||
embedContent, | ||
user_id = zeroUuid, | ||
room_id = zeroUuid, | ||
}: { | ||
runtime: BgentRuntime; | ||
source: string; | ||
content: string; | ||
embedContent?: string; | ||
user_id?: UUID; | ||
room_id?: UUID; | ||
}) { | ||
const loreManager = runtime.loreManager; | ||
|
||
const embedding = embedContent | ||
? await runtime.embed(embedContent) | ||
: await runtime.embed(content); | ||
|
||
await loreManager.createMemory({ | ||
user_id, | ||
user_ids: [user_id], | ||
content: { content, source }, | ||
room_id, | ||
embedding: embedding, | ||
}); | ||
} | ||
|
||
export async function getLore({ | ||
runtime, | ||
message, | ||
match_threshold, | ||
count, | ||
}: { | ||
runtime: BgentRuntime; | ||
message: string; | ||
match_threshold?: number; | ||
count?: number; | ||
}) { | ||
const loreManager = runtime.loreManager; | ||
const embedding = await runtime.embed(message); | ||
const lore = await loreManager.searchMemoriesByEmbedding(embedding, { | ||
userIds: [zeroUuid], | ||
match_threshold, | ||
count, | ||
}); | ||
return lore; | ||
} | ||
|
||
export const formatLore = (lore: Memory[]) => { | ||
const messageStrings = lore.reverse().map((fragment: Memory) => { | ||
const content = fragment.content as Content; | ||
return `${content.content}\n${content.source ? " (Source: " + content.source + ")" : ""}`; | ||
}); | ||
const finalMessageStrings = messageStrings.join("\n"); | ||
return finalMessageStrings; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.