Skip to content

Commit

Permalink
Many changes and fixes, tests almost passing
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Mar 18, 2024
1 parent 502fcfc commit f6b7e52
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 163 deletions.
41 changes: 16 additions & 25 deletions src/lib/__tests__/goals.test.ts
Original file line number Diff line number Diff line change
@@ -1,52 +1,38 @@
import { type User } from "../../test/types";
import { type UUID } from "crypto";
import dotenv from "dotenv";
import { createRuntime } from "../../test/createRuntime";
import { type User } from "../../test/types";
import { zeroUuid } from "../constants";
import { createGoal, getGoals, updateGoal } from "../goals";
import { BgentRuntime } from "../runtime";
import { GoalStatus, type Goal } from "../types";
import { getRelationship } from "../relationships";
import { zeroUuid } from "../constants";

dotenv.config({ path: ".dev.vars" });
describe("Goals", () => {
let runtime: BgentRuntime;
let user: User;
let room_id = beforeAll(async () => {
beforeAll(async () => {
const result = await createRuntime({
env: process.env as Record<string, string>,
});
runtime = result.runtime;
user = result.session.user;
const data = await getRelationship({
runtime,
userA: user?.id as UUID,
userB: zeroUuid,
});

if (!data) {
throw new Error("Relationship not found");
}

room_id = data.room_id;

await runtime.databaseAdapter.removeAllMemoriesByRoomId(room_id, "goals");
await runtime.databaseAdapter.removeAllMemoriesByRoomId(zeroUuid, "goals");
});

beforeEach(async () => {
await runtime.databaseAdapter.removeAllMemoriesByRoomId(room_id, "goals");
await runtime.databaseAdapter.removeAllMemoriesByRoomId(zeroUuid, "goals");
});

afterAll(async () => {
await runtime.databaseAdapter.removeAllMemoriesByRoomId(room_id, "goals");
await runtime.databaseAdapter.removeAllMemoriesByRoomId(zeroUuid, "goals");
});

// TODO: Write goal tests here
test("createGoal - successfully creates a new goal", async () => {
const newGoal: Goal = {
name: "Test Create Goal",
status: GoalStatus.IN_PROGRESS,
room_id,
room_id: zeroUuid,
user_id: user?.id as UUID,
objectives: [
{
Expand All @@ -56,6 +42,8 @@ describe("Goals", () => {
],
};

console.log("newGoal", newGoal);

await createGoal({
runtime,
goal: newGoal,
Expand All @@ -64,9 +52,11 @@ describe("Goals", () => {
// Verify the goal is created in the database
const goals = await getGoals({
runtime,
room_id,
userId: user?.id as UUID,
room_id: zeroUuid,
onlyInProgress: false,
});

const createdGoal = goals.find((goal: Goal) => goal.name === newGoal.name);

expect(createdGoal).toBeDefined();
Expand All @@ -79,7 +69,7 @@ describe("Goals", () => {
const newGoal: Goal = {
name: "Test Create Goal",
status: GoalStatus.IN_PROGRESS,
room_id,
room_id: zeroUuid,
user_id: user?.id as UUID,
objectives: [
{
Expand All @@ -97,9 +87,10 @@ describe("Goals", () => {
// retrieve the goal from the database
let goals = await getGoals({
runtime,
room_id,
room_id: zeroUuid,
onlyInProgress: false,
});
console.log("goals", goals);
const existingGoal = goals.find(
(goal: Goal) => goal.name === newGoal.name,
) as Goal;
Expand All @@ -112,7 +103,7 @@ describe("Goals", () => {
// Verify the goal's status is updated in the database
goals = await getGoals({
runtime,
room_id,
room_id: zeroUuid,
onlyInProgress: false,
});

Expand Down
24 changes: 23 additions & 1 deletion src/lib/__tests__/messages.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { formatActors, formatMessages, getActorDetails } from "../messages";
import { type BgentRuntime } from "../runtime";
import { type Actor, type Content, type Memory } from "../types";
import { formatFacts } from "../evaluators/fact";
import { createRelationship, getRelationship } from "../relationships";
import { zeroUuid } from "../constants";

describe("Messages Library", () => {
let runtime: BgentRuntime, user: User, actors: Actor[];
Expand All @@ -22,9 +24,29 @@ describe("Messages Library", () => {
});

test("getActorDetails should return actors based on given room_id", async () => {
// create a room and add a user to it
const userA = user?.id as UUID;
const userB = zeroUuid;

await createRelationship({
runtime,
userA,
userB,
});

const relationship = await getRelationship({
runtime,
userA,
userB,
});

if (!relationship?.room_id) {
throw new Error("Room not found");
}

const result = await getActorDetails({
runtime,
room_id: "00000000-0000-0000-0000-000000000000",
room_id: relationship?.room_id as UUID,
});
expect(result.length).toBeGreaterThan(0);
result.forEach((actor: Actor) => {
Expand Down
2 changes: 1 addition & 1 deletion src/lib/__tests__/providers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ describe("TestProvider", () => {
providers: [TestProvider],
});
runtime = setup.runtime;
room_id = "some-room-id" as UUID;
room_id = zeroUuid;
});

test("TestProvider should return 'Hello Test'", async () => {
Expand Down
53 changes: 37 additions & 16 deletions src/lib/__tests__/runtime.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dotenv from "dotenv";
import { createRuntime } from "../../test/createRuntime";
import { type UUID } from "crypto";
import { getRelationship } from "../relationships";
import { createRelationship, getRelationship } from "../relationships";
import { getCachedEmbedding, writeCachedEmbedding } from "../../test/cache";
import { BgentRuntime } from "../runtime";
import { type User } from "../../test/types";
Expand All @@ -13,7 +13,7 @@ dotenv.config({ path: ".dev.vars" });
describe("Agent Runtime", () => {
let user: User;
let runtime: BgentRuntime;
let room_id: UUID;
let room_id: UUID = zeroUuid;

// Helper function to clear memories
async function clearMemories() {
Expand All @@ -31,17 +31,21 @@ describe("Agent Runtime", () => {
];

for (const { userId, content } of memories) {
const embedding = getCachedEmbedding(content.content);
const memory = await runtime.messageManager.addEmbeddingToMemory({
user_id: userId,
content,
room_id,
embedding,
});
if (!embedding) {
writeCachedEmbedding(content.content, memory.embedding as number[]);
try {
const embedding = getCachedEmbedding(content.content);
const memory = await runtime.messageManager.addEmbeddingToMemory({
user_id: userId,
content,
room_id,
embedding,
});
if (!embedding) {
writeCachedEmbedding(content.content, memory.embedding as number[]);
}
await runtime.messageManager.createMemory(memory);
} catch (error) {
console.error("Error creating memory", error);
}
await runtime.messageManager.createMemory(memory);
}
}

Expand All @@ -54,17 +58,28 @@ describe("Agent Runtime", () => {
runtime = result.runtime;
user = result.session.user;

const data = await getRelationship({
let data = await getRelationship({
runtime,
userA: user?.id as UUID,
userB: zeroUuid,
});

if (!data) {
throw new Error("Relationship not found");
await createRelationship({
runtime,
userA: user?.id as UUID,
userB: zeroUuid,
});
data = await getRelationship({
runtime,
userA: user?.id as UUID,
userB: zeroUuid,
});
}

room_id = data?.room_id;
console.log("data", data);

room_id = data?.room_id as UUID;
await clearMemories(); // Clear memories before each test
});

Expand All @@ -84,7 +99,13 @@ describe("Agent Runtime", () => {
});

test("Memory lifecycle: create, retrieve, and destroy", async () => {
await createMemories(); // Create new memories
try {
await createMemories(); // Create new memories
} catch (error) {
console.error("Error creating memories", error);
}

console.log("room_id", room_id);

const message: Message = {
userId: user.id as UUID,
Expand Down
37 changes: 23 additions & 14 deletions src/lib/adapters/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
}): Promise<Memory[]> {
let sql = `
SELECT *
FROM ${params.tableName}
WHERE room_id = ? AND vss_search(embedding, ?)
FROM memories
WHERE type = ${params.tableName} AND room_id = ? AND vss_search(embedding, ?)
ORDER BY vss_search(embedding, ?) DESC
LIMIT ?
`;
Expand Down Expand Up @@ -88,14 +88,16 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
}): Promise<[]> {
const sql = `
SELECT *
FROM ${opts.query_table_name}
WHERE vss_search(${opts.query_field_name}, ?)
FROM memories
WHERE type = ?
AND vss_search(${opts.query_field_name}, ?)
ORDER BY vss_search(${opts.query_field_name}, ?) DESC
LIMIT ?
`;
return this.db
.prepare(sql)
.all(
JSON.stringify(opts.query_table_name),
JSON.stringify(opts.query_input),
JSON.stringify(opts.query_input),
opts.query_match_count,
Expand Down Expand Up @@ -134,19 +136,26 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
unique?: boolean;
tableName: string;
}): Promise<Memory[]> {
let sql = `SELECT * FROM ${params.tableName} WHERE room_id = ?`;
const queryParams = [JSON.stringify(params.room_id)];
if (!params.tableName) {
throw new Error("tableName is required");
}
if (!params.room_id) {
throw new Error("room_id is required");
}
let sql = `SELECT * FROM memories WHERE type = ${params.tableName} AND room_id = ${params.room_id}`;

if (params.unique) {
sql += " AND unique = 1";
}

if (params.count) {
sql += " LIMIT ?";
queryParams.push(params.count.toString());
sql += ` LIMIT ${params.count}`;
}

return this.db.prepare(sql).all(...queryParams) as Memory[];
console.log("sql");
console.log(sql);

return this.db.prepare(sql).all() as Memory[];
}

async searchMemoriesByEmbedding(
Expand All @@ -161,8 +170,8 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
): Promise<Memory[]> {
let sql = `
SELECT *
FROM ${params.tableName}
WHERE vss_search(embedding, ?)
FROM memories
WHERE type = ${params.tableName} AND vss_search(embedding, ?)
ORDER BY vss_search(embedding, ?) DESC
`;
const queryParams = [JSON.stringify(embedding), JSON.stringify(embedding)];
Expand All @@ -189,7 +198,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
tableName: string,
unique = false,
): Promise<void> {
const sql = `INSERT INTO memories (id, type, created_at, content, embedding, user_id, room_id, unique) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`;
const sql = `INSERT INTO memories (id, type, created_at, content, embedding, user_id, room_id, \`unique\`) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`;
this.db
.prepare(sql)
.run(
Expand All @@ -213,7 +222,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
room_id: UUID,
tableName: string,
): Promise<void> {
const sql = `DELETE FROM memories WHERE tableName = ? AND room_id = ?`;
const sql = `DELETE FROM memories WHERE type = ? AND room_id = ?`;
this.db.prepare(sql).run(tableName, JSON.stringify(room_id));
}

Expand All @@ -226,7 +235,7 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
throw new Error("tableName is required");
}

let sql = `SELECT COUNT(*) as count FROM memories WHERE tableName = ? AND room_id = ?`;
let sql = `SELECT COUNT(*) as count FROM memories WHERE type = ? AND room_id = ?`;
const queryParams = [tableName, JSON.stringify(room_id)] as string[];

if (unique) {
Expand Down
Loading

0 comments on commit f6b7e52

Please sign in to comment.