Skip to content

Commit

Permalink
Aspect based locking to prevent parallel model operations. Spelling f…
Browse files Browse the repository at this point in the history
…ixes.
  • Loading branch information
edgett committed Jun 18, 2024
1 parent dabf8e0 commit f3f62a8
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

<div class="chat-messages">
@* Display each prompt in a card. Display each response in a card. *@
@foreach (var promptWithResponse in Controller!.WebsocketChatMessages)
@foreach (var promptWithResponse in Controller!.WebSocketChatMessages)
{
<FluentCard>
<FluentStack>
Expand Down
8 changes: 4 additions & 4 deletions PalmHill.BlazorChat/Client/Services/ChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ ILogger<ChatService> logger
/// <summary>
/// The list of chat messages. Containing a prompt and its response.
/// </summary>
public List<WebSocketChatMessage> WebsocketChatMessages { get; private set; } = new List<WebSocketChatMessage>();
public List<WebSocketChatMessage> WebSocketChatMessages { get; private set; } = new List<WebSocketChatMessage>();

/// <summary>
/// The WebSocketChatService that handles the WebSocket connection.
Expand Down Expand Up @@ -142,7 +142,7 @@ public async Task SendToWebSocketChat()
var prompt = new WebSocketChatMessage();
prompt.ConversationId = ConversationId;
prompt.Prompt = UserInput;
WebsocketChatMessages.Add(prompt);
WebSocketChatMessages.Add(prompt);
UserInput = string.Empty;
StateHasChanged();
await SendInferenceRequest();
Expand All @@ -160,7 +160,7 @@ public async Task AskDocumentApi()
var prompt = new WebSocketChatMessage();
prompt.Prompt = UserInput;
prompt.ConversationId = ConversationId;
WebsocketChatMessages.Add(prompt);
WebSocketChatMessages.Add(prompt);
UserInput = string.Empty;
StateHasChanged();

Expand Down Expand Up @@ -258,7 +258,7 @@ private void setupWebSocketChatConnection()
WebSocketChatConnection = new WebSocketChatService(
ConversationId,
_navigationManager.ToAbsoluteUri("/chathub?customUserId=user1"),
WebsocketChatMessages,
WebSocketChatMessages,
_localStorageService
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

namespace PalmHill.BlazorChat.Server
{
public static class ChatCancelation
public static class ChatCancellation
{
public static ConcurrentDictionary<Guid, CancellationTokenSource> CancelationTokens { get; private set; } = new ConcurrentDictionary<Guid, CancellationTokenSource>();
public static ConcurrentDictionary<Guid, CancellationTokenSource> CancellationTokens { get; private set; } = new ConcurrentDictionary<Guid, CancellationTokenSource>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="AspectInjector" Version="2.8.2" />
<PackageReference Include="Microsoft.AspNetCore.Components.WebAssembly.Server" Version="8.0.6" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.6.2" />
</ItemGroup>
Expand Down
44 changes: 44 additions & 0 deletions PalmHill.BlazorChat/Server/SerialExecutionAspect.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
namespace PalmHill.BlazorChat.Server
{
using AspectInjector.Broker;
using System;
using System.Collections.Concurrent;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;

[Aspect(Scope.Global)]
public class SerialExecutionAspect
{
private static readonly ConcurrentDictionary<string, SemaphoreSlim> _semaphores = new ConcurrentDictionary<string, SemaphoreSlim>();

[Advice(Kind.Around, Targets = Target.Method)]
public object? Handle(
[Argument(Source.Name)] string methodName,
[Argument(Source.Target)] Func<object[], object> method,
[Argument(Source.Arguments)] object[] args,
[Argument(Source.Metadata)] MethodBase methodBase)
{
var attribute = methodBase.GetCustomAttribute<SerialExecutionAttribute>();
var key = attribute?.Key ?? string.Empty;
var semaphore = _semaphores.GetOrAdd(key, _ => new SemaphoreSlim(1, 1));

semaphore.Wait();
try
{
var result = method(args);
if (result is Task task)
{
task.GetAwaiter().GetResult(); // Ensure the task completes
return task;
}
return result;
}
finally
{
semaphore.Release();
}
}
}

}
18 changes: 18 additions & 0 deletions PalmHill.BlazorChat/Server/SerialExecutionAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace PalmHill.BlazorChat.Server
{
using AspectInjector.Broker;

[AttributeUsage(AttributeTargets.Method)]
[Injection(typeof(SerialExecutionAspect))]
public class SerialExecutionAttribute : Attribute
{
public string Key { get; }

public SerialExecutionAttribute(string key)
{
Key = key;
}
}


}
6 changes: 4 additions & 2 deletions PalmHill.BlazorChat/Server/SignalR/WebSocketChat.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Azure.AI.OpenAI;
using LLama;
using Microsoft.AspNetCore.Components;
using Microsoft.AspNetCore.SignalR;
using Microsoft.SemanticKernel.ChatCompletion;
using PalmHill.BlazorChat.Server.WebApi;
Expand Down Expand Up @@ -39,7 +40,7 @@ public async Task InferenceRequest(InferenceRequest chatConversation)
{
var conversationId = chatConversation.Id;
var cancellationTokenSource = new CancellationTokenSource();
ChatCancelation.CancelationTokens[conversationId] = cancellationTokenSource;
ChatCancellation.CancellationTokens[conversationId] = cancellationTokenSource;

try
{
Expand Down Expand Up @@ -70,7 +71,7 @@ public async Task InferenceRequest(InferenceRequest chatConversation)
finally
{
//ThreadLock.InferenceLock.Release();
ChatCancelation.CancelationTokens.TryRemove(conversationId, out _);
ChatCancellation.CancellationTokens.TryRemove(conversationId, out _);

}
}
Expand All @@ -83,6 +84,7 @@ public async Task InferenceRequest(InferenceRequest chatConversation)
/// <param name="messageId">The unique identifier for the message.</param>
/// <param name="chatConversation">The chat conversation to use for inference.</param>
/// <returns>A Task that represents the asynchronous operation.</returns>
[SerialExecution("ModelOperation")]
private async Task DoInferenceAndRespondToClient(ISingleClientProxy respondToClient, InferenceRequest chatConversation, CancellationToken cancellationToken)
{
// Create a context for the model and a chat session for the conversation
Expand Down
9 changes: 5 additions & 4 deletions PalmHill.BlazorChat/Server/WebApi/ApiChatController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public async Task<ActionResult<string>> Chat([FromBody] InferenceRequest convers

var conversationId = conversation.Id;
var cancellationTokenSource = new CancellationTokenSource();
ChatCancelation.CancelationTokens[conversationId] = cancellationTokenSource;
ChatCancellation.CancellationTokens[conversationId] = cancellationTokenSource;

try
{
Expand All @@ -80,7 +80,7 @@ public async Task<ActionResult<string>> Chat([FromBody] InferenceRequest convers
finally
{
//ThreadLock.InferenceLock.Release();
ChatCancelation.CancelationTokens.TryRemove(conversationId, out _);
ChatCancellation.CancellationTokens.TryRemove(conversationId, out _);
}

_logger.LogError(errorText);
Expand All @@ -98,7 +98,7 @@ public async Task<ActionResult<ChatMessage>> Ask(InferenceRequest chatConversati

var conversationId = chatConversation.Id;
var cancellationTokenSource = new CancellationTokenSource();
ChatCancelation.CancelationTokens[conversationId] = cancellationTokenSource;
ChatCancellation.CancellationTokens[conversationId] = cancellationTokenSource;

var question = chatConversation.ChatMessages.LastOrDefault()?.Message;
if (question == null)
Expand Down Expand Up @@ -137,7 +137,7 @@ public async Task<ActionResult<ChatMessage>> Ask(InferenceRequest chatConversati
[HttpDelete("cancel/{conversationId}", Name = "CancelChat")]
public async Task<bool> CancelChat(Guid conversationId)
{
var cancelToken = ChatCancelation.CancelationTokens[conversationId];
var cancelToken = ChatCancellation.CancellationTokens[conversationId];
if (cancelToken == null)
{
return false;
Expand All @@ -154,6 +154,7 @@ public async Task<bool> CancelChat(Guid conversationId)
/// </summary>
/// <param name="conversation">The chat conversation for which to perform inference.</param>
/// <returns>Returns the inference result as a string.</returns>
[SerialExecution("ModelOperation")]
private async Task<string> DoInference(InferenceRequest conversation, CancellationToken cancellationToken)
{

Expand Down
1 change: 1 addition & 0 deletions PalmHill.BlazorChat/Server/WebApi/AttachmentController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public async Task<ActionResult<AttachmentInfo>> AddAttachment([FromForm] FileUpl
return attachmentInfo;
}

[SerialExecution("ModelOperation")]
private async Task DoImportAsync(string? userId, AttachmentInfo attachmentInfo)
{
try
Expand Down

0 comments on commit f3f62a8

Please sign in to comment.