From bb1abc43ecc5cd06134b40efc1d27701be90204b Mon Sep 17 00:00:00 2001 From: Peter James Date: Sat, 3 Feb 2024 12:32:04 -0800 Subject: [PATCH 1/5] Add additional memory classes and cleanup --- .../LangChain.Samples.FileMemory/Program.cs | 17 +-- src/libs/LangChain.Core/Chains/Chain.cs | 4 +- .../StackableChains/Agents/GroupChat.cs | 32 ++++-- .../Agents/ReActAgentExecutorChain.cs | 21 +++- .../Chains/StackableChains/LoadMemoryChain.cs | 22 ++-- .../LangChain.Core/Memory/BaseChatMemory.cs | 63 +++++----- .../Memory/BaseChatMemoryInput.cs | 27 ----- .../Memory/BaseChatMessageHistory.cs | 68 +++++++++-- src/libs/LangChain.Core/Memory/BaseMemory.cs | 29 +++-- .../Memory/BufferMemoryInput.cs | 15 --- .../Memory/ChatMessageHistory.cs | 22 +++- .../Memory/ConversationBufferMemory.cs | 83 +++----------- .../Memory/ConversationSummaryBufferMemory.cs | 108 ++++++++++++++++++ .../Memory/ConversationSummaryMemory.cs | 74 ++++++++++++ .../Memory/ConversationWindowBufferMemory.cs | 61 ++++++++++ .../Memory/FileChatMessageHistory.cs | 33 ++---- .../LangChain.Core/Memory/MessageFormatter.cs | 61 ++++++++++ .../Memory/MessageSummarizer.cs | 58 ++++++++++ .../MessageHistoryTests.cs | 7 +- 19 files changed, 591 insertions(+), 214 deletions(-) delete mode 100644 src/libs/LangChain.Core/Memory/BaseChatMemoryInput.cs delete mode 100644 src/libs/LangChain.Core/Memory/BufferMemoryInput.cs create mode 100644 src/libs/LangChain.Core/Memory/ConversationSummaryBufferMemory.cs create mode 100644 src/libs/LangChain.Core/Memory/ConversationSummaryMemory.cs create mode 100644 src/libs/LangChain.Core/Memory/ConversationWindowBufferMemory.cs create mode 100644 src/libs/LangChain.Core/Memory/MessageFormatter.cs create mode 100644 src/libs/LangChain.Core/Memory/MessageSummarizer.cs diff --git a/examples/LangChain.Samples.FileMemory/Program.cs b/examples/LangChain.Samples.FileMemory/Program.cs index 6fbd0552..c257883e 100644 --- a/examples/LangChain.Samples.FileMemory/Program.cs +++ b/examples/LangChain.Samples.FileMemory/Program.cs @@ -1,4 +1,5 @@ using LangChain.Memory; +using LangChain.Providers; using LangChain.Providers.OpenAI; using static LangChain.Chains.Chain; @@ -17,21 +18,24 @@ The following is a friendly conversation between a human and an AI. // To have a conversation thar remembers previous messages we need to use memory. // For memory to work properly we need to specify AI and Human prefixes. -// Since in our template we have "AI:" and "Human:" we need to specify them here. Pay attention to spaces after prefixes. -var conversationBufferMemory = new ConversationBufferMemory(new FileChatMessageHistory("messages.json")) +// Since in our template we have "AI:" and "Human:" we need to specify those prefixes here. +var memory = new ConversationBufferMemory(new FileChatMessageHistory("messages.json")) { - AiPrefix = "AI: ", - HumanPrefix = "Human: " + Formatter = new MessageFormatter + { + AiPrefix = "AI", + HumanPrefix = "Human" + } }; // build chain. Notice that we don't set input key here. It will be set in the loop var chain = // load history. at first it will be empty, but UpdateMemory will update it every iteration - LoadMemory(conversationBufferMemory, outputKey: "history") + LoadMemory(memory, outputKey: "history") | Template(template) | LLM(model) // update memory with new request from Human and response from AI - | UpdateMemory(conversationBufferMemory, requestKey: "input", responseKey: "text"); + | UpdateMemory(memory, requestKey: "input", responseKey: "text"); // run an endless loop of conversation while (true) @@ -48,7 +52,6 @@ The following is a friendly conversation between a human and an AI. // get response from AI var res = await chatChain.Run("text"); - Console.Write("AI: "); Console.WriteLine(res); } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/Chain.cs b/src/libs/LangChain.Core/Chains/Chain.cs index 7a47456f..8a50a99e 100644 --- a/src/libs/LangChain.Core/Chains/Chain.cs +++ b/src/libs/LangChain.Core/Chains/Chain.cs @@ -149,7 +149,7 @@ public static StuffDocumentsChain CombineDocuments( /// /// public static UpdateMemoryChain UpdateMemory( - ConversationBufferMemory memory, + BaseChatMemory memory, string requestKey = "text", string responseKey = "text") { @@ -157,7 +157,7 @@ public static UpdateMemoryChain UpdateMemory( } public static LoadMemoryChain LoadMemory( - ConversationBufferMemory memory, + BaseChatMemory memory, string outputKey = "text") { return new LoadMemoryChain(memory, outputKey); diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs index f52bb755..bb93fd85 100644 --- a/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs +++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs @@ -18,6 +18,8 @@ public class GroupChat : BaseStackableChain private readonly string _outputKey; int _currentAgentId; + private readonly MessageFormatter _messageFormatter; + private readonly ChatMessageHistory _chatMessageHistory; private readonly ConversationBufferMemory _conversationBufferMemory; /// @@ -46,7 +48,20 @@ public GroupChat( _messagesLimit = messagesLimit; _inputKey = inputKey; _outputKey = outputKey; - _conversationBufferMemory = new ConversationBufferMemory(new ChatMessageHistory()) { AiPrefix = "", HumanPrefix = "", SystemPrefix = "", SaveHumanMessages = false }; + + _messageFormatter = new MessageFormatter + { + AiPrefix = "", + HumanPrefix = "", + SystemPrefix = "" + }; + + _chatMessageHistory = new ChatMessageHistory() + { + // Do not save human messages + IsMessageAccepted = x => (x.Role != MessageRole.Human) + }; + InputKeys = new[] { inputKey }; OutputKeys = new[] { outputKey }; } @@ -55,27 +70,28 @@ public GroupChat( /// /// /// - public IReadOnlyList History => _conversationBufferMemory.ChatHistory.Messages; + public IReadOnlyList History => _chatMessageHistory.Messages; /// protected override async Task InternalCall(IChainValues values) { values = values ?? throw new ArgumentNullException(nameof(values)); - await _conversationBufferMemory.Clear().ConfigureAwait(false); + await _chatMessageHistory.Clear().ConfigureAwait(false); foreach (var agent in _agents) { agent.SetHistory(""); } var firstAgent = _agents[0]; var firstAgentMessage = (string)values.Value[_inputKey]; - await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{firstAgent.Name}: {firstAgentMessage}", + await _chatMessageHistory.AddMessage(new Message($"{firstAgent.Name}: {firstAgentMessage}", MessageRole.System)).ConfigureAwait(false); int messagesCount = 1; while (messagesCount<_messagesLimit) { var agent = GetNextAgent(); - agent.SetHistory(_conversationBufferMemory.BufferAsString+"\n"+$"{agent.Name}:"); + string bufferText = _messageFormatter.Format(_chatMessageHistory.Messages); + agent.SetHistory(bufferText + "\n" + $"{agent.Name}:"); var res = await agent.CallAsync(values).ConfigureAwait(false); var message = (string)res.Value[agent.OutputKeys[0]]; if (message.Contains(_stopPhrase)) @@ -85,13 +101,13 @@ await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{firstAgent if (!agent.IsObserver) { - await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{agent.Name}: {message}", + await _chatMessageHistory.AddMessage(new Message($"{agent.Name}: {message}", MessageRole.System)).ConfigureAwait(false); } } - var result = _conversationBufferMemory.ChatHistory.Messages[^1]; - messagesCount = _conversationBufferMemory.ChatHistory.Messages.Count; + var result = _chatMessageHistory.Messages[^1]; + messagesCount = _chatMessageHistory.Messages.Count; if (ThrowOnLimit && messagesCount >= _messagesLimit) { throw new InvalidOperationException($"Message limit reached:{_messagesLimit}"); diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs index 468a8b94..73e09fad 100644 --- a/src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs +++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs @@ -46,6 +46,8 @@ Always add [END] after final answer private readonly IChatModel _model; private readonly string _reActPrompt; private readonly int _maxActions; + private readonly MessageFormatter _messageFormatter; + private readonly ChatMessageHistory _chatMessageHistory; private readonly ConversationBufferMemory _conversationBufferMemory; /// @@ -71,8 +73,23 @@ public ReActAgentExecutorChain( InputKeys = new[] { inputKey }; OutputKeys = new[] { outputKey }; - _conversationBufferMemory = new ConversationBufferMemory(new ChatMessageHistory()) { AiPrefix = "", HumanPrefix = "", SystemPrefix = "", SaveHumanMessages = false }; + _messageFormatter = new MessageFormatter + { + AiPrefix = "", + HumanPrefix = "", + SystemPrefix = "" + }; + + _chatMessageHistory = new ChatMessageHistory() + { + // Do not save human messages + IsMessageAccepted = x => (x.Role != MessageRole.Human) + }; + _conversationBufferMemory = new ConversationBufferMemory(_chatMessageHistory) + { + Formatter = _messageFormatter + }; } private string _userInput = string.Empty; @@ -86,7 +103,7 @@ private void InitializeChain() Set(() => _userInput, "input") | Set(tools, "tools") | Set(toolNames, "tool_names") - | Set(() => _conversationBufferMemory.BufferAsString, "history") + | LoadMemory(_conversationBufferMemory, outputKey: "history") | Template(_reActPrompt) | Chain.LLM(_model).UseCache(_useCache) | UpdateMemory(_conversationBufferMemory, requestKey: "input", responseKey: "text") diff --git a/src/libs/LangChain.Core/Chains/StackableChains/LoadMemoryChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/LoadMemoryChain.cs index e7889e10..46e7920f 100644 --- a/src/libs/LangChain.Core/Chains/StackableChains/LoadMemoryChain.cs +++ b/src/libs/LangChain.Core/Chains/StackableChains/LoadMemoryChain.cs @@ -1,29 +1,35 @@ using LangChain.Abstractions.Schema; using LangChain.Chains.HelperChains; using LangChain.Memory; +using LangChain.Schema; namespace LangChain.Chains.StackableChains; -public class LoadMemoryChain: BaseStackableChain +public class LoadMemoryChain : BaseStackableChain { - - private readonly ConversationBufferMemory _chatMemory; + private readonly BaseChatMemory _chatMemory; private readonly string _outputKey; - public LoadMemoryChain(ConversationBufferMemory chatMemory,string outputKey) + public LoadMemoryChain(BaseChatMemory chatMemory, string outputKey) { - _chatMemory = chatMemory; _outputKey = outputKey; - OutputKeys = new[] {_outputKey}; + OutputKeys = new[] { _outputKey }; } protected override Task InternalCall(IChainValues values) { values = values ?? throw new ArgumentNullException(nameof(values)); - - values.Value[_outputKey] = _chatMemory.BufferAsString; + + string memoryVariableName = _chatMemory.MemoryVariables.FirstOrDefault(); + if (memoryVariableName == null) + { + throw new Exception("Missing memory variable name"); + } + + OutputValues outputValues = _chatMemory.LoadMemoryVariables(null); + values.Value[_outputKey] = outputValues.Value[memoryVariableName]; return Task.FromResult(values); } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/BaseChatMemory.cs b/src/libs/LangChain.Core/Memory/BaseChatMemory.cs index 2d36887a..14a70612 100644 --- a/src/libs/LangChain.Core/Memory/BaseChatMemory.cs +++ b/src/libs/LangChain.Core/Memory/BaseChatMemory.cs @@ -2,54 +2,53 @@ namespace LangChain.Memory; -/// -public abstract class BaseChatMemory( - BaseChatMessageHistory chatHistory) - : BaseMemory +/// +/// Abstract base class for chat memory. +/// +/// NOTE: LangChain's return_messages property is not implemented due to differences between Python and C# +/// +public abstract class BaseChatMemory : BaseMemory { - /// - /// - /// - public BaseChatMessageHistory ChatHistory { get; set; } = chatHistory; - - /// - /// - /// + public BaseChatMessageHistory ChatHistory { get; } + public string? OutputKey { get; set; } - - /// - /// - /// + public string? InputKey { get; set; } - - // note: return type can't be implemented because of Any type as return type in Buffer property - - /// - /// This used just to save user message as input and AI message as output - /// - /// - /// + + protected BaseChatMemory() + { + ChatHistory = new ChatMessageHistory(); + } + + protected BaseChatMemory(BaseChatMessageHistory chatHistory) + { + ChatHistory = chatHistory ?? throw new ArgumentNullException(nameof(chatHistory)); + } + + /// public override async Task SaveContext(InputValues inputValues, OutputValues outputValues) { inputValues = inputValues ?? throw new ArgumentNullException(nameof(inputValues)); outputValues = outputValues ?? throw new ArgumentNullException(nameof(outputValues)); - - var inputKey = inputValues.Value.Keys.FirstOrDefault(); - if (inputKey != null) + + // If the InputKey is not specified, there must only be one input value + var inputKey = InputKey ?? inputValues.Value.Keys.Single(); + if (inputKey is not null) { await ChatHistory.AddUserMessage(inputValues.Value[inputKey].ToString() ?? string.Empty).ConfigureAwait(false); } - - var outputKey = outputValues.Value.Keys.FirstOrDefault(); - if (outputKey != null) + + // If the OutputKey is not specified, there must only be one output value + var outputKey = OutputKey ?? outputValues.Value.Keys.Single(); + if (outputKey is not null) { await ChatHistory.AddAiMessage(outputValues.Value[outputKey].ToString() ?? string.Empty).ConfigureAwait(false); } } /// - public override Task Clear() + public override async Task Clear() { - return ChatHistory.Clear(); + await ChatHistory.Clear().ConfigureAwait(false); } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/BaseChatMemoryInput.cs b/src/libs/LangChain.Core/Memory/BaseChatMemoryInput.cs deleted file mode 100644 index 4e429cae..00000000 --- a/src/libs/LangChain.Core/Memory/BaseChatMemoryInput.cs +++ /dev/null @@ -1,27 +0,0 @@ -namespace LangChain.Memory; - -/// -/// -/// -public class BaseChatMemoryInput -{ - /// - /// - /// - public BaseChatMessageHistory ChatHistory { get; set; } = new ChatMessageHistory(); - - /// - /// - /// - public string InputKey { get; set; } = string.Empty; - - /// - /// - /// - public string MemoryKey { get; set; } = string.Empty; - - /// - /// - /// - public bool ReturnMessages { get; set; } -} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs b/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs index 4523f7e3..bb8cf22d 100644 --- a/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs +++ b/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs @@ -1,23 +1,52 @@ using LangChain.Providers; +using System.Numerics; +using System.Reflection.Emit; namespace LangChain.Memory; /// +/// Abstract base class for storing chat message history. /// +/// Implementations should over-ride the AddMessages method to handle bulk addition +/// of messages. +/// +/// The default implementation of AddMessages will correctly call AddMessage, so +/// it is not necessary to implement both methods. +/// +/// When used for updating history, users should favor usage of `AddMessages` +/// over `AddMessage` or other variants like `AddUserMessage` and `AddAiMessage` +/// to avoid unnecessary round-trips to the underlying persistence layer. /// public abstract class BaseChatMessageHistory { /// + /// A list of messages stored in-memory. + /// + public abstract IReadOnlyList Messages { get; } + + /// + /// Convenience method for adding a human message string to the store. + /// + /// Please note that this is a convenience method. Code should favor the + /// bulk AddMessages interface instead to save on round-trips to the underlying + /// persistence layer. /// + /// This method may be deprecated in a future release. /// - /// + /// The human message to add public async Task AddUserMessage(string message) { await AddMessage(message.AsHumanMessage()).ConfigureAwait(false); } /// + /// Convenience method for adding an AI message string to the store. /// + /// Please note that this is a convenience method. Code should favor the bulk + /// AddMessages interface instead to save on round-trips to the underlying + /// persistence layer. + /// + /// This method may be deprecated in a future release. /// /// public async Task AddAiMessage(string message) @@ -26,20 +55,45 @@ public async Task AddAiMessage(string message) } /// - /// + /// Add a message object to the store. /// - public abstract IReadOnlyList Messages { get; } + /// A message object to store + public abstract Task AddMessage(Message message); /// + /// Add a list of messages. /// + /// Implementations should override this method to handle bulk addition of messages + /// in an efficient manner to avoid unnecessary round-trips to the underlying store. /// - /// - /// - public abstract Task AddMessage(Message message); + /// A list of message objects to store. + public virtual async Task AddMessages(IEnumerable messages) + { + messages = messages ?? throw new ArgumentNullException(nameof(messages)); + + foreach (var message in messages) + { + await AddMessage(message).ConfigureAwait(false); + } + } /// + /// Replace the list of messages. /// + /// Implementations should override this method to handle bulk addition of messages + /// in an efficient manner to avoid unnecessary round-trips to the underlying store. + /// + /// A list of message objects to store. + public virtual async Task SetMessages(IEnumerable messages) + { + messages = messages ?? throw new ArgumentNullException(nameof(messages)); + + await Clear().ConfigureAwait(false); + await AddMessages(messages).ConfigureAwait(false); + } + + /// + /// Remove all messages from the store /// - /// public abstract Task Clear(); } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/BaseMemory.cs b/src/libs/LangChain.Core/Memory/BaseMemory.cs index f26dff3b..a0cfbc97 100644 --- a/src/libs/LangChain.Core/Memory/BaseMemory.cs +++ b/src/libs/LangChain.Core/Memory/BaseMemory.cs @@ -3,34 +3,39 @@ namespace LangChain.Memory; /// +/// Abstract base class for memory in Chains. /// +/// Memory refers to state in Chains. Memory can be used to store information about +/// past executions of a Chain and inject that information into the inputs of +/// future executions of the Chain. For example, for conversational Chains Memory +/// can be used to store conversations and automatically add them to future model +/// prompts so that the model has the necessary context to respond coherently to +/// the latest input. /// public abstract class BaseMemory { /// - /// + /// Return key-value pairs given the text input to the chain. + /// + public abstract List MemoryVariables { get; } + + /// + /// The string keys this memory class will add to chain inputs. /// /// /// public abstract OutputValues LoadMemoryVariables(InputValues? inputValues); - - /// - /// - /// - public abstract List MemoryVariables { get; } - + /// - /// + /// Save the context of this chain run to memory. /// /// /// - /// public abstract Task SaveContext(InputValues inputValues, OutputValues outputValues); - + /// - /// + /// Clear memory contents. /// - /// public abstract Task Clear(); } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/BufferMemoryInput.cs b/src/libs/LangChain.Core/Memory/BufferMemoryInput.cs deleted file mode 100644 index c5b99604..00000000 --- a/src/libs/LangChain.Core/Memory/BufferMemoryInput.cs +++ /dev/null @@ -1,15 +0,0 @@ -namespace LangChain.Memory; - -/// -public sealed class BufferMemoryInput : BaseChatMemoryInput -{ - /// - /// - /// - public string AiPrefix { get; set; } = string.Empty; - - /// - /// - /// - public string HumanPrefix { get; set; } = string.Empty; -} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/ChatMessageHistory.cs b/src/libs/LangChain.Core/Memory/ChatMessageHistory.cs index 8b303563..c47049a9 100644 --- a/src/libs/LangChain.Core/Memory/ChatMessageHistory.cs +++ b/src/libs/LangChain.Core/Memory/ChatMessageHistory.cs @@ -2,18 +2,32 @@ namespace LangChain.Memory; -/// +/// +/// In memory implementation of chat message history. +/// +/// Stores messages in an in memory list. +/// public class ChatMessageHistory : BaseChatMessageHistory { private readonly List _messages = new List(); - + + /// + /// Used to inspect and filter messages on their way to the history store + /// NOTE: This is not a feature of python langchain + /// + public Predicate IsMessageAccepted { get; set; } = (x => true); + /// public override IReadOnlyList Messages => _messages; /// public override Task AddMessage(Message message) - { - _messages.Add(message); + { + if (IsMessageAccepted(message)) + { + _messages.Add(message); + } + return Task.CompletedTask; } diff --git a/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs b/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs index 28a9c8bf..0cc12ed4 100644 --- a/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs +++ b/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs @@ -3,88 +3,41 @@ namespace LangChain.Memory; -/// +/// +/// Buffer for storing conversation memory. +/// +/// NOTE: LangChain's buffer property is not implemented here +/// public class ConversationBufferMemory : BaseChatMemory { - /// - /// - /// - public string HumanPrefix { get; set; } = "Human: "; - - /// - /// - /// - public string AiPrefix { get; set; } = "AI: "; - - /// - /// - /// - public string SystemPrefix { get; set; } = "System: "; + public MessageFormatter Formatter { get; set; } = new MessageFormatter(); - /// - /// - /// public string MemoryKey { get; set; } = "history"; - /// - /// - /// - public bool SaveHumanMessages { get; set; } = true; - /// - public ConversationBufferMemory(BaseChatMessageHistory chatHistory) : base(chatHistory) - { - ChatHistory = chatHistory; - } - - // note: buffer property can't be implemented because of Any type as return type + public override List MemoryVariables => new List { MemoryKey }; /// - /// + /// Initializes new buffered memory instance /// - public string BufferAsString => GetBufferString(BufferAsMessages); + public ConversationBufferMemory() + : base() + { + } /// - /// + /// Initializes new buffered memory instance with provided history store /// - public IReadOnlyList BufferAsMessages => ChatHistory.Messages; - - /// - public override List MemoryVariables => new List { MemoryKey }; - - private string GetBufferString(IEnumerable messages) + /// History backing store + public ConversationBufferMemory(BaseChatMessageHistory chatHistory) + : base(chatHistory) { - var stringMessages = new List(); - - foreach (var m in messages) - { - - if (m.Role==MessageRole.Human&&!SaveHumanMessages) - { - continue; - } - - string role = m.Role switch - { - MessageRole.Human => HumanPrefix, - MessageRole.Ai => AiPrefix, - MessageRole.System => SystemPrefix, - MessageRole.FunctionCall => "Function", - _ => throw new ArgumentException($"Unsupported message type: {m.GetType().Name}") - }; - - string message = $"{role}{m.Content}"; - // TODO: Add special case for a function call - - stringMessages.Add(message); - } - - return string.Join("\n", stringMessages); } /// public override OutputValues LoadMemoryVariables(InputValues? inputValues) { - return new OutputValues(new Dictionary { { MemoryKey, BufferAsString } }); + string bufferText = Formatter.Format(ChatHistory.Messages); + return new OutputValues(new Dictionary { { MemoryKey, bufferText } }); } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/ConversationSummaryBufferMemory.cs b/src/libs/LangChain.Core/Memory/ConversationSummaryBufferMemory.cs new file mode 100644 index 00000000..7487c0f0 --- /dev/null +++ b/src/libs/LangChain.Core/Memory/ConversationSummaryBufferMemory.cs @@ -0,0 +1,108 @@ +using LangChain.Providers; +using LangChain.Schema; + +namespace LangChain.Memory; + +/// +/// Buffer with summarizer for storing conversation memory. +/// +public class ConversationSummaryBufferMemory : BaseChatMemory +{ + private IChatModelWithTokenCounting Model { get; } + private MessageSummarizer Summarizer { get; } + + private string SummaryText { get; set; } = string.Empty; + + public MessageFormatter Formatter { get; set; } = new MessageFormatter(); + public string MemoryKey { get; set; } = "history"; + public int MaxTokenCount { get; set; } = 2000; + + /// + public override List MemoryVariables => new List { MemoryKey }; + + /// + /// Initializes new memory instance with provided model and a default history store + /// + /// Model to use for summarization + /// + public ConversationSummaryBufferMemory(IChatModelWithTokenCounting model) + : base() + { + Model = model ?? throw new ArgumentNullException(nameof(model)); + Summarizer = new MessageSummarizer(model); + } + + /// + /// Initializes new memory instance with provided model and history store + /// + /// Model to use for summarization + /// History backing store + /// + public ConversationSummaryBufferMemory(IChatModelWithTokenCounting model, BaseChatMessageHistory chatHistory) + : base(chatHistory) + { + Model = model ?? throw new ArgumentNullException(nameof(model)); + Summarizer = new MessageSummarizer(model); + } + + /// + public override OutputValues LoadMemoryVariables(InputValues? inputValues) + { + string bufferText = Formatter.Format(GetMessages()); + return new OutputValues(new Dictionary { { MemoryKey, bufferText } }); + } + + /// + public override async Task SaveContext(InputValues inputValues, OutputValues outputValues) + { + await base.SaveContext(inputValues, outputValues).ConfigureAwait(false); + + // Maintain max token size of messages + await PruneMessages().ConfigureAwait(false); + } + + /// + public override async Task Clear() + { + await base.Clear().ConfigureAwait(false); + SummaryText = string.Empty; + } + + /// + /// Prune messages if they exceed the max token limit + /// + /// + private async Task PruneMessages() + { + List prunedMessages = new List(); + + int tokenCount = Model.CountTokens(ChatHistory.Messages); + if (tokenCount > MaxTokenCount) + { + Queue queue = new Queue(ChatHistory.Messages); + + while (tokenCount > MaxTokenCount) + { + Message prunedMessage = queue.Dequeue(); + prunedMessages.Add(prunedMessage); + + tokenCount = Model.CountTokens(queue); + } + + SummaryText = await Summarizer.Summarize(prunedMessages, SummaryText).ConfigureAwait(false); + + await ChatHistory.SetMessages(queue).ConfigureAwait(false); + } + } + + private List GetMessages() + { + List messages = new List + { + SummaryText.AsSystemMessage() + }; + messages.AddRange(ChatHistory.Messages); + + return messages; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/ConversationSummaryMemory.cs b/src/libs/LangChain.Core/Memory/ConversationSummaryMemory.cs new file mode 100644 index 00000000..fe333f02 --- /dev/null +++ b/src/libs/LangChain.Core/Memory/ConversationSummaryMemory.cs @@ -0,0 +1,74 @@ +using LangChain.Providers; +using LangChain.Schema; +using System.Security.Cryptography; + +namespace LangChain.Memory; + +/// +/// Conversation summarizer to chat memory. +/// +public class ConversationSummaryMemory : BaseChatMemory +{ + public MessageFormatter Formatter { get; set; } = new MessageFormatter(); + + public string MemoryKey { get; set; } = "history"; + + /// + public override List MemoryVariables => new List { MemoryKey }; + + private IChatModel Model { get; } + private MessageSummarizer Summarizer { get; } + private string SummaryText { get; set; } = string.Empty; + + /// + /// Initializes new summarizing memory instance with provided model + /// + /// Model to use for summarization + /// + public ConversationSummaryMemory(IChatModel model) + : base() + { + Model = model ?? throw new ArgumentNullException(nameof(model)); + Summarizer = new MessageSummarizer(model); + } + + /// + /// Initializes new summarizing memory instance with provided model and history store + /// + /// Model to use for summarization + /// History backing store + /// + public ConversationSummaryMemory(IChatModel model, BaseChatMessageHistory chatHistory) + : base(chatHistory) + { + Model = model ?? throw new ArgumentNullException(nameof(model)); + Summarizer = new MessageSummarizer(model); + } + + /// + public override OutputValues LoadMemoryVariables(InputValues? inputValues) + { + return new OutputValues(new Dictionary { { MemoryKey, SummaryText } }); + } + + /// + public override async Task SaveContext(InputValues inputValues, OutputValues outputValues) + { + // Save non-summarized values to the history + await base.SaveContext(inputValues, outputValues).ConfigureAwait(false); + + // Since we are in SaveContext, can assume there are at least two messages (human + ai) + var newMessages = ChatHistory.Messages + .Skip(ChatHistory.Messages.Count - 2) + .Take(2); + + SummaryText = await Summarizer.Summarize(newMessages, SummaryText).ConfigureAwait(false); + } + + /// + public override async Task Clear() + { + await base.Clear().ConfigureAwait(false); + SummaryText = string.Empty; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/ConversationWindowBufferMemory.cs b/src/libs/LangChain.Core/Memory/ConversationWindowBufferMemory.cs new file mode 100644 index 00000000..1a28111b --- /dev/null +++ b/src/libs/LangChain.Core/Memory/ConversationWindowBufferMemory.cs @@ -0,0 +1,61 @@ +using LangChain.Providers; +using LangChain.Schema; + +namespace LangChain.Memory; + +/// +/// Buffer for storing conversation memory. +/// +/// NOTE: LangChain's buffer property is not implemented here +/// +public class ConversationWindowBufferMemory : BaseChatMemory +{ + public MessageFormatter Formatter { get; set; } = new MessageFormatter(); + + public string MemoryKey { get; set; } = "history"; + + /// + /// Number of messages to store in buffer. + /// + /// This is actually the number of Human+AI pairs of messages. + /// This is the 'k' property in python langchain + /// + public int WindowSize { get; set; } = 5; + + /// + public override List MemoryVariables => new List { MemoryKey }; + + /// + /// Initializes new windowed buffer memory instance + /// + public ConversationWindowBufferMemory() + : base() + { + } + + /// + /// Initializes new windowed buffer memory instance with provided history store + /// + /// History backing store + public ConversationWindowBufferMemory(BaseChatMessageHistory chatHistory) + : base(chatHistory) + { + } + + /// + public override OutputValues LoadMemoryVariables(InputValues? inputValues) + { + string bufferText = Formatter.Format(GetMessages()); + return new OutputValues(new Dictionary { { MemoryKey, bufferText } }); + } + + private List GetMessages() + { + int numMessages = Math.Min(ChatHistory.Messages.Count, WindowSize * 2); + + return ChatHistory.Messages + .Skip(ChatHistory.Messages.Count - numMessages) + .Take(numMessages) + .ToList(); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/FileChatMessageHistory.cs b/src/libs/LangChain.Core/Memory/FileChatMessageHistory.cs index bc9d2f8b..a6896bdc 100644 --- a/src/libs/LangChain.Core/Memory/FileChatMessageHistory.cs +++ b/src/libs/LangChain.Core/Memory/FileChatMessageHistory.cs @@ -4,36 +4,28 @@ namespace LangChain.Memory; /// -/// Stores history in a local file. +/// Chat message history that stores history in a local file. /// public class FileChatMessageHistory : BaseChatMessageHistory { private string MessagesFilePath { get; } - private List _messages; + private List _messages = new List(); /// - public override IReadOnlyList Messages - { - get - { - if (_messages is null) - { - LoadMessages().Wait(); - } + public override IReadOnlyList Messages => _messages; - return _messages; - } - } - /// - /// + /// Initializes new history instance with provided file path /// - /// Path to local history file + /// path of the local file to store the messages /// public FileChatMessageHistory(string messagesFilePath) { MessagesFilePath = messagesFilePath ?? throw new ArgumentNullException(nameof(messagesFilePath)); + + // Blocking call in the constructor creates a simpler implementation + LoadMessages().Wait(); } /// @@ -53,19 +45,16 @@ public override async Task Clear() private async Task SaveMessages() { string json = JsonSerializer.Serialize(_messages); - await Task.Run(() => File.WriteAllText(MessagesFilePath, json)); + await Task.Run(() => File.WriteAllText(MessagesFilePath, json)).ConfigureAwait(false); } private async Task LoadMessages() { if (File.Exists(MessagesFilePath)) { - string json = await Task.Run(() => File.ReadAllText(MessagesFilePath)); + string json = await Task.Run(() => File.ReadAllText(MessagesFilePath)).ConfigureAwait(false); _messages = JsonSerializer.Deserialize>(json); } - else - { - _messages = new List(); - } } + } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/MessageFormatter.cs b/src/libs/LangChain.Core/Memory/MessageFormatter.cs new file mode 100644 index 00000000..e24f1e4c --- /dev/null +++ b/src/libs/LangChain.Core/Memory/MessageFormatter.cs @@ -0,0 +1,61 @@ +using LangChain.Providers; + +namespace LangChain.Memory; + +public class MessageFormatter +{ + public string HumanPrefix { get; set; } = "Human"; + public string AiPrefix { get; set; } = "AI"; + public string SystemPrefix { get; set; } = "System"; + public string FunctionCallPrefix { get; set; } = "Function"; + public string FunctionResultPrefix { get; set; } = "Result"; + public string ChatPrefix { get; set; } = "Chat"; + + private string GetPrefix(MessageRole role) + { + switch (role) + { + case MessageRole.System: + return SystemPrefix; + + case MessageRole.Human: + return HumanPrefix; + + case MessageRole.Ai: + return AiPrefix; + + case MessageRole.FunctionCall: + return FunctionCallPrefix; + + case MessageRole.FunctionResult: + return FunctionResultPrefix; + + case MessageRole.Chat: + return ChatPrefix; + + default: + throw new ArgumentException("Unrecognized message role", nameof(role)); + } + } + + public string Format(Message message) + { + string messagePrefix = GetPrefix(message.Role); + + return string.IsNullOrEmpty(messagePrefix) ? message.Content : $"{messagePrefix}: {message.Content}"; + } + + public string Format(IEnumerable messages) + { + messages = messages ?? throw new ArgumentNullException(nameof(messages)); + + List formattedMessages = new List(); + + foreach (Message message in messages) + { + formattedMessages.Add(Format(message)); + } + + return string.Join("\n", messages); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/MessageSummarizer.cs b/src/libs/LangChain.Core/Memory/MessageSummarizer.cs new file mode 100644 index 00000000..c909d190 --- /dev/null +++ b/src/libs/LangChain.Core/Memory/MessageSummarizer.cs @@ -0,0 +1,58 @@ +using LangChain.Providers; +using static LangChain.Chains.Chain; + +namespace LangChain.Memory; + +public class MessageSummarizer +{ + private const string SummaryPrompt = @" +Progressively summarize the lines of conversation provided, adding onto the previous summary returning a new summary. + +EXAMPLE +Current summary: +The human asks what the AI thinks of artificial intelligence.The AI thinks artificial intelligence is a force for good. + +New lines of conversation: +Human: Why do you think artificial intelligence is a force for good? +AI: Because artificial intelligence will help humans reach their full potential. + +New summary: +The human asks what the AI thinks of artificial intelligence. The AI thinks artificial intelligence is a force for good because it will help humans reach their full potential. +END OF EXAMPLE + +Current summary: +{summary} + + New lines of conversation: +{new_lines} + +New summary:"; + + private IChatModel Model { get; } + private MessageFormatter Formatter { get; } + + public MessageSummarizer(IChatModel model) + { + Model = model ?? throw new ArgumentNullException(nameof(model)); + Formatter = new MessageFormatter(); + } + + public MessageSummarizer(IChatModel model, MessageFormatter formatter) + { + Model = model ?? throw new ArgumentNullException(nameof(model)); + Formatter = formatter ?? throw new ArgumentNullException(nameof(formatter)); + } + + public async Task Summarize(IEnumerable newMessages, string existingSummary) + { + string newLines = Formatter.Format(newMessages); + + var chain = + Set(existingSummary, outputKey: "summary") + | Set(newLines, outputKey: "new_lines") + | Template(SummaryPrompt) + | LLM(Model); + + return await chain.Run("text").ConfigureAwait(false); + } +} \ No newline at end of file diff --git a/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/MessageHistoryTests.cs b/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/MessageHistoryTests.cs index 15d0627e..90e0c152 100644 --- a/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/MessageHistoryTests.cs +++ b/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/MessageHistoryTests.cs @@ -21,13 +21,14 @@ public void TestHistory() Human: {message} AI: "; - var memory = new ConversationBufferMemory(new ChatMessageHistory()); + var history = new ChatMessageHistory(); + var memory = new ConversationBufferMemory(history); var message = Set("hi, i am Jimmy", "message"); var chain = message - | Set(() => memory.BufferAsString, outputKey: "chat_history") // get lates messages from buffer every time + | LoadMemory(memory, outputKey: "chat_history") // get lates messages from buffer every time | Template(promptText, outputKey: "prompt") | LLM(model, inputKey: "prompt", outputKey: "text") | UpdateMemory(memory, requestKey: "message", responseKey: "text"); // save the messages to the buffer @@ -41,7 +42,7 @@ public void TestHistory() var res=chain.Run().Result; // call the chain for the second time. // prompt will contain previous messages and a question about the name. - Assert.AreEqual(4,memory.BufferAsMessages.Count); + Assert.AreEqual(4, history.Messages.Count); res.Value["text"].ToString()?.ToLower().Trim().Contains("jimmy").Should().BeTrue(); } From eec6c46f970489e6e46e7499a25d4339410dde35 Mon Sep 17 00:00:00 2001 From: Peter James Date: Sat, 3 Feb 2024 13:16:52 -0800 Subject: [PATCH 2/5] Fix a couple of bugs --- src/libs/LangChain.Core/Memory/MessageFormatter.cs | 2 +- src/libs/LangChain.Core/Memory/MessageSummarizer.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libs/LangChain.Core/Memory/MessageFormatter.cs b/src/libs/LangChain.Core/Memory/MessageFormatter.cs index e24f1e4c..baff45f1 100644 --- a/src/libs/LangChain.Core/Memory/MessageFormatter.cs +++ b/src/libs/LangChain.Core/Memory/MessageFormatter.cs @@ -56,6 +56,6 @@ public string Format(IEnumerable messages) formattedMessages.Add(Format(message)); } - return string.Join("\n", messages); + return string.Join("\n", formattedMessages); } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/MessageSummarizer.cs b/src/libs/LangChain.Core/Memory/MessageSummarizer.cs index c909d190..d9acbcb4 100644 --- a/src/libs/LangChain.Core/Memory/MessageSummarizer.cs +++ b/src/libs/LangChain.Core/Memory/MessageSummarizer.cs @@ -23,7 +23,7 @@ END OF EXAMPLE Current summary: {summary} - New lines of conversation: +New lines of conversation: {new_lines} New summary:"; From b0ff297164f36a179032ada836f4f308a7125ddb Mon Sep 17 00:00:00 2001 From: Peter James Date: Sat, 3 Feb 2024 15:11:02 -0800 Subject: [PATCH 3/5] Update memory example app to try out different forms of memory --- LangChain.sln | 7 - .../LangChain.Samples.FileMemory.csproj | 14 -- .../LangChain.Samples.FileMemory/Program.cs | 57 ------ examples/LangChain.Samples.Memory/Program.cs | 177 +++++++++++++++++- 4 files changed, 170 insertions(+), 85 deletions(-) delete mode 100644 examples/LangChain.Samples.FileMemory/LangChain.Samples.FileMemory.csproj delete mode 100644 examples/LangChain.Samples.FileMemory/Program.cs diff --git a/LangChain.sln b/LangChain.sln index f0e875a9..2e218ebf 100644 --- a/LangChain.sln +++ b/LangChain.sln @@ -172,8 +172,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Automat EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Automatic1111.IntegrationTests", "src\tests\LangChain.Providers.Automatic1111.IntegrationTests\LangChain.Providers.Automatic1111.IntegrationTests.csproj", "{A6CF79BC-8365-46E8-9230-1A4AD615D40B}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Samples.FileMemory", "examples\LangChain.Samples.FileMemory\LangChain.Samples.FileMemory.csproj", "{BA701280-0BEB-4DA4-92B3-9C777082C2AF}" -EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -404,10 +402,6 @@ Global {A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Debug|Any CPU.Build.0 = Debug|Any CPU {A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.ActiveCfg = Release|Any CPU {A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.Build.0 = Release|Any CPU - {BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Debug|Any CPU.Build.0 = Debug|Any CPU - {BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Release|Any CPU.ActiveCfg = Release|Any CPU - {BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -476,7 +470,6 @@ Global {4913844F-74EC-4E74-AE8A-EA825569E6BA} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68} {BF4C7B87-0997-4208-84EF-D368DF7B9861} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68} {A6CF79BC-8365-46E8-9230-1A4AD615D40B} = {FDEE2E22-C239-4921-83B2-9797F765FD6A} - {BA701280-0BEB-4DA4-92B3-9C777082C2AF} = {F17A86AE-A174-4B6C-BAA7-9D9A9704BE85} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {5C00D0F1-6138-4ED9-846B-97E43D6DFF1C} diff --git a/examples/LangChain.Samples.FileMemory/LangChain.Samples.FileMemory.csproj b/examples/LangChain.Samples.FileMemory/LangChain.Samples.FileMemory.csproj deleted file mode 100644 index 7b2b281a..00000000 --- a/examples/LangChain.Samples.FileMemory/LangChain.Samples.FileMemory.csproj +++ /dev/null @@ -1,14 +0,0 @@ - - - - Exe - net8.0 - enable - enable - - - - - - - diff --git a/examples/LangChain.Samples.FileMemory/Program.cs b/examples/LangChain.Samples.FileMemory/Program.cs deleted file mode 100644 index c257883e..00000000 --- a/examples/LangChain.Samples.FileMemory/Program.cs +++ /dev/null @@ -1,57 +0,0 @@ -using LangChain.Memory; -using LangChain.Providers; -using LangChain.Providers.OpenAI; -using static LangChain.Chains.Chain; - -var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? - throw new InvalidOperationException("OPENAI_API_KEY environment variable is not found."); - -var model = new OpenAiModel(apiKey, "gpt-3.5-turbo"); - - -// create simple template for conversation for AI to know what piece of text it is looking at -var template = @" -The following is a friendly conversation between a human and an AI. -{history} -Human: {input} -AI:"; - -// To have a conversation thar remembers previous messages we need to use memory. -// For memory to work properly we need to specify AI and Human prefixes. -// Since in our template we have "AI:" and "Human:" we need to specify those prefixes here. -var memory = new ConversationBufferMemory(new FileChatMessageHistory("messages.json")) -{ - Formatter = new MessageFormatter - { - AiPrefix = "AI", - HumanPrefix = "Human" - } -}; - -// build chain. Notice that we don't set input key here. It will be set in the loop -var chain = - // load history. at first it will be empty, but UpdateMemory will update it every iteration - LoadMemory(memory, outputKey: "history") - | Template(template) - | LLM(model) - // update memory with new request from Human and response from AI - | UpdateMemory(memory, requestKey: "input", responseKey: "text"); - -// run an endless loop of conversation -while (true) -{ - Console.Write("Human: "); - var input = Console.ReadLine(); - if (input == "exit") - break; - - // build a new chain using previous chain but with new input every time - var chatChain = Set(input, "input") - | chain; - - // get response from AI - var res = await chatChain.Run("text"); - - Console.Write("AI: "); - Console.WriteLine(res); -} \ No newline at end of file diff --git a/examples/LangChain.Samples.Memory/Program.cs b/examples/LangChain.Samples.Memory/Program.cs index 07b55699..0f211d1f 100644 --- a/examples/LangChain.Samples.Memory/Program.cs +++ b/examples/LangChain.Samples.Memory/Program.cs @@ -1,13 +1,176 @@ using LangChain.Memory; +using LangChain.Providers; +using LangChain.Providers.OpenAI; +using System.Runtime.Serialization; +using static LangChain.Chains.Chain; -var inMemoryHistory = new ChatMessageHistory(); +internal class Program +{ + private static async Task Main(string[] args) + { + // Pull the API key from the environment, so it's never checked in with source + var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? + throw new InvalidOperationException("OPENAI_API_KEY environment variable is not found."); -await inMemoryHistory.AddUserMessage("hi!"); + // Use a common, general-purpose LLM + var model = new OpenAiModel(apiKey, "gpt-3.5-turbo"); -await inMemoryHistory.AddAiMessage("whats up?"); + // Create a simple prompt template for the conversation to help the AI + var template = @" +The following is a friendly conversation between a human and an AI. -foreach (var message in inMemoryHistory.Messages) -{ - Console.WriteLine(message.GetType().Name + ":" + message.Content); -} +{history} +Human: {input} +AI: "; + + // To have a conversation that remembers previous messages we need to use memory. + // Here we pick one of a number of different strategies for implementing memory. + var memory = PickMemoryStrategy(model); + + // Build the chain that will be used for each turn in our conversation. + // This is just declaring the chain. Actual execution of the chain happens + // in the conversation loop below. On every pass through the loop, the user's + // input is added to the beginning of this chain to make a new chain. + var chain = + LoadMemory(memory, outputKey: "history") + | Template(template) + | LLM(model) + | UpdateMemory(memory, requestKey: "input", responseKey: "text"); + + Console.WriteLine(); + Console.WriteLine("Start a conversation with the friendly AI!"); + Console.WriteLine("(Enter 'exit' or hit Ctrl-C to end the conversation)"); + + // Run an endless loop of conversation + while (true) + { + Console.WriteLine(); + + Console.Write("Human: "); + var input = Console.ReadLine(); + if (input == "exit") + { + break; + } + + // Build a new chain by prepending the user's input to the original chain + var currentChain = Set(input, "input") + | chain; + + // Get a response from the AI + var response = await currentChain.Run("text"); + + Console.Write("AI: "); + Console.WriteLine(response); + } + } + + private static BaseChatMemory PickMemoryStrategy(IChatModel model) + { + // The memory will add prefixes to messages to indicate where they came from + // The prefixes specified here should match those used in our prompt template + MessageFormatter messageFormatter = new MessageFormatter + { + AiPrefix = "AI", + HumanPrefix = "Human" + }; + + BaseChatMessageHistory chatHistory = GetChatMessageHistory(); + + string memoryClassName = PromptForChoice(new[] + { + nameof(ConversationBufferMemory), + nameof(ConversationWindowBufferMemory), + nameof(ConversationSummaryMemory), + nameof(ConversationSummaryBufferMemory) + }); + + switch (memoryClassName) + { + case nameof(ConversationBufferMemory): + return GetConversationBufferMemory(chatHistory, messageFormatter); + case nameof(ConversationWindowBufferMemory): + return GetConversationWindowBufferMemory(chatHistory, messageFormatter); + + case nameof(ConversationSummaryMemory): + return GetConversationSummaryMemory(chatHistory, messageFormatter, model); + + case nameof(ConversationSummaryBufferMemory): + return GetConversationSummaryBufferMemory(chatHistory, messageFormatter, (IChatModelWithTokenCounting)model); + + default: + throw new Exception($"Unexpected memory class name: '{memoryClassName}'"); + } + } + + private static string PromptForChoice(string[] choiceTexts) + { + while (true) + { + Console.Clear(); + Console.WriteLine("Select from the following options:"); + + int choiceNumber = 1; + + foreach (string choiceText in choiceTexts) + { + Console.WriteLine($" {choiceNumber}: {choiceText}"); + choiceNumber++; + } + + Console.WriteLine(); + Console.Write("Enter choice: "); + + string choiceEntry = Console.ReadLine(); + if (int.TryParse(choiceEntry, out int choiceIndex)) + { + string choiceText = choiceTexts[choiceIndex]; + + Console.WriteLine(); + Console.WriteLine($"You selected '{choiceText}'"); + + return choiceText; + } + } + } + + private static BaseChatMessageHistory GetChatMessageHistory() + { + // Other types of chat history work, too! + return new ChatMessageHistory(); + } + + private static BaseChatMemory GetConversationBufferMemory(BaseChatMessageHistory chatHistory, MessageFormatter messageFormatter) + { + return new ConversationBufferMemory(chatHistory) + { + Formatter = messageFormatter + }; + } + + private static BaseChatMemory GetConversationWindowBufferMemory(BaseChatMessageHistory chatHistory, MessageFormatter messageFormatter) + { + return new ConversationWindowBufferMemory(chatHistory) + { + WindowSize = 3, + Formatter = messageFormatter + }; + } + + private static BaseChatMemory GetConversationSummaryMemory(BaseChatMessageHistory chatHistory, MessageFormatter messageFormatter, IChatModel model) + { + return new ConversationSummaryMemory(model, chatHistory) + { + Formatter = messageFormatter + }; + } + private static BaseChatMemory GetConversationSummaryBufferMemory(BaseChatMessageHistory chatHistory, MessageFormatter messageFormatter, IChatModelWithTokenCounting model) + { + return new ConversationSummaryBufferMemory(model, chatHistory) + { + MaxTokenCount = 25, + Formatter = messageFormatter + }; + } +} From c23c09a94f0eb16708924bdb74d9ad274427cfcf Mon Sep 17 00:00:00 2001 From: Peter James Date: Sat, 3 Feb 2024 23:54:36 -0800 Subject: [PATCH 4/5] Address PR feedback --- examples/LangChain.Samples.Memory/Program.cs | 2 +- .../Chains/StackableChains/LoadMemoryChain.cs | 2 +- .../Memory/BaseChatMessageHistory.cs | 2 -- .../Memory/FileChatMessageHistory.cs | 23 +++++++++++++------ 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/examples/LangChain.Samples.Memory/Program.cs b/examples/LangChain.Samples.Memory/Program.cs index 0f211d1f..e6fb0447 100644 --- a/examples/LangChain.Samples.Memory/Program.cs +++ b/examples/LangChain.Samples.Memory/Program.cs @@ -100,7 +100,7 @@ private static BaseChatMemory PickMemoryStrategy(IChatModel model) return GetConversationSummaryBufferMemory(chatHistory, messageFormatter, (IChatModelWithTokenCounting)model); default: - throw new Exception($"Unexpected memory class name: '{memoryClassName}'"); + throw new InvalidOperationException($"Unexpected memory class name: '{memoryClassName}'"); } } diff --git a/src/libs/LangChain.Core/Chains/StackableChains/LoadMemoryChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/LoadMemoryChain.cs index 46e7920f..f36b3f56 100644 --- a/src/libs/LangChain.Core/Chains/StackableChains/LoadMemoryChain.cs +++ b/src/libs/LangChain.Core/Chains/StackableChains/LoadMemoryChain.cs @@ -25,7 +25,7 @@ protected override Task InternalCall(IChainValues values) string memoryVariableName = _chatMemory.MemoryVariables.FirstOrDefault(); if (memoryVariableName == null) { - throw new Exception("Missing memory variable name"); + throw new InvalidOperationException("Missing memory variable name"); } OutputValues outputValues = _chatMemory.LoadMemoryVariables(null); diff --git a/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs b/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs index bb8cf22d..c8f85d8b 100644 --- a/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs +++ b/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs @@ -1,6 +1,4 @@ using LangChain.Providers; -using System.Numerics; -using System.Reflection.Emit; namespace LangChain.Memory; diff --git a/src/libs/LangChain.Core/Memory/FileChatMessageHistory.cs b/src/libs/LangChain.Core/Memory/FileChatMessageHistory.cs index a6896bdc..d4af1ce2 100644 --- a/src/libs/LangChain.Core/Memory/FileChatMessageHistory.cs +++ b/src/libs/LangChain.Core/Memory/FileChatMessageHistory.cs @@ -20,26 +20,36 @@ public class FileChatMessageHistory : BaseChatMessageHistory /// /// path of the local file to store the messages /// - public FileChatMessageHistory(string messagesFilePath) + private FileChatMessageHistory(string messagesFilePath) { MessagesFilePath = messagesFilePath ?? throw new ArgumentNullException(nameof(messagesFilePath)); + } - // Blocking call in the constructor creates a simpler implementation - LoadMessages().Wait(); + /// + /// Create new history instance with provided file path + /// + /// path of the local file to store the messages + /// + public static async Task CreateAsync(string path, CancellationToken cancellationToken = default) + { + FileChatMessageHistory chatHistory = new FileChatMessageHistory(path); + await chatHistory.LoadMessages().ConfigureAwait(false); + + return chatHistory; } - + /// public override async Task AddMessage(Message message) { _messages.Add(message); - await SaveMessages(); + await SaveMessages().ConfigureAwait(false); } /// public override async Task Clear() { _messages.Clear(); - await SaveMessages(); + await SaveMessages().ConfigureAwait(false); } private async Task SaveMessages() @@ -56,5 +66,4 @@ private async Task LoadMessages() _messages = JsonSerializer.Deserialize>(json); } } - } \ No newline at end of file From bab5fb218d1338ab880810e7d23496caf64bf0a1 Mon Sep 17 00:00:00 2001 From: Peter James Date: Sun, 4 Feb 2024 00:08:13 -0800 Subject: [PATCH 5/5] Address PR feedback --- .../LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs index bb93fd85..e8781734 100644 --- a/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs +++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs @@ -20,7 +20,6 @@ public class GroupChat : BaseStackableChain int _currentAgentId; private readonly MessageFormatter _messageFormatter; private readonly ChatMessageHistory _chatMessageHistory; - private readonly ConversationBufferMemory _conversationBufferMemory; /// ///