Skip to content

Commit

Permalink
Add provider abstraction and time provider
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Feb 25, 2024
1 parent 779bbe1 commit 9853678
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/lib/__tests__/evaluation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import dotenv from "dotenv";
import { createRuntime } from "../../test/createRuntime";
import { TEST_EVALUATOR, TEST_EVALUATOR_FAIL } from "../../test/testEvaluator";
import { composeContext } from "../context";
import { evaluationTemplate } from "../evaluation";
import { evaluationTemplate } from "../evaluators";
import summarization from "../evaluators/summarization";
import { getRelationship } from "../relationships";
import { BgentRuntime } from "../runtime";
Expand Down
4 changes: 2 additions & 2 deletions src/lib/__tests__/lore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { createRuntime } from "../../test/createRuntime";
import { composeContext } from "../context";
import { addLore, getLore } from "../lore";
import { BgentRuntime } from "../runtime";
import { requestHandlerTemplate } from "../templates";
import { messageHandlerTemplate } from "../templates";
import { type Content } from "../types";

dotenv.config({ path: ".dev.vars" });
Expand Down Expand Up @@ -77,7 +77,7 @@ describe("Lore", () => {

const context = composeContext({
state,
template: requestHandlerTemplate,
template: messageHandlerTemplate,
});

// expect context to contain 'Test Lore Source' and 'Test Lore Content'
Expand Down
4 changes: 2 additions & 2 deletions src/lib/actions/continue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { composeContext } from "../context";
import logger from "../logger";
import { embeddingZeroVector } from "../memory";
import { type BgentRuntime } from "../runtime";
import { requestHandlerTemplate } from "../templates";
import { messageHandlerTemplate } from "../templates";
import { Content, State, type Action, type Message } from "../types";
import { parseJSONObjectFromText } from "../utils";

Expand Down Expand Up @@ -42,7 +42,7 @@ export default {

const context = composeContext({
state,
template: requestHandlerTemplate,
template: messageHandlerTemplate,
});

if (runtime.debugMode) {
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion src/lib/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
export * from "./actions";
export * from "./actions/index";
export * from "./context";
export * from "./evaluation";
export * from "./evaluators";
export * from "./evaluators/index";
export * from "./goals";
export * from "./lore";
export * from "./memory";
export * from "./messages";
export * from "./providers";
export * from "./relationships";
export * from "./runtime";
export * from "./types";
20 changes: 20 additions & 0 deletions src/lib/providers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import time from "./providers/time";
import { BgentRuntime } from "./runtime";
import { type Message, type Provider } from "./types";

export const defaultProviders: Provider[] = [time];

/**
* Formats provider outputs into a string which can be injected into the context.
* @param providers - An array of evaluator objects.
* @returns A string that concatenates the outputs of each provider.
*/
export async function getProviders(runtime: BgentRuntime, message: Message) {
const providerResults = await Promise.all(
runtime.providers.map(async (provider) => {
return await provider.get(runtime, message);
}),
);

return providerResults.join("\n");
}
89 changes: 89 additions & 0 deletions src/lib/providers/__tests__/time.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import { type UUID } from "crypto";
import dotenv from "dotenv";
import { createRuntime } from "../../../test/createRuntime";
import { composeContext } from "../../context";
import { BgentRuntime } from "../../runtime";

import { type Message, type State } from "../../types";
import timeProvider from "../time";

dotenv.config({ path: ".dev.vars" });

const zeroUuid = "00000000-0000-0000-0000-000000000000" as UUID;

describe("Time Provider", () => {
let runtime: BgentRuntime;
let user: { id: UUID };
let room_id: UUID;

beforeAll(async () => {
const setup = await createRuntime({
env: process.env as Record<string, string>,
providers: [timeProvider],
});
runtime = setup.runtime;
user = { id: setup.session.user?.id as UUID };
room_id = "some-room-id" as UUID; // Assume room_id is fetched or set up in your environment
});

test("Time provider should return the current time in the correct format", async () => {
const message: Message = {
senderId: user.id,
agentId: zeroUuid,
userIds: [user.id, zeroUuid],
content: { content: "" },
room_id: room_id,
};

const currentTimeResponse = await timeProvider.get(
runtime,
message,
{} as State,
);
expect(currentTimeResponse).toMatch(
/^The current time is: \d{1,2}:\d{2}:\d{2}\s?(AM|PM)$/,
);
});

test("Time provider should be integrated in the state and context correctly", async () => {
const message: Message = {
senderId: user.id,
agentId: zeroUuid,
userIds: [user.id, zeroUuid],
content: { content: "" },
room_id: room_id,
};

// Manually integrate the time provider's response into the state
const state = await runtime.composeState(message);

const contextTemplate = `Current Time: {{providers}}`;
const context = composeContext({
state: state,
template: contextTemplate,
});

const match = context.match(
new RegExp(
`^Current Time: The current time is: \\d{1,2}:\\d{2}:\\d{2}\\s?(AM|PM)$`,
),
);

expect(match).toBeTruthy();
});

test("Time provider should work independently", async () => {
const message: Message = {
senderId: user.id,
agentId: zeroUuid,
userIds: [user.id, zeroUuid],
content: { content: "" },
room_id: room_id,
};
const currentTimeResponse = await timeProvider.get(runtime, message);

expect(currentTimeResponse).toMatch(
/^The current time is: \d{1,2}:\d{2}:\d{2}\s?(AM|PM)$/,
);
});
});
1 change: 1 addition & 0 deletions src/lib/providers/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export { default as summarization } from "./time";
12 changes: 12 additions & 0 deletions src/lib/providers/time.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { type BgentRuntime } from "../runtime";
import { type Message, type Provider, type State } from "../types";

const time: Provider = {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
get: async (_runtime: BgentRuntime, _message: Message, _state?: State) => {
const currentTime = new Date().toLocaleTimeString("en-US");
return "The current time is: " + currentTime;
},
};

export default time;
78 changes: 51 additions & 27 deletions src/lib/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ import {
formatEvaluatorExamples,
formatEvaluatorNames,
formatEvaluators,
} from "./evaluation";
} from "./evaluators";
import logger from "./logger";
import { MemoryManager, embeddingZeroVector } from "./memory";
import { requestHandlerTemplate } from "./templates";
import { messageHandlerTemplate } from "./templates";
import {
Content,
Goal,
State,
type Action,
type Evaluator,
type Message,
Provider,
} from "./types";
import { parseJSONObjectFromText, parseJsonArrayFromText } from "./utils";

Expand All @@ -33,6 +34,7 @@ import { formatGoalsAsString, getGoals } from "./goals";
import { formatActors, formatMessages, getActorDetails } from "./messages";
import { type Actor, /*type Goal,*/ type Memory } from "./types";
import { getLore, formatLore } from "./lore";
import { defaultProviders, getProviders } from "./providers";

/**
* Represents the runtime environment for an agent, handling message processing,
Expand Down Expand Up @@ -79,6 +81,11 @@ export class BgentRuntime {
*/
evaluators: Evaluator[] = [];

/**
* Context providers used to provide context for message generation.
*/
providers: Provider[] = [];

/**
* Store messages that are sent and received by the agent.
*/
Expand Down Expand Up @@ -122,6 +129,7 @@ export class BgentRuntime {
* @param opts.flavor - Optional lore to inject into the default prompt.
* @param opts.actions - Optional custom actions.
* @param opts.evaluators - Optional custom evaluators.
* @param opts.providers - Optional context providers.
*/
constructor(opts: {
recentMessageCount?: number; // number of messages to hold in the recent message cache
Expand All @@ -132,6 +140,7 @@ export class BgentRuntime {
flavor?: string; // Optional lore to inject into the default prompt
actions?: Action[]; // Optional custom actions
evaluators?: Evaluator[]; // Optional custom evaluators
providers?: Provider[];
}) {
this.#recentMessageCount =
opts.recentMessageCount ?? this.#recentMessageCount;
Expand All @@ -152,6 +161,9 @@ export class BgentRuntime {
(opts.evaluators ?? defaultEvaluators).forEach((evaluator) => {
this.registerEvaluator(evaluator);
});
(opts.providers ?? defaultProviders).forEach((provider) => {
this.registerContextProvider(provider);
});
}

/**
Expand All @@ -178,6 +190,14 @@ export class BgentRuntime {
this.evaluators.push(evaluator);
}

/**
* Register a context provider to provide context for message generation.
* @param provider The context provider to register.
*/
registerContextProvider(provider: Provider) {
this.providers.push(provider);
}

/**
* Send a message to the OpenAI API for completion.
* @param opts - The options for the completion request.
Expand Down Expand Up @@ -322,7 +342,7 @@ export class BgentRuntime {

const context = composeContext({
state,
template: requestHandlerTemplate,
template: messageHandlerTemplate,
});

if (this.debugMode) {
Expand Down Expand Up @@ -518,30 +538,33 @@ export class BgentRuntime {
recentSummarizationsData,
goalsData,
loreData,
]: [Actor[], Memory[], Memory[], Goal[], Memory[]] = await Promise.all([
getActorDetails({ runtime: this, userIds: userIds! }),
this.messageManager.getMemoriesByIds({
userIds: userIds!,
count: recentMessageCount,
unique: false,
}),
this.summarizationManager.getMemoriesByIds({
userIds: userIds!,
count: recentSummarizationsCount,
}),
getGoals({
runtime: this,
count: 10,
onlyInProgress: true,
userIds: userIds!,
}),
getLore({
runtime: this,
message: (message.content as Content).content,
count: 5,
match_threshold: 0.5,
}),
]);
providers,
]: [Actor[], Memory[], Memory[], Goal[], Memory[], string] =
await Promise.all([
getActorDetails({ runtime: this, userIds: userIds! }),
this.messageManager.getMemoriesByIds({
userIds: userIds!,
count: recentMessageCount,
unique: false,
}),
this.summarizationManager.getMemoriesByIds({
userIds: userIds!,
count: recentSummarizationsCount,
}),
getGoals({
runtime: this,
count: 10,
onlyInProgress: true,
userIds: userIds!,
}),
getLore({
runtime: this,
message: (message.content as Content).content,
count: 5,
match_threshold: 0.5,
}),
getProviders(this, message),
]);

const goals = await formatGoalsAsString({ goals: goalsData });

Expand Down Expand Up @@ -600,6 +623,7 @@ export class BgentRuntime {
goals,
lore,
loreData,
providers,
goalsData,
flavor: this.flavor,
recentMessages,
Expand Down
4 changes: 3 additions & 1 deletion src/lib/templates.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export const requestHandlerTemplate = `{{flavor}}
export const messageHandlerTemplate = `{{flavor}}
# START MESSAGE EXAMPLES
json\`\`\`
Expand Down Expand Up @@ -39,6 +39,8 @@ json\`\`\`
{{actionNames}}
{{actions}}
{{providers}}
Current Scene Dialog:
{{recentMessages}}
Expand Down
9 changes: 9 additions & 0 deletions src/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export interface State {
actions?: string; // An optional string representation of actions and their descriptions, relevant to the current state.
actionsData?: Action[]; // An optional array of action objects relevant to the current state.
actionExamples?: string; // An optional string representation of examples of actions, for demonstration or testing.
providers?: string; // An optional string representation of available providers and their descriptions, relevant to the current state.
responseData?: Content; // An optional content object representing the agent's response in the current state.
[key: string]: unknown; // Allows for additional properties to be included dynamically.
}
Expand Down Expand Up @@ -167,3 +168,11 @@ export interface Evaluator {
name: string; // The name of the evaluator.
validate: Validator; // The function that validates whether the evaluator is applicable in the current context.
}

export interface Provider {
get: (
runtime: BgentRuntime,
message: Message,
state?: State,
) => Promise<unknown>;
}
Loading

0 comments on commit 9853678

Please sign in to comment.