Skip to content

Commit

Permalink
fix: SageMaker customizable inputs and responses
Browse files Browse the repository at this point in the history
  • Loading branch information
curlyfro committed Mar 3, 2024
1 parent 0d5eaa5 commit 3030850
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 223 deletions.
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp" Version="4.8.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.8.0" />
<PackageVersion Include="Microsoft.CodeAnalysis.PublicApiAnalyzers" Version="3.3.4" />
<PackageVersion Include="Microsoft.CSharp" Version="4.7.0" />
<PackageVersion Include="Microsoft.Data.Sqlite.Core" Version="8.0.2" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.AI.OpenAI" Version="0.15.230531.5-preview" />
Expand Down
63 changes: 0 additions & 63 deletions src/Providers/Amazon.Bedrock/test/EmbeddedResource.cs

This file was deleted.

70 changes: 0 additions & 70 deletions src/Providers/Amazon.Sagemaker/src/Chat/SageMakerChatModel.cs

This file was deleted.

54 changes: 0 additions & 54 deletions src/Providers/Amazon.Sagemaker/src/Chat/SageMakerChatSettings.cs

This file was deleted.

20 changes: 0 additions & 20 deletions src/Providers/Amazon.Sagemaker/src/Internal/BedrockExtensions.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
<PackageReference Include="System.Text.Json" />
</ItemGroup>

<ItemGroup>
<Compile Remove="ISageMakerModel.cs" />
<Compile Remove="ISageMakerModel`2.cs" />
<Compile Remove="SageMakerBaseModel.cs" />
<Compile Remove="SageMakerRequest.cs" />
<Compile Remove="SageMakerResponse.cs" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.CSharp" />
</ItemGroup>

<ItemGroup Label="Usings">
<Using Include="System.Net.Http" />
</ItemGroup>
Expand Down
57 changes: 57 additions & 0 deletions src/Providers/Amazon.Sagemaker/src/SageMakerModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using System.Diagnostics;
using System.Text;
using System.Text.Json;

// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.SageMaker;

public class SageMakerModel(
SageMakerProvider provider,
string endpointName)
: ChatModel(id: endpointName ?? throw new ArgumentNullException(nameof(endpointName), "SageMaker Endpoint Name is not defined"))
{
public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var messages = request.Messages.ToList();

var watch = Stopwatch.StartNew();

var usedSettings = SageMakerSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);
usedSettings.InputParamers?.Add("endpointName", Id);

using StringContent jsonContent = new(
JsonSerializer.Serialize(usedSettings.InputParamers),
Encoding.UTF8,
usedSettings.ContentType!);

using var response = await provider.HttpClient.PostAsync(provider.Uri, jsonContent, cancellationToken)
.ConfigureAwait(false);

response.EnsureSuccessStatusCode();

dynamic output = usedSettings.TransformOutput!(response);
messages.Add(new Message(output.Result, MessageRole.Ai));

var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);

return new ChatResponse
{
Messages = messages,
UsedSettings = usedSettings,
Usage = usage,
};
}
}
6 changes: 3 additions & 3 deletions src/Providers/Amazon.Sagemaker/src/SageMakerProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ namespace LangChain.Providers.Amazon.SageMaker;
///
/// </summary>
public class SageMakerProvider(
string apiGatewayEndpoint)
string apiGatewayRoute)
: Provider(id: "SageMaker")
{
#region Properties

public HttpClient HttpClient { get; } = new();
public Uri Uri { get; } = new(apiGatewayEndpoint ?? throw new ArgumentNullException(nameof(apiGatewayEndpoint), "API Gateway Endpoint is not defined"));
public Uri Uri { get; } = new(apiGatewayRoute ?? throw new ArgumentNullException(nameof(apiGatewayRoute), "API Gateway Endpoint is not defined"));

#endregion
}
Loading

0 comments on commit 3030850

Please sign in to comment.