From e57676ed4662558d088082e66ab27a3c00462776 Mon Sep 17 00:00:00 2001 From: hchen Date: Mon, 14 Aug 2023 12:00:01 -0500 Subject: [PATCH] Add Agent hook. --- .../Agents/AgentHookBase.cs | 40 ++++++++++++++++ .../Agents/Enums/AgentRole.cs | 9 ++++ .../BotSharp.Abstraction/Agents/IAgentHook.cs | 28 +++++++++++ .../Agents/IAgentService.cs | 8 ++++ ...ionHookBase.cs => ConversationHookBase.cs} | 32 ++++++------- .../IConversationCompletionHook.cs | 24 ---------- .../Conversations/IConversationHook.cs | 29 +++++++++++ .../IConversationStateService.cs | 4 +- .../Conversations/Models/RoleDialogModel.cs | 11 +++++ .../Agents/Services/AgentService.LoadAgent.cs | 47 ++++++++++++++++++ .../Services/ConversationService.cs | 48 +++++++++++-------- .../Services/ConversationStateService.cs | 47 ++++++++++++++---- .../Providers/ChatCompletionProvider.cs | 7 +++ .../Controllers/WebhookController.cs | 10 +++- 14 files changed, 273 insertions(+), 71 deletions(-) create mode 100644 src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs create mode 100644 src/Infrastructure/BotSharp.Abstraction/Agents/Enums/AgentRole.cs create mode 100644 src/Infrastructure/BotSharp.Abstraction/Agents/IAgentHook.cs rename src/Infrastructure/BotSharp.Abstraction/Conversations/{ConversationCompletionHookBase.cs => ConversationHookBase.cs} (51%) delete mode 100644 src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationCompletionHook.cs create mode 100644 src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationHook.cs create mode 100644 src/Infrastructure/BotSharp.Core/Agents/Services/AgentService.LoadAgent.cs diff --git a/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs b/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs new file mode 100644 index 000000000..34a657b1e --- /dev/null +++ b/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs @@ -0,0 +1,40 @@ +namespace BotSharp.Abstraction.Agents; + +public abstract class AgentHookBase : IAgentHook +{ + protected Agent _agent; + public Agent Agent => _agent; + + public void SetAget(Agent agent) + { + _agent = agent; + } + + public virtual bool OnAgentLoading(ref string id) + { + return true; + } + + public virtual bool OnInstructionLoaded(ref string instruction) + { + _agent.Instruction = instruction; + return true; + } + + public virtual bool OnFunctionsLoaded(ref string functions) + { + _agent.Functions = functions; + return true; + } + + public virtual bool OnSamplesLoaded(ref string samples) + { + _agent.Samples = samples; + return true; + } + + public virtual Agent OnAgentLoaded() + { + return _agent; + } +} diff --git a/src/Infrastructure/BotSharp.Abstraction/Agents/Enums/AgentRole.cs b/src/Infrastructure/BotSharp.Abstraction/Agents/Enums/AgentRole.cs new file mode 100644 index 000000000..f9547d5a7 --- /dev/null +++ b/src/Infrastructure/BotSharp.Abstraction/Agents/Enums/AgentRole.cs @@ -0,0 +1,9 @@ +namespace BotSharp.Abstraction.Agents.Enums; + +public class AgentRole +{ + public const string System = "system"; + public const string Assistant = "assistant"; + public const string User = "user"; + public const string Function = "function"; +} diff --git a/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentHook.cs b/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentHook.cs new file mode 100644 index 000000000..b91fc073e --- /dev/null +++ b/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentHook.cs @@ -0,0 +1,28 @@ +namespace BotSharp.Abstraction.Agents; + +public interface IAgentHook +{ + Agent Agent { get; } + void SetAget(Agent agent); + + /// + /// Triggered before loading, you can change the returned id to switch agent. + /// + /// Agent Id + /// + bool OnAgentLoading(ref string id); + + + bool OnInstructionLoaded(ref string instruction); + + bool OnFunctionsLoaded(ref string functions); + + bool OnSamplesLoaded(ref string samples); + + /// + /// Triggered when agent is loaded completely. + /// + /// + /// + Agent OnAgentLoaded(); +} diff --git a/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentService.cs b/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentService.cs index 6051dd2ce..9f119bb13 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentService.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentService.cs @@ -7,6 +7,14 @@ public interface IAgentService { Task CreateAgent(Agent agent); Task> GetAgents(); + + /// + /// Load agent configurations and triggher hooks + /// + /// + /// + Task LoadAgent(string id); + Task GetAgent(string id); Task DeleteAgent(string id); Task UpdateAgent(Agent agent); diff --git a/src/Infrastructure/BotSharp.Abstraction/Conversations/ConversationCompletionHookBase.cs b/src/Infrastructure/BotSharp.Abstraction/Conversations/ConversationHookBase.cs similarity index 51% rename from src/Infrastructure/BotSharp.Abstraction/Conversations/ConversationCompletionHookBase.cs rename to src/Infrastructure/BotSharp.Abstraction/Conversations/ConversationHookBase.cs index 2995a9db6..4e8de8b35 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Conversations/ConversationCompletionHookBase.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Conversations/ConversationHookBase.cs @@ -1,9 +1,8 @@ using BotSharp.Abstraction.Conversations.Models; -using BotSharp.Abstraction.MLTasks; namespace BotSharp.Abstraction.Conversations; -public abstract class ConversationCompletionHookBase : IConversationCompletionHook +public abstract class ConversationHookBase : IConversationHook { protected Agent _agent; public Agent Agent => _agent; @@ -14,44 +13,39 @@ public abstract class ConversationCompletionHookBase : IConversationCompletionHo protected List _dialogs; public List Dialogs => _dialogs; - protected IChatCompletion _chatCompletion; - public IChatCompletion ChatCompletion => _chatCompletion; - - public IConversationCompletionHook SetAgent(Agent agent) + public IConversationHook SetAgent(Agent agent) { _agent = agent; return this; } - public IConversationCompletionHook SetConversation(Conversation conversation) + public IConversationHook SetConversation(Conversation conversation) { _conversation = conversation; return this; } - public IConversationCompletionHook SetDialogs(List dialogs) + public virtual Task OnStateLoaded(ConversationState state) { - _dialogs = dialogs; - return this; + return Task.CompletedTask; } - public IConversationCompletionHook SetChatCompletion(IChatCompletion chatCompletion) + public virtual Task OnStateChanged(string name, string preValue, string currentValue) { - _chatCompletion = chatCompletion; - return this; + return Task.CompletedTask; } - public virtual Task OnStateLoaded(ConversationState state, Action? onAgentSwitched = null) + public virtual Task BeforeCompletion() { return Task.CompletedTask; } - public virtual Task BeforeCompletion() + public virtual Task OnFunctionExecuting(RoleDialogModel message) { return Task.CompletedTask; } - public virtual Task OnFunctionExecuting(string name, string args) + public virtual Task OnFunctionExecuted(RoleDialogModel message) { return Task.CompletedTask; } @@ -60,4 +54,10 @@ public virtual Task AfterCompletion(RoleDialogModel message) { return Task.CompletedTask; } + + public virtual Task OnDialogsLoaded(List dialogs) + { + _dialogs = dialogs; + return Task.CompletedTask; + } } diff --git a/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationCompletionHook.cs b/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationCompletionHook.cs deleted file mode 100644 index 056423255..000000000 --- a/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationCompletionHook.cs +++ /dev/null @@ -1,24 +0,0 @@ -using BotSharp.Abstraction.Conversations.Models; -using BotSharp.Abstraction.MLTasks; - -namespace BotSharp.Abstraction.Conversations; - -public interface IConversationCompletionHook -{ - Agent Agent { get; } - IConversationCompletionHook SetAgent(Agent agent); - - Conversation Conversation { get; } - IConversationCompletionHook SetConversation(Conversation conversation); - - List Dialogs { get; } - IConversationCompletionHook SetDialogs(List dialogs); - - IChatCompletion ChatCompletion { get; } - IConversationCompletionHook SetChatCompletion(IChatCompletion chatCompletion); - - Task OnStateLoaded(ConversationState state, Action? onAgentSwitched = null); - Task BeforeCompletion(); - Task OnFunctionExecuting(string name, string args); - Task AfterCompletion(RoleDialogModel message); -} diff --git a/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationHook.cs b/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationHook.cs new file mode 100644 index 000000000..6b2fdccb3 --- /dev/null +++ b/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationHook.cs @@ -0,0 +1,29 @@ +using BotSharp.Abstraction.Conversations.Models; +using BotSharp.Abstraction.MLTasks; + +namespace BotSharp.Abstraction.Conversations; + +public interface IConversationHook +{ + Agent Agent { get; } + IConversationHook SetAgent(Agent agent); + + Conversation Conversation { get; } + IConversationHook SetConversation(Conversation conversation); + + List Dialogs { get; } + /// + /// Triggered when dialog history is loaded + /// + /// + /// + Task OnDialogsLoaded(List dialogs); + + Task OnStateLoaded(ConversationState state); + Task OnStateChanged(string name, string preValue, string currentValue); + + Task BeforeCompletion(); + Task OnFunctionExecuting(RoleDialogModel message); + Task OnFunctionExecuted(RoleDialogModel message); + Task AfterCompletion(RoleDialogModel message); +} diff --git a/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationStateService.cs b/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationStateService.cs index 1478ffba6..3cac14f85 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationStateService.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Conversations/IConversationStateService.cs @@ -7,7 +7,9 @@ namespace BotSharp.Abstraction.Conversations; /// public interface IConversationStateService { - ConversationState Load(string conversationId); + void SetConversation(string conversationId); + ConversationState Load(); string GetState(string name); + void SetState(string name, string value); void Save(); } diff --git a/src/Infrastructure/BotSharp.Abstraction/Conversations/Models/RoleDialogModel.cs b/src/Infrastructure/BotSharp.Abstraction/Conversations/Models/RoleDialogModel.cs index fa0a19388..d9b047240 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Conversations/Models/RoleDialogModel.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Conversations/Models/RoleDialogModel.cs @@ -19,6 +19,17 @@ public class RoleDialogModel /// public string? ExecutionResult { get; set; } + /// + /// When function callback has been executed, system will pass result to LLM again, + /// Set this property to True to stop calling LLM. + /// + public bool StopSubsequentInteraction { get;set; } + + /// + /// Channel name + /// + public string Channel { get; set; } + public RoleDialogModel(string role, string text) { Role = role; diff --git a/src/Infrastructure/BotSharp.Core/Agents/Services/AgentService.LoadAgent.cs b/src/Infrastructure/BotSharp.Core/Agents/Services/AgentService.LoadAgent.cs new file mode 100644 index 000000000..5d9d7b4f7 --- /dev/null +++ b/src/Infrastructure/BotSharp.Core/Agents/Services/AgentService.LoadAgent.cs @@ -0,0 +1,47 @@ +using BotSharp.Abstraction.Agents.Models; + +namespace BotSharp.Core.Agents.Services; + +public partial class AgentService +{ + public async Task LoadAgent(string id) + { + var hooks = _services.GetServices(); + + // Before agent is loaded. + foreach (var hook in hooks) + { + hook.OnAgentLoading(ref id); + } + + var agent = await GetAgent(id); + + // After agent is loaded + foreach (var hook in hooks) + { + hook.SetAget(agent); + + if (!string.IsNullOrEmpty(agent.Instruction)) + { + var instruction = agent.Instruction; + hook.OnInstructionLoaded(ref instruction); + } + + if (!string.IsNullOrEmpty(agent.Functions)) + { + var functions = agent.Functions; + hook.OnFunctionsLoaded(ref functions); + } + + if (!string.IsNullOrEmpty(agent.Samples)) + { + var samples = agent.Samples; + hook.OnSamplesLoaded(ref samples); + } + + hook.OnAgentLoaded(); + } + + return agent; + } +} diff --git a/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationService.cs b/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationService.cs index c7269efd9..d733872bc 100644 --- a/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationService.cs +++ b/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationService.cs @@ -1,3 +1,4 @@ +using BotSharp.Abstraction.Agents.Enums; using BotSharp.Abstraction.Conversations; using BotSharp.Abstraction.Conversations.Models; using BotSharp.Abstraction.Functions; @@ -116,10 +117,7 @@ public async Task SendMessage(string agentId, string conversationId, RoleD } public async Task SendMessage(string agentId, string conversationId, List wholeDialogs, Func onMessageReceived) - { - var agent = await _services.GetRequiredService() - .GetAgent(agentId); - + { var converation = await GetConversation(conversationId); // Create conversation if this conversation not exists @@ -133,13 +131,18 @@ public async Task SendMessage(string agentId, string conversationId, List< converation = await NewConversation(sess); } - // load state + // conversation state var stateService = _services.GetRequiredService(); - var state = stateService.Load(conversationId); - state["agentId"] = agentId; - + stateService.SetConversation(conversationId); + stateService.Load(); + stateService.SetState("agentId", agentId); + + // load agent + var agentService = _services.GetRequiredService(); + var agent = await agentService.LoadAgent(agentId); + // Get relevant domain knowledge - if (_settings.EnableKnowledgeBase) + /*if (_settings.EnableKnowledgeBase) { var knowledge = _services.GetRequiredService(); agent.Knowledges = await knowledge.GetKnowledges(new KnowledgeRetrievalModel @@ -147,21 +150,19 @@ public async Task SendMessage(string agentId, string conversationId, List< AgentId = agentId, Question = string.Join("\n", wholeDialogs.Select(x => x.Content)) }); - } + }*/ var chatCompletion = GetChatCompletion(); - var hooks = _services.GetServices().ToList(); + var hooks = _services.GetServices().ToList(); // Before chat completion hook foreach (var hook in hooks) { hook.SetAgent(agent) - .SetConversation(converation) - .SetDialogs(wholeDialogs) - .SetChatCompletion(chatCompletion); + .SetConversation(converation); - await hook.OnStateLoaded(state, onAgentSwitched: x => agent = x); + await hook.OnDialogsLoaded(wholeDialogs); await hook.BeforeCompletion(); } @@ -172,7 +173,7 @@ public async Task SendMessage(string agentId, string conversationId, List< // Before executing functions foreach (var hook in hooks) { - await hook.OnFunctionExecuting(msg.FunctionName, msg.Content); + await hook.OnFunctionExecuting(msg); } // Save states var jo = JsonSerializer.Deserialize(msg.Content); @@ -180,11 +181,7 @@ public async Task SendMessage(string agentId, string conversationId, List< { foreach (JsonProperty property in root.EnumerateObject()) { - string propertyName = property.Name; - string propertyValue = property.Value.ToString(); - - _logger.LogInformation($"Set conversation state: {propertyName} - {propertyValue}"); - state[propertyName] = propertyValue; + stateService.SetState(property.Name, property.Value.ToString()); } } } @@ -197,6 +194,15 @@ public async Task SendMessage(string agentId, string conversationId, List< } } await onMessageReceived(msg); + + if (msg.Role == AgentRole.Function) + { + // After functions have been executed + foreach (var hook in hooks) + { + await hook.OnFunctionExecuted(msg); + } + } }); return result; diff --git a/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationStateService.cs b/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationStateService.cs index 7bd064122..b11ed9d24 100644 --- a/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationStateService.cs +++ b/src/Infrastructure/BotSharp.Core/Conversations/Services/ConversationStateService.cs @@ -1,4 +1,6 @@ using BotSharp.Abstraction.Conversations.Models; +using Microsoft.EntityFrameworkCore.Metadata.Internal; +using Microsoft.Extensions.Logging; using System.IO; namespace BotSharp.Core.Conversations.Services; @@ -8,27 +10,44 @@ namespace BotSharp.Core.Conversations.Services; /// public class ConversationStateService : IConversationStateService, IDisposable { + private readonly ILogger _logger; + private readonly IServiceProvider _services; private ConversationState _state; private MyDatabaseSettings _dbSettings; private string _conversationId; private string _file; - public ConversationStateService(MyDatabaseSettings dbSettings) + public ConversationStateService(ILogger logger, + IServiceProvider services, + MyDatabaseSettings dbSettings) { + _logger = logger; + _services = services; _dbSettings = dbSettings; } public void SetState(string name, string value) { - _state[name] = value; + var hooks = _services.GetServices(); + string preValue = _state.ContainsKey(name) ? _state[name] : ""; + if (!_state.ContainsKey(name) || _state[name] != value) + { + var currentValue = value; + _state[name] = currentValue; + _logger.LogInformation($"Set state: {name} - {value}"); + foreach (var hook in hooks) + { + hook.OnStateChanged(name, preValue, currentValue).Wait(); + } + } } - public void Dispose() + public void SetConversation(string conversationId) { - Save(); + _conversationId = conversationId; } - public ConversationState Load(string conversationId) + public ConversationState Load() { if (_state != null) { @@ -36,7 +55,6 @@ public ConversationState Load(string conversationId) } _state = new ConversationState(); - _conversationId = conversationId; _file = GetStorageFile(_conversationId); @@ -45,10 +63,17 @@ public ConversationState Load(string conversationId) var dict = File.ReadAllLines(_file); foreach (var line in dict) { - _state[line.Split(':')[0]] = line.Split(':')[1]; + _state[line.Split('=')[0]] = line.Split('=')[1]; } } + _logger.LogInformation($"Loaded state {_conversationId}"); + var hooks = _services.GetServices(); + foreach (var hook in hooks) + { + hook.OnStateLoaded(_state).Wait(); + } + return _state; } @@ -58,9 +83,10 @@ public void Save() foreach (var dic in _state) { - states.Add($"{dic.Key}:{dic.Value}"); + states.Add($"{dic.Key}={dic.Value}"); } File.WriteAllLines(_file, states); + _logger.LogInformation($"Saved state {_conversationId}"); } private string GetStorageFile(string conversationId) @@ -81,4 +107,9 @@ public string GetState(string name) } return _state[name]; } + + public void Dispose() + { + Save(); + } } diff --git a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/ChatCompletionProvider.cs b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/ChatCompletionProvider.cs index a018cfb5f..e9506f3dd 100644 --- a/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/ChatCompletionProvider.cs +++ b/src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/ChatCompletionProvider.cs @@ -117,6 +117,13 @@ public async Task GetChatCompletionsAsync(Agent agent, List> Messages([FromRoute] string age }); // Go to LLM - var result = await conv.SendMessage(agentId, senderId, new RoleDialogModel("user", input), async msg => + var result = await conv.SendMessage(agentId, senderId, new RoleDialogModel("user", input) { + Channel = "messenger" + }, async msg => + { + if (msg.Role == AgentRole.Function) + { + + } content = msg.Content; }, async fn => {