Skip to content

Commit

Permalink
Add evaluation checker, fixes for if recent message count is 1
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Feb 20, 2024
1 parent c332f68 commit b87664b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 8 deletions.
51 changes: 49 additions & 2 deletions src/lib/__tests__/evaluation.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { User } from "@supabase/supabase-js";
import { UUID } from "crypto";
import dotenv from "dotenv";
import { composeContext } from "../context";
import { evaluationTemplate } from "../evaluation";
import { createRuntime } from "../../test/createRuntime";
import { TEST_EVALUATOR, TEST_EVALUATOR_FAIL } from "../../test/testEvaluator";
import { composeContext } from "../context";
import { evaluationTemplate } from "../evaluation";
import summarization from "../evaluators/summarization";
import { getRelationship } from "../relationships";
import { BgentRuntime } from "../runtime";
import { Message } from "../types";
Expand Down Expand Up @@ -106,4 +107,50 @@ describe("Evaluation Process", () => {

expect(result?.includes("TEST_EVALUATOR")).toBe(true);
});

test("Run the evaluation process", async () => {
const { runtime } = await createRuntime({
env: process.env as Record<string, string>,
evaluators: [TEST_EVALUATOR, TEST_EVALUATOR_FAIL],
});

const message: Message = {
senderId: user.id as UUID,
agentId: zeroUuid,
userIds: [user.id as UUID, zeroUuid],
content: "Please run the test evaluator",
room_id,
};

const state = await runtime.composeState(message);

const result = await runtime.evaluate(message, state);

expect(result?.includes("TEST_EVALUATOR")).toBe(true);
});

test("Test that summarization appears in evaluation handler", async () => {
const { runtime } = await createRuntime({
env: process.env as Record<string, string>,
recentMessageCount: 1,
});

const message: Message = {
senderId: user.id as UUID,
agentId: zeroUuid,
userIds: [user.id as UUID, zeroUuid],
content: "Test message for evaluation",
room_id,
};

const state = await runtime.composeState(message);
const prompt = composeContext({ state, template: evaluationTemplate });

// expect that the prompt contacts the testEvaluator name
expect(prompt).toContain(summarization.name);

// check if state.EvaluatorNames contains the testEvaluator name

expect(state.evaluatorNames).toContain(summarization.name);
});
});
14 changes: 10 additions & 4 deletions src/lib/evaluators/summarization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async function handler(runtime: BgentRuntime, message: Message) {

const actionNames = runtime.actions.map((a: Action) => a.name).join(", ");
console.log("actionNames", actionNames);

const actions = runtime.actions
.map((a: Action) => `${a.name}: ${a.description}`)
.join("\n");
Expand Down Expand Up @@ -175,11 +175,17 @@ export default {
name: "SUMMARIZE",
validate: async (
// eslint-disable-next-line @typescript-eslint/no-unused-vars
_runtime: BgentRuntime,
runtime: BgentRuntime,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
_message: Message,
message: Message,
): Promise<boolean> => {
return await Promise.resolve(true);
const messageCount = (await runtime.messageManager.countMemoriesByUserIds(
message.userIds,
)) as number;

const reflectionCount = Math.ceil(runtime.getRecentMessageCount() / 2);

return messageCount % reflectionCount === 0;
},
description:
"Extract factual information about the people in the conversation, the current events in the world, and anything else that might be important to remember.",
Expand Down
13 changes: 11 additions & 2 deletions src/lib/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ export class BgentRuntime {
const resolvedEvaluators = await Promise.all(evaluatorPromises);
const evaluatorsData = resolvedEvaluators.filter(Boolean);

// if there are no evaluators this frame, return
if (evaluatorsData.length === 0) {
return [];
}

const evaluators = formatEvaluators(evaluatorsData as Evaluator[]);
const evaluatorNames = formatEvaluatorNames(evaluatorsData as Evaluator[]);
const evaluatorConditions = formatEvaluatorConditions(
Expand Down Expand Up @@ -389,8 +394,12 @@ export class BgentRuntime {
const { senderId, agentId, userIds, room_id } = message;

const recentMessageCount = this.getRecentMessageCount();
const recentSummarizationsCount = this.getRecentMessageCount() / 2;
const relevantSummarizationsCount = this.getRecentMessageCount() / 2;
const recentSummarizationsCount = Math.ceil(
this.getRecentMessageCount() / 2,
);
const relevantSummarizationsCount = Math.ceil(
this.getRecentMessageCount() / 2,
);

const [
actorsData,
Expand Down

0 comments on commit b87664b

Please sign in to comment.