Skip to content

Commit

Permalink
feat: added Amazon Knowledgebase
Browse files Browse the repository at this point in the history
  • Loading branch information
curlyfro committed May 13, 2024
1 parent f9fadd1 commit 36fb73c
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
<PackageVersion Include="Anthropic.SDK" Version="3.2.0" />
<PackageVersion Include="Anyscale" Version="1.0.2" />
<PackageVersion Include="Aspose.PDF" Version="24.4.0" />
<PackageVersion Include="AWSSDK.BedrockAgentRuntime" Version="3.7.306.2" />
<PackageVersion Include="AWSSDK.BedrockRuntime" Version="3.7.302.1" />
<PackageVersion Include="AWSSDK.Kendra" Version="3.7.301.21" />
<PackageVersion Include="AWSSDK.OpenSearchService" Version="3.7.305.8" />
Expand Down
6 changes: 6 additions & 0 deletions src/Providers/Amazon.Bedrock/src/BedrockProvider.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Amazon;
using Amazon.BedrockAgentRuntime;
using Amazon.BedrockRuntime;

namespace LangChain.Providers.Amazon.Bedrock;
Expand All @@ -25,6 +26,7 @@ public BedrockProvider() : this(RegionEndpoint.USEast1)
public BedrockProvider(RegionEndpoint region) : base(DefaultProviderId)
{
Api = new AmazonBedrockRuntimeClient(region);
AgentApi = new AmazonBedrockAgentRuntimeClient(region);
}

/// <summary>
Expand All @@ -38,12 +40,16 @@ public BedrockProvider(string accessKeyId, string secretAccessKey, RegionEndpoin
: base(DefaultProviderId)
{
Api = new AmazonBedrockRuntimeClient(accessKeyId, secretAccessKey, region ?? RegionEndpoint.USEast1);
AgentApi = new AmazonBedrockAgentRuntimeClient(accessKeyId, secretAccessKey, region ?? RegionEndpoint.USEast1);
}

#region Properties

[CLSCompliant(false)]
public AmazonBedrockRuntimeClient Api { get; }

[CLSCompliant(false)]
public AmazonBedrockAgentRuntimeClient AgentApi { get; }

#endregion
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using System.Diagnostics;
using Amazon.BedrockAgentRuntime;
using Amazon.BedrockAgentRuntime.Model;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public abstract class AmazonKnowledgeBaseChatModel(
BedrockProvider provider,
string id)
: ChatModel(id)
{
private readonly string _id = id;

/// <summary>
/// Generates a chat response based on the provided `ChatRequest`.
/// </summary>
/// <param name="request">The `ChatRequest` containing the input messages and other parameters.</param>
/// <param name="settings">Optional `ChatSettings` to override the model's default settings.</param>
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param>
/// <returns>A `ChatResponse` containing the generated messages and usage information.</returns>
public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();

var usedSettings = AmazonKnowledgeBaseChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);

var retrieveAndGenerateRequest = new RetrieveAndGenerateRequest
{
Input = new RetrieveAndGenerateInput { Text = prompt },
RetrieveAndGenerateConfiguration = new RetrieveAndGenerateConfiguration
{
Type = RetrieveAndGenerateType.KNOWLEDGE_BASE,
KnowledgeBaseConfiguration = new KnowledgeBaseRetrieveAndGenerateConfiguration
{
KnowledgeBaseId = usedSettings?.KnowledgeBaseId,
ModelArn = _id,
RetrievalConfiguration = new KnowledgeBaseRetrievalConfiguration
{
VectorSearchConfiguration = new KnowledgeBaseVectorSearchConfiguration
{
OverrideSearchType = usedSettings?.SelectedSearchType,
Filter = usedSettings?.Filter
}
}
}
}
};
var response = await provider.AgentApi!.RetrieveAndGenerateAsync(retrieveAndGenerateRequest, cancellationToken)
.ConfigureAwait(false);

var result = request.Messages.ToList();
result.Add(response.Output.Text.AsAiMessage());
usedSettings!.Citations = response.Citations;

var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);

return new ChatResponse
{
Messages = result,
UsedSettings = usedSettings,
Usage = usage,
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using Amazon.BedrockAgentRuntime;
using Amazon.BedrockAgentRuntime.Model;

// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;


public class AmazonKnowledgeBaseChatSettings : BedrockChatSettings
{
public new static AmazonKnowledgeBaseChatSettings Default { get; } = new()
{
SelectedSearchType = "VECTOR",
KnowledgeBaseId = null
};

/// <summary>
/// Knowledge base id
/// </summary>
public required string? KnowledgeBaseId { get; init; }

/// <summary>
/// Knowledge base search type
/// </summary>
public SearchType? SelectedSearchType { get; init; }

Check warning on line 24 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Type of 'AmazonKnowledgeBaseChatSettings.SelectedSearchType' is not CLS-compliant

Check warning on line 24 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Type of 'AmazonKnowledgeBaseChatSettings.SelectedSearchType' is not CLS-compliant

/// <summary>
/// Knowledge base filter
/// </summary>
public RetrievalFilter? Filter { get; set; }

Check warning on line 29 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Type of 'AmazonKnowledgeBaseChatSettings.Filter' is not CLS-compliant

Check warning on line 29 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Type of 'AmazonKnowledgeBaseChatSettings.Filter' is not CLS-compliant

/// <summary>
/// Knowledge base response citations
/// </summary>
public List<Citation>? Citations { get; set; }

Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Type of 'AmazonKnowledgeBaseChatSettings.Citations' is not CLS-compliant

Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Change 'Citations' to be read-only by removing the property setter (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca2227)

Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Change 'List<Citation>' in 'AmazonKnowledgeBaseChatSettings.Citations' to use 'Collection<T>', 'ReadOnlyCollection<T>' or 'KeyedCollection<K,V>' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1002)

Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Type of 'AmazonKnowledgeBaseChatSettings.Citations' is not CLS-compliant

Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Change 'Citations' to be read-only by removing the property setter (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca2227)

Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Change 'List<Citation>' in 'AmazonKnowledgeBaseChatSettings.Citations' to use 'Collection<T>', 'ReadOnlyCollection<T>' or 'KeyedCollection<K,V>' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1002)

/// <summary>
/// Calculate the settings to use for the request.
/// </summary>
/// <param name="requestSettings"></param>
/// <param name="modelSettings"></param>
/// <param name="providerSettings"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public new static AmazonKnowledgeBaseChatSettings Calculate(
ChatSettings? requestSettings,
ChatSettings? modelSettings,
ChatSettings? providerSettings)
{
var requestSettingsCasted = requestSettings as AmazonKnowledgeBaseChatSettings;
var modelSettingsCasted = modelSettings as AmazonKnowledgeBaseChatSettings;
var providerSettingsCasted = providerSettings as AmazonKnowledgeBaseChatSettings;

return new AmazonKnowledgeBaseChatSettings
{
KnowledgeBaseId =
requestSettingsCasted?.KnowledgeBaseId ??
modelSettingsCasted?.KnowledgeBaseId ??
providerSettingsCasted?.KnowledgeBaseId ??
Default.KnowledgeBaseId ??
throw new InvalidOperationException("KnowledgeBaseId can not be null."),

SelectedSearchType =
requestSettingsCasted?.SelectedSearchType ??
modelSettingsCasted?.SelectedSearchType ??
providerSettingsCasted?.SelectedSearchType ??
Default.SelectedSearchType ??
throw new InvalidOperationException("Default SelectedSearchType is not set."),

Filter =
requestSettingsCasted?.Filter ??
modelSettingsCasted?.Filter ??
providerSettingsCasted?.Filter ??
Default.Filter ??
throw new InvalidOperationException("Default Filter is not set."),
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="AWSSDK.BedrockAgentRuntime" />
<PackageReference Include="AWSSDK.BedrockRuntime" />
</ItemGroup>

Expand Down

0 comments on commit 36fb73c

Please sign in to comment.