diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index 18348f99..63f9e887 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -8,6 +8,7 @@ + diff --git a/src/Providers/Amazon.Bedrock/src/BedrockProvider.cs b/src/Providers/Amazon.Bedrock/src/BedrockProvider.cs index 5d9446dd..1a7e2027 100644 --- a/src/Providers/Amazon.Bedrock/src/BedrockProvider.cs +++ b/src/Providers/Amazon.Bedrock/src/BedrockProvider.cs @@ -1,4 +1,5 @@ using Amazon; +using Amazon.BedrockAgentRuntime; using Amazon.BedrockRuntime; namespace LangChain.Providers.Amazon.Bedrock; @@ -25,6 +26,7 @@ public BedrockProvider() : this(RegionEndpoint.USEast1) public BedrockProvider(RegionEndpoint region) : base(DefaultProviderId) { Api = new AmazonBedrockRuntimeClient(region); + AgentApi = new AmazonBedrockAgentRuntimeClient(region); } /// @@ -38,6 +40,7 @@ public BedrockProvider(string accessKeyId, string secretAccessKey, RegionEndpoin : base(DefaultProviderId) { Api = new AmazonBedrockRuntimeClient(accessKeyId, secretAccessKey, region ?? RegionEndpoint.USEast1); + AgentApi = new AmazonBedrockAgentRuntimeClient(accessKeyId, secretAccessKey, region ?? RegionEndpoint.USEast1); } #region Properties @@ -45,5 +48,8 @@ public BedrockProvider(string accessKeyId, string secretAccessKey, RegionEndpoin [CLSCompliant(false)] public AmazonBedrockRuntimeClient Api { get; } + [CLSCompliant(false)] + public AmazonBedrockAgentRuntimeClient AgentApi { get; } + #endregion } \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Chat/AmazonKnowledgeBaseChatModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/AmazonKnowledgeBaseChatModel.cs new file mode 100644 index 00000000..49b813a5 --- /dev/null +++ b/src/Providers/Amazon.Bedrock/src/Chat/AmazonKnowledgeBaseChatModel.cs @@ -0,0 +1,80 @@ +using System.Diagnostics; +using Amazon.BedrockAgentRuntime; +using Amazon.BedrockAgentRuntime.Model; +using LangChain.Providers.Amazon.Bedrock.Internal; + +// ReSharper disable once CheckNamespace +namespace LangChain.Providers.Amazon.Bedrock; + +public abstract class AmazonKnowledgeBaseChatModel( + BedrockProvider provider, + string id) + : ChatModel(id) +{ + private readonly string _id = id; + + /// + /// Generates a chat response based on the provided `ChatRequest`. + /// + /// The `ChatRequest` containing the input messages and other parameters. + /// Optional `ChatSettings` to override the model's default settings. + /// A cancellation token to cancel the operation. + /// A `ChatResponse` containing the generated messages and usage information. + public override async Task GenerateAsync( + ChatRequest request, + ChatSettings? settings = null, + CancellationToken cancellationToken = default) + { + request = request ?? throw new ArgumentNullException(nameof(request)); + + var watch = Stopwatch.StartNew(); + var prompt = request.Messages.ToSimplePrompt(); + + var usedSettings = AmazonKnowledgeBaseChatSettings.Calculate( + requestSettings: settings, + modelSettings: Settings, + providerSettings: provider.ChatSettings); + + var retrieveAndGenerateRequest = new RetrieveAndGenerateRequest + { + Input = new RetrieveAndGenerateInput { Text = prompt }, + RetrieveAndGenerateConfiguration = new RetrieveAndGenerateConfiguration + { + Type = RetrieveAndGenerateType.KNOWLEDGE_BASE, + KnowledgeBaseConfiguration = new KnowledgeBaseRetrieveAndGenerateConfiguration + { + KnowledgeBaseId = usedSettings?.KnowledgeBaseId, + ModelArn = _id, + RetrievalConfiguration = new KnowledgeBaseRetrievalConfiguration + { + VectorSearchConfiguration = new KnowledgeBaseVectorSearchConfiguration + { + OverrideSearchType = usedSettings?.SelectedSearchType, + Filter = usedSettings?.Filter + } + } + } + } + }; + var response = await provider.AgentApi!.RetrieveAndGenerateAsync(retrieveAndGenerateRequest, cancellationToken) + .ConfigureAwait(false); + + var result = request.Messages.ToList(); + result.Add(response.Output.Text.AsAiMessage()); + usedSettings!.Citations = response.Citations; + + var usage = Usage.Empty with + { + Time = watch.Elapsed, + }; + AddUsage(usage); + provider.AddUsage(usage); + + return new ChatResponse + { + Messages = result, + UsedSettings = usedSettings, + Usage = usage, + }; + } +} \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs b/src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs new file mode 100644 index 00000000..863710a0 --- /dev/null +++ b/src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs @@ -0,0 +1,77 @@ +using Amazon.BedrockAgentRuntime; +using Amazon.BedrockAgentRuntime.Model; + +// ReSharper disable once CheckNamespace +namespace LangChain.Providers.Amazon.Bedrock; + + +public class AmazonKnowledgeBaseChatSettings : BedrockChatSettings +{ + public new static AmazonKnowledgeBaseChatSettings Default { get; } = new() + { + SelectedSearchType = "VECTOR", + KnowledgeBaseId = null + }; + + /// + /// Knowledge base id + /// + public required string? KnowledgeBaseId { get; init; } + + /// + /// Knowledge base search type + /// + public SearchType? SelectedSearchType { get; init; } + + /// + /// Knowledge base filter + /// + public RetrievalFilter? Filter { get; set; } + + /// + /// Knowledge base response citations + /// + public List? Citations { get; set; } + + /// + /// Calculate the settings to use for the request. + /// + /// + /// + /// + /// + /// + public new static AmazonKnowledgeBaseChatSettings Calculate( + ChatSettings? requestSettings, + ChatSettings? modelSettings, + ChatSettings? providerSettings) + { + var requestSettingsCasted = requestSettings as AmazonKnowledgeBaseChatSettings; + var modelSettingsCasted = modelSettings as AmazonKnowledgeBaseChatSettings; + var providerSettingsCasted = providerSettings as AmazonKnowledgeBaseChatSettings; + + return new AmazonKnowledgeBaseChatSettings + { + KnowledgeBaseId = + requestSettingsCasted?.KnowledgeBaseId ?? + modelSettingsCasted?.KnowledgeBaseId ?? + providerSettingsCasted?.KnowledgeBaseId ?? + Default.KnowledgeBaseId ?? + throw new InvalidOperationException("KnowledgeBaseId can not be null."), + + SelectedSearchType = + requestSettingsCasted?.SelectedSearchType ?? + modelSettingsCasted?.SelectedSearchType ?? + providerSettingsCasted?.SelectedSearchType ?? + Default.SelectedSearchType ?? + throw new InvalidOperationException("Default SelectedSearchType is not set."), + + Filter = + requestSettingsCasted?.Filter ?? + modelSettingsCasted?.Filter ?? + providerSettingsCasted?.Filter ?? + Default.Filter ?? + throw new InvalidOperationException("Default Filter is not set."), + }; + } +} \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/LangChain.Providers.Amazon.Bedrock.csproj b/src/Providers/Amazon.Bedrock/src/LangChain.Providers.Amazon.Bedrock.csproj index 00d595f7..67ae6183 100644 --- a/src/Providers/Amazon.Bedrock/src/LangChain.Providers.Amazon.Bedrock.csproj +++ b/src/Providers/Amazon.Bedrock/src/LangChain.Providers.Amazon.Bedrock.csproj @@ -10,6 +10,7 @@ +