-
-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
using System.Diagnostics; | ||
using Amazon.BedrockAgentRuntime; | ||
using Amazon.BedrockAgentRuntime.Model; | ||
using LangChain.Providers.Amazon.Bedrock.Internal; | ||
|
||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock; | ||
|
||
public abstract class AmazonKnowledgeBaseChatModel( | ||
BedrockProvider provider, | ||
string id) | ||
: ChatModel(id) | ||
{ | ||
private readonly string _id = id; | ||
|
||
/// <summary> | ||
/// Generates a chat response based on the provided `ChatRequest`. | ||
/// </summary> | ||
/// <param name="request">The `ChatRequest` containing the input messages and other parameters.</param> | ||
/// <param name="settings">Optional `ChatSettings` to override the model's default settings.</param> | ||
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param> | ||
/// <returns>A `ChatResponse` containing the generated messages and usage information.</returns> | ||
public override async Task<ChatResponse> GenerateAsync( | ||
ChatRequest request, | ||
ChatSettings? settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
var prompt = request.Messages.ToSimplePrompt(); | ||
|
||
var usedSettings = AmazonKnowledgeBaseChatSettings.Calculate( | ||
requestSettings: settings, | ||
modelSettings: Settings, | ||
providerSettings: provider.ChatSettings); | ||
|
||
var retrieveAndGenerateRequest = new RetrieveAndGenerateRequest | ||
{ | ||
Input = new RetrieveAndGenerateInput { Text = prompt }, | ||
RetrieveAndGenerateConfiguration = new RetrieveAndGenerateConfiguration | ||
{ | ||
Type = RetrieveAndGenerateType.KNOWLEDGE_BASE, | ||
KnowledgeBaseConfiguration = new KnowledgeBaseRetrieveAndGenerateConfiguration | ||
{ | ||
KnowledgeBaseId = usedSettings?.KnowledgeBaseId, | ||
ModelArn = _id, | ||
RetrievalConfiguration = new KnowledgeBaseRetrievalConfiguration | ||
{ | ||
VectorSearchConfiguration = new KnowledgeBaseVectorSearchConfiguration | ||
{ | ||
OverrideSearchType = usedSettings?.SelectedSearchType, | ||
Filter = usedSettings?.Filter | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
var response = await provider.AgentApi!.RetrieveAndGenerateAsync(retrieveAndGenerateRequest, cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
var result = request.Messages.ToList(); | ||
result.Add(response.Output.Text.AsAiMessage()); | ||
usedSettings!.Citations = response.Citations; | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
AddUsage(usage); | ||
provider.AddUsage(usage); | ||
|
||
return new ChatResponse | ||
{ | ||
Messages = result, | ||
UsedSettings = usedSettings, | ||
Usage = usage, | ||
}; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
using Amazon.BedrockAgentRuntime; | ||
using Amazon.BedrockAgentRuntime.Model; | ||
|
||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock; | ||
|
||
|
||
public class AmazonKnowledgeBaseChatSettings : BedrockChatSettings | ||
{ | ||
public new static AmazonKnowledgeBaseChatSettings Default { get; } = new() | ||
{ | ||
SelectedSearchType = "VECTOR", | ||
KnowledgeBaseId = null | ||
}; | ||
|
||
/// <summary> | ||
/// Knowledge base id | ||
/// </summary> | ||
public required string? KnowledgeBaseId { get; init; } | ||
|
||
/// <summary> | ||
/// Knowledge base search type | ||
/// </summary> | ||
public SearchType? SelectedSearchType { get; init; } | ||
Check warning on line 24 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs GitHub Actions / Build and test / Build, test and publish
|
||
|
||
/// <summary> | ||
/// Knowledge base filter | ||
/// </summary> | ||
public RetrievalFilter? Filter { get; set; } | ||
Check warning on line 29 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs GitHub Actions / Build and test / Build, test and publish
|
||
|
||
/// <summary> | ||
/// Knowledge base response citations | ||
/// </summary> | ||
public List<Citation>? Citations { get; set; } | ||
Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs GitHub Actions / Build and test / Build, test and publish
Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs GitHub Actions / Build and test / Build, test and publish
Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs GitHub Actions / Build and test / Build, test and publish
Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs GitHub Actions / Build and test / Build, test and publish
Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs GitHub Actions / Build and test / Build, test and publish
Check warning on line 34 in src/Providers/Amazon.Bedrock/src/Chat/Settings/AmazonKnowledgeBaseChatSettings.cs GitHub Actions / Build and test / Build, test and publish
|
||
|
||
/// <summary> | ||
/// Calculate the settings to use for the request. | ||
/// </summary> | ||
/// <param name="requestSettings"></param> | ||
/// <param name="modelSettings"></param> | ||
/// <param name="providerSettings"></param> | ||
/// <returns></returns> | ||
/// <exception cref="InvalidOperationException"></exception> | ||
public new static AmazonKnowledgeBaseChatSettings Calculate( | ||
ChatSettings? requestSettings, | ||
ChatSettings? modelSettings, | ||
ChatSettings? providerSettings) | ||
{ | ||
var requestSettingsCasted = requestSettings as AmazonKnowledgeBaseChatSettings; | ||
var modelSettingsCasted = modelSettings as AmazonKnowledgeBaseChatSettings; | ||
var providerSettingsCasted = providerSettings as AmazonKnowledgeBaseChatSettings; | ||
|
||
return new AmazonKnowledgeBaseChatSettings | ||
{ | ||
KnowledgeBaseId = | ||
requestSettingsCasted?.KnowledgeBaseId ?? | ||
modelSettingsCasted?.KnowledgeBaseId ?? | ||
providerSettingsCasted?.KnowledgeBaseId ?? | ||
Default.KnowledgeBaseId ?? | ||
throw new InvalidOperationException("KnowledgeBaseId can not be null."), | ||
|
||
SelectedSearchType = | ||
requestSettingsCasted?.SelectedSearchType ?? | ||
modelSettingsCasted?.SelectedSearchType ?? | ||
providerSettingsCasted?.SelectedSearchType ?? | ||
Default.SelectedSearchType ?? | ||
throw new InvalidOperationException("Default SelectedSearchType is not set."), | ||
|
||
Filter = | ||
requestSettingsCasted?.Filter ?? | ||
modelSettingsCasted?.Filter ?? | ||
providerSettingsCasted?.Filter ?? | ||
Default.Filter ?? | ||
throw new InvalidOperationException("Default Filter is not set."), | ||
}; | ||
} | ||
} |