Skip to content

Commit

Permalink
Bump version to 0.1.5-preview
Browse files Browse the repository at this point in the history
Add support for Claude 3 and Anthropic Messages API with streaming
  • Loading branch information
François Bouteruche committed May 16, 2024
1 parent 21ebac8 commit c80c446
Show file tree
Hide file tree
Showing 11 changed files with 444 additions and 16 deletions.
171 changes: 161 additions & 10 deletions src/Rockhead.Extensions.Tests/Anthropic/ClaudeTest.cs

Large diffs are not rendered by default.

68 changes: 67 additions & 1 deletion src/Rockhead.Extensions/AmazonBedrockRuntimeClientExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ async void BodyOnChunkReceived(object? sender, EventStreamEventReceivedArgs<Payl


/// <summary>
/// Invoke a Claude model (Instant V1, V2, V2.1) for text completion a response stream
/// Invoke a Claude model (Instant V1, V2, V2.1) for text completion with a response stream
/// </summary>
/// <param name="client">The Amazon Bedrock Runtime client object</param>
/// <param name="model">The Claude model to invoke</param>
Expand Down Expand Up @@ -405,6 +405,72 @@ async void BodyOnChunkReceived(object? sender, EventStreamEventReceivedArgs<Payl
}
}

/// <summary>
/// Invoke a Claude model (Instant V1, V2, V2.1, V3 Sonnet, V3 Haiku, V3 Opus) for text completion with a response stream
/// </summary>
/// <param name="client">The Amazon Bedrock Runtime client object</param>
/// <param name="model">The Claude model to invoke</param>
/// <param name="prompt">The input text to complete</param>
/// <param name="textGenerationConfig">The text generation configuration</param>
/// <param name="cancellationToken">A cancellation token</param>
/// <returns>An asynchronous enumeration of Claude model responses</returns>
public static async IAsyncEnumerable<IClaudeMessagesChunk> InvokeClaudeMessagesWithResponseStreamAsync(this AmazonBedrockRuntimeClient client, Model.Claude model, ClaudeMessage message, ClaudeMessagesConfig? messagesConfig = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (messagesConfig != null)
{
Validator.ValidateObject(messagesConfig, new ValidationContext(messagesConfig), true);
}
else
{
messagesConfig = new ClaudeMessagesConfig() { MaxTokens = 1000, Messages = new List<ClaudeMessage>() };
}
messagesConfig.Messages.Add(message);

JsonObject payload = JsonSerializer.SerializeToNode(messagesConfig)?.AsObject() ?? new();

InvokeModelWithResponseStreamResponse response = await client.InvokeModelWithResponseStreamAsync(new InvokeModelWithResponseStreamRequest()
{
ModelId = model.ModelId,
ContentType = "application/json",
Accept = "application/json",
Body = AWSSDKUtils.GenerateMemoryStreamFromString(payload.ToJsonString())
},
cancellationToken).ConfigureAwait(false);

Channel<IClaudeMessagesChunk> buffer = Channel.CreateUnbounded<IClaudeMessagesChunk>();
bool isStreaming = true;

response.Body.ChunkReceived += BodyOnChunkReceived;
response.Body.StartProcessing();

while ((!cancellationToken.IsCancellationRequested && isStreaming) || (!cancellationToken.IsCancellationRequested && buffer.Reader.Count > 0))
{
yield return await buffer.Reader.ReadAsync(cancellationToken).ConfigureAwait(false);
}
response.Body.ChunkReceived -= BodyOnChunkReceived;

yield break;

async void BodyOnChunkReceived(object? sender, EventStreamEventReceivedArgs<PayloadPart> e)
{
var message = new StreamReader(e.EventStreamEvent.Bytes).ReadToEnd();
e.EventStreamEvent.Bytes.Position = 0;
var streamResponse = await JsonSerializer.DeserializeAsync<IClaudeMessagesChunk>(e.EventStreamEvent.Bytes, cancellationToken: cancellationToken).ConfigureAwait(false);

if (streamResponse is null)
{
throw new NullReferenceException($"Unable to deserialize {nameof(e.EventStreamEvent.Bytes)} to {nameof(IClaudeMessagesChunk)}");
}

if (streamResponse is ClaudeMessagesMessageStopChunk)
{
isStreaming = false;
}

await buffer.Writer.WriteAsync(streamResponse, cancellationToken).ConfigureAwait(false);
}
}

/// <summary>
/// Invoke a Command v14 model (Text or Light Text) for text completion
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading.Tasks;

namespace Rockhead.Extensions.Anthropic;

public class ClaudeMessagesContentBlockDeltaChunk : IClaudeMessagesChunk
{
[JsonPropertyName("index")]
public int? Index { get; init; }

[JsonPropertyName("delta")]
public BlockDeltaChunkContentBlock? Delta { get; init; }

public string? GetResponse()
{
return Delta?.Text;
}

public string? GetStopReason()
{
return null;
}

public class BlockDeltaChunkContentBlock
{
[JsonPropertyName("type")]
public string? Type { get; init; }

[JsonPropertyName("text")]
public string? Text { get; init; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading.Tasks;

namespace Rockhead.Extensions.Anthropic;

public class ClaudeMessagesContentBlockStartChunk : IClaudeMessagesChunk
{
[JsonPropertyName("index")]
public int? Index { get; init; }

[JsonPropertyName("content_block")]
public BlockStartChunkContentBlock? ContentBlock { get; init; }

public string? GetResponse()
{
return ContentBlock?.Text;
}

public string? GetStopReason()
{
return null;
}

public class BlockStartChunkContentBlock
{
[JsonPropertyName("type")]
public string? Type { get; init; }

[JsonPropertyName("text")]
public string? Text { get; init; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading.Tasks;

namespace Rockhead.Extensions.Anthropic;

public class ClaudeMessagesContentBlockStopChunk : IClaudeMessagesChunk
{
[JsonPropertyName("index")]
public int? Index { get; init; }

public string? GetResponse()
{
return String.Empty;
}

public string? GetStopReason()
{
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using static Rockhead.Extensions.Anthropic.ClaudeMessagesResponse;

namespace Rockhead.Extensions.Anthropic;

public class ClaudeMessagesMessageDeltaChunk : IClaudeMessagesChunk
{
[JsonPropertyName("delta")]
public MessageDelta? Delta { get; set; }

[JsonPropertyName("usage")] public MessageDeltaChunkUsage? Usage { get; init; }

public string? GetResponse()
{
return String.Empty;
}

public string? GetStopReason()
{
return Delta?.StopReason;
}

public class MessageDelta
{ [JsonPropertyName("stop_reason")] public string? StopReason { get; init; }

[JsonPropertyName("stop_sequence")] public string? StopSequence { get; init; }
}

public class MessageDeltaChunkUsage
{
[JsonPropertyName("output_tokens")] public int OutputTokens { get; init; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading.Tasks;

namespace Rockhead.Extensions.Anthropic;

public class ClaudeMessagesMessageStartChunk : IClaudeMessagesChunk
{
[JsonPropertyName("type")]
public string? Type { get; init; }

[JsonPropertyName("message")]
public ClaudeMessagesResponse? Message { get; init; }

public string? GetResponse()
{
return ((ClaudeTextContent?)Message?.Content.FirstOrDefault())?.Text;
}

public string? GetStopReason()
{
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading.Tasks;

namespace Rockhead.Extensions.Anthropic;

public class ClaudeMessagesMessageStopChunk : IClaudeMessagesChunk
{
[JsonPropertyName("amazon-bedrock-invocationMetrics")]
public AmazonBedrockInvocationMetrics? InvocationMetrics { get; init; }

public string? GetResponse()
{
return null;
}

public string? GetStopReason()
{
return null;
}

public class AmazonBedrockInvocationMetrics
{
[JsonPropertyName("inputTokenCount")] public int? InputTokenCount { get; init; }

[JsonPropertyName("outputTokenCount")] public int? OutputTokenCount { get; init; }

[JsonPropertyName("invocationLatency")] public int? InvocationLatency { get; init; }

[JsonPropertyName("firstByteLatency")] public int? FirstByteLatency { get; init; }
}
}
4 changes: 0 additions & 4 deletions src/Rockhead.Extensions/Anthropic/ClaudeMessagesResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ public class ClaudeMessagesResponse : ClaudeMessage, IFoundationModelResponse

[JsonPropertyName("type")] public string? Type { get; init; }

//[JsonPropertyName("role")] public string? Role { get; init; }

//[JsonPropertyName("content")] public IEnumerable<IClaudeContent>? Content { get; init; }

[JsonPropertyName("stop_reason")] public string? StopReason { get; init; }

[JsonPropertyName("stop_sequence")] public string? StopSequence { get; init; }
Expand Down
19 changes: 19 additions & 0 deletions src/Rockhead.Extensions/Anthropic/IClaudeMessagesChunk.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading.Tasks;

namespace Rockhead.Extensions.Anthropic;

[JsonPolymorphic(TypeDiscriminatorPropertyName = "type")]
[JsonDerivedType(typeof(ClaudeMessagesMessageStartChunk), "message_start")]
[JsonDerivedType(typeof(ClaudeMessagesContentBlockStartChunk), "content_block_start")]
[JsonDerivedType(typeof(ClaudeMessagesContentBlockDeltaChunk), "content_block_delta")]
[JsonDerivedType(typeof(ClaudeMessagesContentBlockStopChunk), "content_block_stop")]
[JsonDerivedType(typeof(ClaudeMessagesMessageDeltaChunk), "message_delta")]
[JsonDerivedType(typeof(ClaudeMessagesMessageStopChunk), "message_stop")]
public interface IClaudeMessagesChunk : IFoundationModelResponse
{
}
2 changes: 1 addition & 1 deletion src/Rockhead.Extensions/Rockhead.Extensions.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<PackageTags>Bedrock, GenerativeAI, GenAI, AI</PackageTags>
<PackageLicenseExpression>MIT</PackageLicenseExpression>
<PackageReadmeFile>Readme.md</PackageReadmeFile>
<Version>0.1.4-preview</Version>
<Version>0.1.5-preview</Version>
</PropertyGroup>

<ItemGroup>
Expand Down

0 comments on commit c80c446

Please sign in to comment.