diff --git a/src/agents/cj/actions/__tests__/introduce.test.ts b/src/agents/cj/actions/__tests__/introduce.test.ts index fd0915a..27bdbed 100644 --- a/src/agents/cj/actions/__tests__/introduce.test.ts +++ b/src/agents/cj/actions/__tests__/introduce.test.ts @@ -28,7 +28,7 @@ describe("Introduce Action", () => { ); const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user?.id as UUID, userB: zeroUuid, }); diff --git a/src/agents/cj/actions/introduce.ts b/src/agents/cj/actions/introduce.ts index 72ef105..16006e8 100644 --- a/src/agents/cj/actions/introduce.ts +++ b/src/agents/cj/actions/introduce.ts @@ -85,7 +85,7 @@ const handler = async (runtime: BgentRuntime, message: Message) => { if (responseData?.userA && responseData.userB) { await createRelationship({ - supabase: runtime.supabase, + runtime, userA: responseData.userA, userB: responseData.userB, }); diff --git a/src/agents/cj/evaluators/__tests__/details.test.ts b/src/agents/cj/evaluators/__tests__/details.test.ts index e8e1df5..8e73f8a 100644 --- a/src/agents/cj/evaluators/__tests__/details.test.ts +++ b/src/agents/cj/evaluators/__tests__/details.test.ts @@ -27,7 +27,7 @@ describe("User Details", () => { ); const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user?.id as UUID, userB: zeroUuid, }); diff --git a/src/agents/cj/evaluators/__tests__/profile.test.ts b/src/agents/cj/evaluators/__tests__/profile.test.ts index 98b8d61..9841d14 100644 --- a/src/agents/cj/evaluators/__tests__/profile.test.ts +++ b/src/agents/cj/evaluators/__tests__/profile.test.ts @@ -61,7 +61,7 @@ describe("User Profile", () => { ); const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user?.id as UUID, userB: zeroUuid, }); diff --git a/src/agents/cj/evaluators/profile.ts b/src/agents/cj/evaluators/profile.ts index eda2afa..7182deb 100644 --- a/src/agents/cj/evaluators/profile.ts +++ b/src/agents/cj/evaluators/profile.ts @@ -137,7 +137,7 @@ const handler = async (runtime: BgentRuntime, message: Message) => { // find the room_id in 'relationships' where user_a is the agent and user_b is the user, OR vice versa const relationshipRecord = await getRelationship({ - supabase: runtime.supabase, + runtime, userA, userB, }); diff --git a/src/lib/__tests__/actions.test.ts b/src/lib/__tests__/actions.test.ts index 404c8dd..4698b33 100644 --- a/src/lib/__tests__/actions.test.ts +++ b/src/lib/__tests__/actions.test.ts @@ -4,7 +4,6 @@ import dotenv from "dotenv"; import { createRuntime } from "../../test/createRuntime"; import { getRelationship } from "../relationships"; import { type BgentRuntime } from "../runtime"; -import { populateMemories } from "../../test/populateMemories"; dotenv.config(); @@ -13,7 +12,7 @@ const zeroUuid = "00000000-0000-0000-0000-000000000000" as UUID; describe("Actions", () => { let user: User; let runtime: BgentRuntime; - let room_id: UUID; + // let room_id: UUID; afterAll(async () => { await cleanup(); @@ -24,13 +23,13 @@ describe("Actions", () => { user = setup.session.user; runtime = setup.runtime; - const data = await getRelationship({ - supabase: runtime.supabase, - userA: user?.id as UUID, - userB: zeroUuid, - }); + // const data = await getRelationship({ + // runtime, + // userA: user?.id as UUID, + // userB: zeroUuid, + // }); - room_id = data?.room_id; + // room_id = data?.room_id; await cleanup(); }); @@ -46,9 +45,96 @@ describe("Actions", () => { ]); } - // TODO: 1. Test that actions are being loaded into context properly + describe("Actions", () => { + let user: User; + let runtime: BgentRuntime; + let room_id: UUID; + + afterAll(async () => { + await cleanup(); + }); + + beforeAll(async () => { + const setup = await createRuntime(); + user = setup.session.user; + runtime = setup.runtime; - // TODO: 2. Test that actions are validated properply, for example we know that the continue action is always valid + const data = await getRelationship({ + runtime, + userA: user?.id as UUID, + userB: zeroUuid, + }); + + room_id = data?.room_id; + + await cleanup(); + }); - // TODO 3. Test that action handlers are being called properly + async function cleanup() { + await runtime.summarizationManager.removeAllMemoriesByUserIds([ + user?.id as UUID, + zeroUuid, + ]); + await runtime.messageManager.removeAllMemoriesByUserIds([ + user?.id as UUID, + zeroUuid, + ]); + } + + // Test that actions are being loaded into context properly + test("Actions are loaded into context", async () => { + const actions = runtime.getActions(); + expect(actions).toBeDefined(); + expect(actions.length).toBeGreaterThan(0); + // Ensure the CONTINUE action is part of the loaded actions + const continueAction = actions.find( + (action) => action.name === "CONTINUE", + ); + expect(continueAction).toBeDefined(); + }); + + // Test that actions are validated properly + test("Continue action is always valid", async () => { + const continueAction = runtime + .getActions() + .find((action) => action.name === "CONTINUE"); + expect(continueAction).toBeDefined(); + if (continueAction && continueAction.validate) { + const isValid = await continueAction.validate(runtime, { + agentId: zeroUuid, + senderId: user.id as UUID, + userIds: [user.id as UUID, zeroUuid], + content: "Test message", + room_id: room_id, + }); + expect(isValid).toBeTruthy(); + } else { + throw new Error( + "Continue action or its validation function is undefined", + ); + } + }); + + // Test that action handlers are being called properly + test("Continue action handler is called", async () => { + const continueAction = runtime + .getActions() + .find((action) => action.name === "CONTINUE"); + expect(continueAction).toBeDefined(); + if (continueAction && continueAction.handler) { + const mockMessage = { + agentId: zeroUuid, + senderId: user.id as UUID, + userIds: [user.id as UUID, zeroUuid], + content: "Test message for CONTINUE action", + room_id: room_id, + }; + const response = await continueAction.handler(runtime, mockMessage); + expect(response).toBeDefined(); + // Further assertions can be made based on the expected outcome of the CONTINUE action handler + } else { + throw new Error("Continue action or its handler function is undefined"); + } + }, 20000); + }); }); diff --git a/src/lib/__tests__/memory.test.ts b/src/lib/__tests__/memory.test.ts index f3b0744..662965b 100644 --- a/src/lib/__tests__/memory.test.ts +++ b/src/lib/__tests__/memory.test.ts @@ -21,7 +21,7 @@ describe("Memory", () => { user = result.session.user; const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user?.id as UUID, userB: zeroUuid, }); @@ -208,7 +208,7 @@ describe("Memory - Basic tests", () => { user = result.session.user; const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user?.id as UUID, userB: zeroUuid, }); @@ -298,7 +298,7 @@ describe("Memory - Extended Tests", () => { user = result.session.user; const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user.id as UUID, userB: zeroUuid, }); diff --git a/src/lib/__tests__/messages.test.ts b/src/lib/__tests__/messages.test.ts index 44ef9f4..81d89c0 100644 --- a/src/lib/__tests__/messages.test.ts +++ b/src/lib/__tests__/messages.test.ts @@ -18,14 +18,14 @@ describe("Messages Library", () => { runtime = setup.runtime; user = setup.session.user; actors = await getMessageActors({ - supabase: runtime.supabase, + runtime, userIds: [user.id as UUID, "00000000-0000-0000-0000-000000000000"], }); }); test("getMessageActors should return actors based on given userIds", async () => { const result = await getMessageActors({ - supabase: runtime.supabase, + runtime, userIds: [user.id as UUID, "00000000-0000-0000-0000-000000000000"], }); expect(result.length).toBeGreaterThan(0); diff --git a/src/lib/__tests__/relationships.test.ts b/src/lib/__tests__/relationships.test.ts index 20c99ec..aa0aa60 100644 --- a/src/lib/__tests__/relationships.test.ts +++ b/src/lib/__tests__/relationships.test.ts @@ -28,7 +28,7 @@ describe("Relationships Module", () => { const userB = zeroUuid; const relationship = await createRelationship({ - supabase: runtime.supabase, + runtime, userA, userB, }); @@ -40,10 +40,10 @@ describe("Relationships Module", () => { const userA = user?.id as UUID; const userB = zeroUuid; - await createRelationship({ supabase: runtime.supabase, userA, userB }); + await createRelationship({ runtime, userA, userB }); const relationship = await getRelationship({ - supabase: runtime.supabase, + runtime, userA, userB, }); @@ -56,10 +56,10 @@ describe("Relationships Module", () => { const userA = user?.id as UUID; const userB = zeroUuid; - await createRelationship({ supabase: runtime.supabase, userA, userB }); + await createRelationship({ runtime, userA, userB }); const relationships = await getRelationships({ - supabase: runtime.supabase, + runtime, userId: userA, }); expect(relationships).toBeDefined(); diff --git a/src/lib/__tests__/runtime.test.ts b/src/lib/__tests__/runtime.test.ts index a4aa9aa..680d2a9 100644 --- a/src/lib/__tests__/runtime.test.ts +++ b/src/lib/__tests__/runtime.test.ts @@ -53,7 +53,7 @@ describe("Agent Runtime", () => { user = result.session.user; const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user?.id as UUID, userB: zeroUuid, }); diff --git a/src/lib/actions/__tests__/continue.test.ts b/src/lib/actions/__tests__/continue.test.ts index 42f4363..111b282 100644 --- a/src/lib/actions/__tests__/continue.test.ts +++ b/src/lib/actions/__tests__/continue.test.ts @@ -49,7 +49,7 @@ describe("User Profile", () => { runtime = setup.runtime; const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user.id, userB: zeroUuid, }); diff --git a/src/lib/actions/__tests__/ignore.test.ts b/src/lib/actions/__tests__/ignore.test.ts index 492668c..d22454e 100644 --- a/src/lib/actions/__tests__/ignore.test.ts +++ b/src/lib/actions/__tests__/ignore.test.ts @@ -32,7 +32,7 @@ describe("Ignore action tests", () => { runtime = setup.runtime; const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user?.id as UUID, userB: zeroUuid, }); diff --git a/src/lib/actions/__tests__/wait.test.ts b/src/lib/actions/__tests__/wait.test.ts index e7bdc15..f9750ae 100644 --- a/src/lib/actions/__tests__/wait.test.ts +++ b/src/lib/actions/__tests__/wait.test.ts @@ -32,7 +32,7 @@ describe("Wait Action Behavior", () => { runtime = setup.runtime; const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user?.id as UUID, userB: zeroUuid, }); diff --git a/src/lib/evaluators/__tests__/summarization.test.ts b/src/lib/evaluators/__tests__/summarization.test.ts index af32c0e..b6c193b 100644 --- a/src/lib/evaluators/__tests__/summarization.test.ts +++ b/src/lib/evaluators/__tests__/summarization.test.ts @@ -34,7 +34,7 @@ describe("Factual Summarization", () => { runtime = setup.runtime; const data = await getRelationship({ - supabase: runtime.supabase, + runtime, userA: user?.id as UUID, userB: zeroUuid, }); diff --git a/src/lib/evaluators/summarization.ts b/src/lib/evaluators/summarization.ts index fe92179..5f78c53 100644 --- a/src/lib/evaluators/summarization.ts +++ b/src/lib/evaluators/summarization.ts @@ -147,8 +147,7 @@ async function handler(runtime: BgentRuntime, message: Message) { const { userIds, senderId, agentId, room_id } = state; - const actors = - (await getMessageActors({ supabase: runtime.supabase, userIds })) ?? []; + const actors = (await getMessageActors({ runtime, userIds })) ?? []; const senderName = actors?.find( (actor: Actor) => actor.id === senderId, diff --git a/src/lib/goals.ts b/src/lib/goals.ts index 33c4a64..b07686d 100644 --- a/src/lib/goals.ts +++ b/src/lib/goals.ts @@ -1,26 +1,29 @@ -import { type SupabaseClient } from "@supabase/supabase-js"; -import { type Goal, type Objective } from "./types"; import { type UUID } from "crypto"; +import { BgentRuntime } from "./runtime"; +import { type Goal, type Objective } from "./types"; export const getGoals = async ({ - supabase, + runtime, userIds, userId = null, onlyInProgress = true, count = 5, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; userIds: string[]; userId?: string | null; onlyInProgress?: boolean; count?: number; }) => { - const { data: goals, error } = await supabase.rpc("get_goals_by_user_ids", { - query_user_ids: userIds, - query_user_id: userId, - only_in_progress: onlyInProgress, - row_count: count, - }); + const { data: goals, error } = await runtime.supabase.rpc( + "get_goals_by_user_ids", + { + query_user_ids: userIds, + query_user_id: userId, + only_in_progress: onlyInProgress, + row_count: count, + }, + ); if (error) { throw new Error(error.message); @@ -41,16 +44,16 @@ export const formatGoalsAsString = async ({ goals }: { goals: Goal[] }) => { }; export const updateGoals = async ({ - supabase, + runtime, userIds, goals, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; userIds: UUID[]; goals: Goal[]; }) => { for (const goal of goals) { - await supabase + await runtime.supabase .from("goals") .update(goal) .match({ id: goal.id }) @@ -59,17 +62,17 @@ export const updateGoals = async ({ }; export const createGoal = async ({ - supabase, + runtime, goal, userIds, userId, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; goal: Goal; userIds: string[]; userId: string; }) => { - const { error } = await supabase + const { error } = await runtime.supabase .from("goals") .upsert({ ...goal, user_ids: userIds, user_id: userId }); @@ -79,38 +82,41 @@ export const createGoal = async ({ }; export const cancelGoal = async ({ - supabase, + runtime, goalId, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; goalId: UUID; }) => { - await supabase + await runtime.supabase .from("goals") .update({ status: "FAILED" }) .match({ id: goalId }); }; export const finishGoal = async ({ - supabase, + runtime, goalId, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; goalId: UUID; }) => { - await supabase.from("goals").update({ status: "DONE" }).match({ id: goalId }); + await runtime.supabase + .from("goals") + .update({ status: "DONE" }) + .match({ id: goalId }); }; export const finishGoalObjective = async ({ - supabase, + runtime, goalId, objectiveId, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; goalId: UUID; objectiveId: string; }) => { - const { data: goal, error } = await supabase + const { data: goal, error } = await runtime.supabase .from("goals") .select("*") .match({ id: goalId }) @@ -127,7 +133,7 @@ export const finishGoalObjective = async ({ return objective; }); - await supabase + await runtime.supabase .from("goals") .update({ objectives: updatedObjectives }) .match({ id: goalId }); diff --git a/src/lib/messages.ts b/src/lib/messages.ts index 6d894f9..7ad3afb 100644 --- a/src/lib/messages.ts +++ b/src/lib/messages.ts @@ -1,15 +1,15 @@ -import { type SupabaseClient } from "@supabase/supabase-js"; import { type UUID } from "crypto"; +import { BgentRuntime } from "./runtime"; import { type Actor, type Content, type Memory } from "./types"; export async function getMessageActors({ - supabase, + runtime, userIds, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; userIds: UUID[]; }) { - const response = await supabase + const response = await runtime.supabase .from("accounts") .select("*") .in("id", userIds); diff --git a/src/lib/relationships.ts b/src/lib/relationships.ts index e508d67..859ff6f 100644 --- a/src/lib/relationships.ts +++ b/src/lib/relationships.ts @@ -1,17 +1,17 @@ -import { type SupabaseClient } from "@supabase/supabase-js"; import { type UUID } from "crypto"; +import { type BgentRuntime } from "./runtime"; import { type Relationship } from "./types"; export async function createRelationship({ - supabase, + runtime, userA, userB, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; userA: UUID; userB: UUID; }): Promise { - const { error } = await supabase.from("relationships").upsert({ + const { error } = await runtime.supabase.from("relationships").upsert({ user_a: userA, user_b: userB, user_id: userA, @@ -25,15 +25,15 @@ export async function createRelationship({ } export async function getRelationship({ - supabase, + runtime, userA, userB, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; userA: string; userB: string; }) { - const { data, error } = await supabase.rpc("get_relationship", { + const { data, error } = await runtime.supabase.rpc("get_relationship", { usera: userA, userb: userB, }); @@ -46,13 +46,13 @@ export async function getRelationship({ } export async function getRelationships({ - supabase, + runtime, userId, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; userId: string; }) { - const { data, error } = await supabase + const { data, error } = await runtime.supabase .from("relationships") .select("*") .or(`user_a.eq.${userId},user_b.eq.${userId}`) @@ -66,13 +66,13 @@ export async function getRelationships({ } export async function formatRelationships({ - supabase, + runtime, userId, }: { - supabase: SupabaseClient; + runtime: BgentRuntime; userId: string; }) { - const relationships = await getRelationships({ supabase, userId }); + const relationships = await getRelationships({ runtime, userId }); const formattedRelationships = relationships.map( (relationship: Relationship) => { diff --git a/src/lib/runtime.ts b/src/lib/runtime.ts index b3b7c9d..1d31101 100644 --- a/src/lib/runtime.ts +++ b/src/lib/runtime.ts @@ -415,7 +415,6 @@ export class BgentRuntime { async composeState(message: Message) { const { senderId, agentId, userIds, room_id } = message; - const { supabase } = this; const recentMessageCount = this.getRecentMessageCount(); const recentSummarizationsCount = this.getRecentMessageCount() / 2; const relevantSummarizationsCount = this.getRecentMessageCount() / 2; @@ -426,7 +425,7 @@ export class BgentRuntime { recentSummarizationsData, goalsData, ]: [Actor[], Memory[], Memory[], Goal[]] = await Promise.all([ - getMessageActors({ supabase, userIds: userIds! }), + getMessageActors({ runtime: this, userIds: userIds! }), this.messageManager.getMemoriesByIds({ userIds: userIds!, count: recentMessageCount, @@ -437,7 +436,7 @@ export class BgentRuntime { count: recentSummarizationsCount, }), getGoals({ - supabase, + runtime: this, count: 10, onlyInProgress: true, userIds: userIds!,