From e14cdb08c9b39fab790504d5b89208a6db345ed5 Mon Sep 17 00:00:00 2001 From: Khoroshev Evgeniy Date: Tue, 5 Dec 2023 02:44:38 +0300 Subject: [PATCH] feat: redis message history (#86) Co-authored-by: Evgenii Khoroshev --- LangChain.sln | 7 ++ src/Directory.Packages.props | 1 + .../LangChain.Databases.Redis.csproj | 5 ++ .../RedisChatMessageHistory.cs | 81 +++++++++++++++++ .../LangChain.Core/Memory/BaseChatMemory.cs | 3 +- .../Memory/BaseChatMessageHistory.cs | 5 +- .../Memory/ChatMessageHistory.cs | 7 +- .../Memory/ConversationBufferMemory.cs | 38 +++----- .../LangChain.Core/Memory/MemoryExtensions.cs | 27 +++++- ...in.Databases.Redis.IntegrationTests.csproj | 11 +++ .../RedisChatMessageHistoryTests.cs | 88 +++++++++++++++++++ 11 files changed, 236 insertions(+), 37 deletions(-) create mode 100644 src/libs/Databases/LangChain.Databases.Redis/RedisChatMessageHistory.cs create mode 100644 src/tests/LangChain.Databases.Redis.IntegrationTests/LangChain.Databases.Redis.IntegrationTests.csproj create mode 100644 src/tests/LangChain.Databases.Redis.IntegrationTests/RedisChatMessageHistoryTests.cs diff --git a/LangChain.sln b/LangChain.sln index 1942b733..0c86063a 100644 --- a/LangChain.sln +++ b/LangChain.sln @@ -154,6 +154,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Utilities.Postgre EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Utilities.Postgres.IntegrationTests", "src\tests\LangChain.Utilities.Postgres.IntegrationTests\LangChain.Utilities.Postgres.IntegrationTests.csproj", "{A652E4C6-6988-40BD-A726-2F5A3783C129}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Databases.Redis.IntegrationTests", "src\tests\LangChain.Databases.Redis.IntegrationTests\LangChain.Databases.Redis.IntegrationTests.csproj", "{E19562A0-9AAA-4C75-BE78-648E7148A4CD}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -348,6 +350,10 @@ Global {A652E4C6-6988-40BD-A726-2F5A3783C129}.Debug|Any CPU.Build.0 = Debug|Any CPU {A652E4C6-6988-40BD-A726-2F5A3783C129}.Release|Any CPU.ActiveCfg = Release|Any CPU {A652E4C6-6988-40BD-A726-2F5A3783C129}.Release|Any CPU.Build.0 = Release|Any CPU + {E19562A0-9AAA-4C75-BE78-648E7148A4CD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E19562A0-9AAA-4C75-BE78-648E7148A4CD}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E19562A0-9AAA-4C75-BE78-648E7148A4CD}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E19562A0-9AAA-4C75-BE78-648E7148A4CD}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -407,6 +413,7 @@ Global {7D47EC2D-2F03-4284-A07D-E56486B885C6} = {788567AF-444A-488F-BCED-C3B9F03CC38D} {2A01AC56-7850-48FD-B32F-A7AAF0E86F84} = {788567AF-444A-488F-BCED-C3B9F03CC38D} {A652E4C6-6988-40BD-A726-2F5A3783C129} = {FDEE2E22-C239-4921-83B2-9797F765FD6A} + {E19562A0-9AAA-4C75-BE78-648E7148A4CD} = {FDEE2E22-C239-4921-83B2-9797F765FD6A} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {5C00D0F1-6138-4ED9-846B-97E43D6DFF1C} diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index c234e9ec..ab8a8ca4 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -39,6 +39,7 @@ + diff --git a/src/libs/Databases/LangChain.Databases.Redis/LangChain.Databases.Redis.csproj b/src/libs/Databases/LangChain.Databases.Redis/LangChain.Databases.Redis.csproj index a5cb2570..72b6fcee 100644 --- a/src/libs/Databases/LangChain.Databases.Redis/LangChain.Databases.Redis.csproj +++ b/src/libs/Databases/LangChain.Databases.Redis/LangChain.Databases.Redis.csproj @@ -11,6 +11,11 @@ + + + + + diff --git a/src/libs/Databases/LangChain.Databases.Redis/RedisChatMessageHistory.cs b/src/libs/Databases/LangChain.Databases.Redis/RedisChatMessageHistory.cs new file mode 100644 index 00000000..91c54882 --- /dev/null +++ b/src/libs/Databases/LangChain.Databases.Redis/RedisChatMessageHistory.cs @@ -0,0 +1,81 @@ +using System.Text.Json; +using LangChain.Memory; +using LangChain.Providers; +using StackExchange.Redis; + +namespace LangChain.Databases; + +/// +/// Chat message history stored in a Redis database. +/// +public class RedisChatMessageHistory : BaseChatMessageHistory +{ + private readonly string _sessionId; + private readonly string _keyPrefix; + private readonly TimeSpan? _ttl; + private readonly Lazy _multiplexer; + + /// + public RedisChatMessageHistory( + string sessionId, + string connectionString, + string keyPrefix = "message_store:", + TimeSpan? ttl = null) + { + _sessionId = sessionId; + _keyPrefix = keyPrefix; + _ttl = ttl; + + _multiplexer = new Lazy( + () => + { + var multiplexer = ConnectionMultiplexer.Connect(connectionString); + + return multiplexer; + }, + LazyThreadSafetyMode.ExecutionAndPublication); + } + + /// + /// Construct the record key to use + /// + private string Key => _keyPrefix + _sessionId; + + /// + /// Retrieve the messages from Redis + /// TODO: use async methods + /// + public override IReadOnlyList Messages + { + get + { + var database = _multiplexer.Value.GetDatabase(); + var values = database.ListRange(Key, start: 0, stop: -1); + var messages = values.Select(v => JsonSerializer.Deserialize(v.ToString())).Reverse(); + + return messages.ToList(); + } + } + + /// + /// Append the message to the record in Redis + /// + public override async Task AddMessage(Message message) + { + var database = _multiplexer.Value.GetDatabase(); + await database.ListLeftPushAsync(Key, JsonSerializer.Serialize(message)).ConfigureAwait(false); + if (_ttl.HasValue) + { + await database.KeyExpireAsync(Key, _ttl).ConfigureAwait(false); + } + } + + /// + /// Clear session memory from Redis + /// + public override async Task Clear() + { + var database = _multiplexer.Value.GetDatabase(); + await database.KeyDeleteAsync(Key).ConfigureAwait(false); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/BaseChatMemory.cs b/src/libs/LangChain.Core/Memory/BaseChatMemory.cs index 015a0b79..bcedf9b8 100644 --- a/src/libs/LangChain.Core/Memory/BaseChatMemory.cs +++ b/src/libs/LangChain.Core/Memory/BaseChatMemory.cs @@ -28,7 +28,6 @@ public override async Task SaveContext(InputValues inputValues, OutputValues out public override Task Clear() { - ChatHistory.Clear(); - return Task.CompletedTask; + return ChatHistory.Clear(); } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs b/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs index 1c0bd3b5..232ea99c 100644 --- a/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs +++ b/src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs @@ -4,12 +4,9 @@ namespace LangChain.Memory; public abstract class BaseChatMessageHistory { - public IList Messages { get; set; } = new List(); - public async Task AddUserMessage(string message) { await AddMessage(message.AsHumanMessage()); - } public async Task AddAiMessage(string message) @@ -17,6 +14,8 @@ public async Task AddAiMessage(string message) await AddMessage(message.AsAiMessage()); } + public abstract IReadOnlyList Messages { get; } + public abstract Task AddMessage(Message message); public abstract Task Clear(); diff --git a/src/libs/LangChain.Core/Memory/ChatMessageHistory.cs b/src/libs/LangChain.Core/Memory/ChatMessageHistory.cs index ac60dfa7..14b23fb8 100644 --- a/src/libs/LangChain.Core/Memory/ChatMessageHistory.cs +++ b/src/libs/LangChain.Core/Memory/ChatMessageHistory.cs @@ -4,15 +4,18 @@ namespace LangChain.Memory; public class ChatMessageHistory : BaseChatMessageHistory { + private readonly List _messages = new List(); + public override IReadOnlyList Messages => _messages; + public override Task AddMessage(Message message) { - Messages.Add(message); + _messages.Add(message); return Task.CompletedTask; } public override Task Clear() { - Messages.Clear(); + _messages.Clear(); return Task.CompletedTask; } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs b/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs index 8b65895e..ab368ba9 100644 --- a/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs +++ b/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs @@ -1,5 +1,4 @@ using LangChain.Providers; -using System.Net.Mail; using LangChain.Schema; namespace LangChain.Memory; @@ -16,40 +15,28 @@ public ConversationBufferMemory(BaseChatMessageHistory chatHistory) : base(chatH ChatHistory = chatHistory; } - // note: buffer property can't be implemented because of Any type as return type public string BufferAsString => GetBufferString(BufferAsMessages); - public IList BufferAsMessages => ChatHistory.Messages; + public IReadOnlyList BufferAsMessages => ChatHistory.Messages; - public override List MemoryVariables => new List {MemoryKey}; + public override List MemoryVariables => new List { MemoryKey }; - private string GetBufferString( - IEnumerable messages) + private string GetBufferString(IEnumerable messages) { - List stringMessages = new List(); + var stringMessages = new List(); foreach (var m in messages) { - string role; - switch (m.Role) + string role = m.Role switch { - case MessageRole.Human: - role = HumanPrefix; - break; - case MessageRole.Ai: - role = AiPrefix; - break; - case MessageRole.System: - role = "System"; - break; - case MessageRole.FunctionCall: - role = "Function"; - break; - default: - throw new ArgumentException($"Unsupported message type: {m.GetType().Name}"); - } + MessageRole.Human => HumanPrefix, + MessageRole.Ai => AiPrefix, + MessageRole.System => "System", + MessageRole.FunctionCall => "Function", + _ => throw new ArgumentException($"Unsupported message type: {m.GetType().Name}") + }; string message = $"{role}: {m.Content}"; // TODO: Add special case for a function call @@ -60,9 +47,8 @@ private string GetBufferString( return string.Join("\n", stringMessages); } - public override OutputValues LoadMemoryVariables(InputValues? inputValues) { - return new OutputValues(new Dictionary {{MemoryKey, BufferAsString}}); + return new OutputValues(new Dictionary { { MemoryKey, BufferAsString } }); } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Memory/MemoryExtensions.cs b/src/libs/LangChain.Core/Memory/MemoryExtensions.cs index 8c947575..58e096ba 100644 --- a/src/libs/LangChain.Core/Memory/MemoryExtensions.cs +++ b/src/libs/LangChain.Core/Memory/MemoryExtensions.cs @@ -4,7 +4,9 @@ namespace LangChain.Memory; public static class MemoryExtensions { - public static IReadOnlyCollection WithHistory(this IReadOnlyCollection messages, BaseMemory? memory) + public static IReadOnlyCollection WithHistory( + this IReadOnlyCollection messages, + BaseMemory? memory) { if (memory == null) { @@ -22,9 +24,26 @@ public static IReadOnlyCollection WithHistory(this IReadOnlyCollection< } } - return new[] + var result = new Message[messages.Count + 1]; + result[0] = history.AsHumanMessage(); + messages.CopyTo(result, startIndex: 1); + + return result; + } + + private static void CopyTo(this IReadOnlyCollection source, T[] destination, int startIndex) + { + if (destination.Length > source.Count + startIndex) { - history.AsHumanMessage(), - }.Concat(messages).ToArray(); + throw new ArgumentException( + $"{nameof(destination)} required to have min length of {source.Count + startIndex}, but was {destination.Length}"); + } + + var i = 0; + foreach (var item in source) + { + destination[startIndex + i] = item; + i++; + } } } \ No newline at end of file diff --git a/src/tests/LangChain.Databases.Redis.IntegrationTests/LangChain.Databases.Redis.IntegrationTests.csproj b/src/tests/LangChain.Databases.Redis.IntegrationTests/LangChain.Databases.Redis.IntegrationTests.csproj new file mode 100644 index 00000000..a1aaa4fc --- /dev/null +++ b/src/tests/LangChain.Databases.Redis.IntegrationTests/LangChain.Databases.Redis.IntegrationTests.csproj @@ -0,0 +1,11 @@ + + + + net8.0 + + + + + + + diff --git a/src/tests/LangChain.Databases.Redis.IntegrationTests/RedisChatMessageHistoryTests.cs b/src/tests/LangChain.Databases.Redis.IntegrationTests/RedisChatMessageHistoryTests.cs new file mode 100644 index 00000000..b314ac01 --- /dev/null +++ b/src/tests/LangChain.Databases.Redis.IntegrationTests/RedisChatMessageHistoryTests.cs @@ -0,0 +1,88 @@ +using LangChain.Providers; + +namespace LangChain.Databases.Redis.IntegrationTests; + +/// +/// In order to run tests please run redis locally, e.g. with docker +/// docker run -p 6379:6379 redis +/// +[TestFixture] +[Explicit] +public class RedisChatMessageHistoryTests +{ + private readonly string _connectionString = "127.0.0.1:6379"; + + [Test] + public void GetMessages_EmptyHistory_Ok() + { + var sessionId = "GetMessages_EmptyHistory_Ok"; + var history = new RedisChatMessageHistory( + sessionId, + _connectionString, + ttl: TimeSpan.FromSeconds(30)); + + var existing = history.Messages; + + existing.Should().BeEmpty(); + } + + [Test] + public async Task AddMessage_Ok() + { + var sessionId = "RedisChatMessageHistoryTests_AddMessage_Ok"; + var history = new RedisChatMessageHistory( + sessionId, + _connectionString, + ttl: TimeSpan.FromSeconds(30)); + + var humanMessage = Message.Human("Hi, AI"); + await history.AddMessage(humanMessage); + var aiMessage = Message.Ai("Hi, human"); + await history.AddMessage(aiMessage); + + var actual = history.Messages; + + actual.Should().HaveCount(2); + + actual[0].Role.Should().Be(humanMessage.Role); + actual[0].Content.Should().BeEquivalentTo(humanMessage.Content); + + actual[1].Role.Should().Be(aiMessage.Role); + actual[1].Content.Should().BeEquivalentTo(aiMessage.Content); + } + + [Test] + public async Task Ttl_Ok() + { + var sessionId = "Ttl_Ok"; + var history = new RedisChatMessageHistory( + sessionId, + _connectionString, + ttl: TimeSpan.FromSeconds(2)); + + var humanMessage = Message.Human("Hi, AI"); + await history.AddMessage(humanMessage); + + await Task.Delay(2_500); + + var existing = history.Messages; + + existing.Should().BeEmpty(); + } + + [Test] + public async Task Clear_Ok() + { + var sessionId = "Ttl_Ok"; + var history = new RedisChatMessageHistory( + sessionId, + _connectionString, + ttl: TimeSpan.FromSeconds(30)); + + await history.Clear(); + + var existing = history.Messages; + + existing.Should().BeEmpty(); + } +} \ No newline at end of file