Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add additional memory classes and do some cleanup #132

Merged
merged 6 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Automat
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Automatic1111.IntegrationTests", "src\tests\LangChain.Providers.Automatic1111.IntegrationTests\LangChain.Providers.Automatic1111.IntegrationTests.csproj", "{A6CF79BC-8365-46E8-9230-1A4AD615D40B}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Samples.FileMemory", "examples\LangChain.Samples.FileMemory\LangChain.Samples.FileMemory.csproj", "{BA701280-0BEB-4DA4-92B3-9C777082C2AF}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Bedrock.IntegrationTests", "src\tests\LangChain.Providers.Bedrock.IntegrationTests\LangChain.Providers.Bedrock.IntegrationTests.csproj", "{73C76E80-95C5-4C96-A319-4F32043C903E}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Bedrock", "src\libs\Providers\LangChain.Providers.Bedrock\LangChain.Providers.Bedrock.csproj", "{67985CCB-F606-41F8-9D36-513459F58882}"
Expand Down Expand Up @@ -408,10 +406,6 @@ Global
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A6CF79BC-8365-46E8-9230-1A4AD615D40B}.Release|Any CPU.Build.0 = Release|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Debug|Any CPU.Build.0 = Debug|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Release|Any CPU.ActiveCfg = Release|Any CPU
{BA701280-0BEB-4DA4-92B3-9C777082C2AF}.Release|Any CPU.Build.0 = Release|Any CPU
{73C76E80-95C5-4C96-A319-4F32043C903E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{73C76E80-95C5-4C96-A319-4F32043C903E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{73C76E80-95C5-4C96-A319-4F32043C903E}.Release|Any CPU.ActiveCfg = Release|Any CPU
Expand Down Expand Up @@ -488,7 +482,6 @@ Global
{4913844F-74EC-4E74-AE8A-EA825569E6BA} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{BF4C7B87-0997-4208-84EF-D368DF7B9861} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{A6CF79BC-8365-46E8-9230-1A4AD615D40B} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{BA701280-0BEB-4DA4-92B3-9C777082C2AF} = {F17A86AE-A174-4B6C-BAA7-9D9A9704BE85}
{73C76E80-95C5-4C96-A319-4F32043C903E} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{67985CCB-F606-41F8-9D36-513459F58882} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
EndGlobalSection
Expand Down

This file was deleted.

54 changes: 0 additions & 54 deletions examples/LangChain.Samples.FileMemory/Program.cs

This file was deleted.

177 changes: 170 additions & 7 deletions examples/LangChain.Samples.Memory/Program.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,176 @@
using LangChain.Memory;
using LangChain.Providers;
using LangChain.Providers.OpenAI;
using System.Runtime.Serialization;
using static LangChain.Chains.Chain;

var inMemoryHistory = new ChatMessageHistory();
internal class Program
{
private static async Task Main(string[] args)
{
// Pull the API key from the environment, so it's never checked in with source
var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ??
throw new InvalidOperationException("OPENAI_API_KEY environment variable is not found.");

await inMemoryHistory.AddUserMessage("hi!");
// Use a common, general-purpose LLM
var model = new OpenAiModel(apiKey, "gpt-3.5-turbo");

await inMemoryHistory.AddAiMessage("whats up?");
// Create a simple prompt template for the conversation to help the AI
var template = @"
The following is a friendly conversation between a human and an AI.

foreach (var message in inMemoryHistory.Messages)
{
Console.WriteLine(message.GetType().Name + ":" + message.Content);
}
{history}
Human: {input}
AI: ";

// To have a conversation that remembers previous messages we need to use memory.
// Here we pick one of a number of different strategies for implementing memory.
var memory = PickMemoryStrategy(model);

// Build the chain that will be used for each turn in our conversation.
// This is just declaring the chain. Actual execution of the chain happens
// in the conversation loop below. On every pass through the loop, the user's
// input is added to the beginning of this chain to make a new chain.
var chain =
LoadMemory(memory, outputKey: "history")
| Template(template)
| LLM(model)
| UpdateMemory(memory, requestKey: "input", responseKey: "text");

Console.WriteLine();
Console.WriteLine("Start a conversation with the friendly AI!");
Console.WriteLine("(Enter 'exit' or hit Ctrl-C to end the conversation)");

// Run an endless loop of conversation
while (true)
{
Console.WriteLine();

Console.Write("Human: ");
var input = Console.ReadLine();
if (input == "exit")
{
break;
}

// Build a new chain by prepending the user's input to the original chain
var currentChain = Set(input, "input")
| chain;

// Get a response from the AI
var response = await currentChain.Run("text");

Console.Write("AI: ");
Console.WriteLine(response);
}
}

private static BaseChatMemory PickMemoryStrategy(IChatModel model)
{
// The memory will add prefixes to messages to indicate where they came from
// The prefixes specified here should match those used in our prompt template
MessageFormatter messageFormatter = new MessageFormatter
{
AiPrefix = "AI",
HumanPrefix = "Human"
};

BaseChatMessageHistory chatHistory = GetChatMessageHistory();

string memoryClassName = PromptForChoice(new[]
{
nameof(ConversationBufferMemory),
nameof(ConversationWindowBufferMemory),
nameof(ConversationSummaryMemory),
nameof(ConversationSummaryBufferMemory)
});

switch (memoryClassName)
{
case nameof(ConversationBufferMemory):
return GetConversationBufferMemory(chatHistory, messageFormatter);

case nameof(ConversationWindowBufferMemory):
return GetConversationWindowBufferMemory(chatHistory, messageFormatter);

case nameof(ConversationSummaryMemory):
return GetConversationSummaryMemory(chatHistory, messageFormatter, model);

case nameof(ConversationSummaryBufferMemory):
return GetConversationSummaryBufferMemory(chatHistory, messageFormatter, (IChatModelWithTokenCounting)model);

default:
throw new InvalidOperationException($"Unexpected memory class name: '{memoryClassName}'");
}
}

private static string PromptForChoice(string[] choiceTexts)
{
while (true)
{
Console.Clear();
Console.WriteLine("Select from the following options:");

int choiceNumber = 1;

foreach (string choiceText in choiceTexts)
{
Console.WriteLine($" {choiceNumber}: {choiceText}");
choiceNumber++;
}

Console.WriteLine();
Console.Write("Enter choice: ");

string choiceEntry = Console.ReadLine();
if (int.TryParse(choiceEntry, out int choiceIndex))
{
string choiceText = choiceTexts[choiceIndex];

Console.WriteLine();
Console.WriteLine($"You selected '{choiceText}'");

return choiceText;
}
}
}

private static BaseChatMessageHistory GetChatMessageHistory()
{
// Other types of chat history work, too!
return new ChatMessageHistory();
}

private static BaseChatMemory GetConversationBufferMemory(BaseChatMessageHistory chatHistory, MessageFormatter messageFormatter)
{
return new ConversationBufferMemory(chatHistory)
{
Formatter = messageFormatter
};
}

private static BaseChatMemory GetConversationWindowBufferMemory(BaseChatMessageHistory chatHistory, MessageFormatter messageFormatter)
{
return new ConversationWindowBufferMemory(chatHistory)
{
WindowSize = 3,
Formatter = messageFormatter
};
}

private static BaseChatMemory GetConversationSummaryMemory(BaseChatMessageHistory chatHistory, MessageFormatter messageFormatter, IChatModel model)
{
return new ConversationSummaryMemory(model, chatHistory)
{
Formatter = messageFormatter
};
}
private static BaseChatMemory GetConversationSummaryBufferMemory(BaseChatMessageHistory chatHistory, MessageFormatter messageFormatter, IChatModelWithTokenCounting model)
{
return new ConversationSummaryBufferMemory(model, chatHistory)
{
MaxTokenCount = 25,
Formatter = messageFormatter
};
}
}
4 changes: 2 additions & 2 deletions src/libs/LangChain.Core/Chains/Chain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,15 @@ public static StuffDocumentsChain CombineDocuments(
/// <param name="responseKey"></param>
/// <returns></returns>
public static UpdateMemoryChain UpdateMemory(
ConversationBufferMemory memory,
BaseChatMemory memory,
string requestKey = "text",
string responseKey = "text")
{
return new UpdateMemoryChain(memory, requestKey, responseKey);
}

public static LoadMemoryChain LoadMemory(
ConversationBufferMemory memory,
BaseChatMemory memory,
string outputKey = "text")
{
return new LoadMemoryChain(memory, outputKey);
Expand Down
33 changes: 24 additions & 9 deletions src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ public class GroupChat : BaseStackableChain
private readonly string _outputKey;

int _currentAgentId;
private readonly ConversationBufferMemory _conversationBufferMemory;
private readonly MessageFormatter _messageFormatter;
private readonly ChatMessageHistory _chatMessageHistory;

/// <summary>
///
Expand Down Expand Up @@ -46,7 +47,20 @@ public GroupChat(
_messagesLimit = messagesLimit;
_inputKey = inputKey;
_outputKey = outputKey;
_conversationBufferMemory = new ConversationBufferMemory(new ChatMessageHistory()) { AiPrefix = "", HumanPrefix = "", SystemPrefix = "", SaveHumanMessages = false };

_messageFormatter = new MessageFormatter
{
AiPrefix = "",
HumanPrefix = "",
SystemPrefix = ""
};

_chatMessageHistory = new ChatMessageHistory()
{
// Do not save human messages
IsMessageAccepted = x => (x.Role != MessageRole.Human)
};

InputKeys = new[] { inputKey };
OutputKeys = new[] { outputKey };
}
Expand All @@ -55,27 +69,28 @@ public GroupChat(
///
/// </summary>
/// <returns></returns>
public IReadOnlyList<Message> History => _conversationBufferMemory.ChatHistory.Messages;
public IReadOnlyList<Message> History => _chatMessageHistory.Messages;

/// <inheritdoc />
protected override async Task<IChainValues> InternalCall(IChainValues values)
{
values = values ?? throw new ArgumentNullException(nameof(values));

await _conversationBufferMemory.Clear().ConfigureAwait(false);
await _chatMessageHistory.Clear().ConfigureAwait(false);
foreach (var agent in _agents)
{
agent.SetHistory("");
}
var firstAgent = _agents[0];
var firstAgentMessage = (string)values.Value[_inputKey];
await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{firstAgent.Name}: {firstAgentMessage}",
await _chatMessageHistory.AddMessage(new Message($"{firstAgent.Name}: {firstAgentMessage}",
MessageRole.System)).ConfigureAwait(false);
int messagesCount = 1;
while (messagesCount<_messagesLimit)
{
var agent = GetNextAgent();
agent.SetHistory(_conversationBufferMemory.BufferAsString+"\n"+$"{agent.Name}:");
string bufferText = _messageFormatter.Format(_chatMessageHistory.Messages);
agent.SetHistory(bufferText + "\n" + $"{agent.Name}:");
var res = await agent.CallAsync(values).ConfigureAwait(false);
var message = (string)res.Value[agent.OutputKeys[0]];
if (message.Contains(_stopPhrase))
Expand All @@ -85,13 +100,13 @@ await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{firstAgent

if (!agent.IsObserver)
{
await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{agent.Name}: {message}",
await _chatMessageHistory.AddMessage(new Message($"{agent.Name}: {message}",
MessageRole.System)).ConfigureAwait(false);
}
}

var result = _conversationBufferMemory.ChatHistory.Messages[^1];
messagesCount = _conversationBufferMemory.ChatHistory.Messages.Count;
var result = _chatMessageHistory.Messages[^1];
messagesCount = _chatMessageHistory.Messages.Count;
if (ThrowOnLimit && messagesCount >= _messagesLimit)
{
throw new InvalidOperationException($"Message limit reached:{_messagesLimit}");
Expand Down
Loading
Loading