diff --git a/Directory.Packages.props b/Directory.Packages.props index 5ab0ff95a..6f88580f0 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -40,6 +40,7 @@ + diff --git a/samples/AzureFunctionsApp/AzureFunctionsApp.csproj b/samples/AzureFunctionsApp/AzureFunctionsApp.csproj index 25d824faf..19e9c4626 100644 --- a/samples/AzureFunctionsApp/AzureFunctionsApp.csproj +++ b/samples/AzureFunctionsApp/AzureFunctionsApp.csproj @@ -19,6 +19,8 @@ + + diff --git a/samples/AzureFunctionsApp/Entities/ShoppingCart.cs b/samples/AzureFunctionsApp/Entities/ShoppingCart.cs new file mode 100644 index 000000000..acf00371d --- /dev/null +++ b/samples/AzureFunctionsApp/Entities/ShoppingCart.cs @@ -0,0 +1,307 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; + +namespace AzureFunctionsApp.Entities; + +/// +/// This sample demonstrates strongly-typed entity invocation using proxy interfaces. +/// Instead of calling entities using string-based operation names, you define an interface +/// that represents the entity's operations and use it to invoke operations in a type-safe manner. +/// + +/// +/// Entity proxy interface for the shopping cart entity (orchestration use). +/// Defines the operations that can be performed on a shopping cart from orchestrations. +/// +public interface IShoppingCartProxy : IEntityProxy +{ + /// + /// Adds an item to the shopping cart. + /// + /// The item to add. + /// The total number of items in the cart. + Task AddItem(CartItem item); + + /// + /// Removes an item from the shopping cart. + /// + /// The ID of the item to remove. + /// True if the item was removed, false if not found. + Task RemoveItem(string itemId); + + /// + /// Gets the total price of all items in the cart. + /// + /// The total price. + Task GetTotalPrice(); + + /// + /// Clears all items from the cart. + /// + Task Clear(); +} + +/// +/// Client-side proxy interface for the shopping cart entity. +/// Client operations are fire-and-forget (cannot return results). +/// +public interface IShoppingCartClientProxy : IEntityProxy +{ + /// + /// Signals the entity to add an item to the shopping cart. + /// + /// The item to add. + Task AddItem(CartItem item); + + /// + /// Signals the entity to clear all items from the cart. + /// + Task Clear(); +} + +/// +/// Represents an item in the shopping cart. +/// +public record CartItem(string Id, string Name, decimal Price, int Quantity); + +/// +/// Shopping cart state. +/// +public record ShoppingCartState +{ + public List Items { get; init; } = new(); +} + +/// +/// Shopping cart entity implementation. +/// +[DurableTask(nameof(ShoppingCart))] +public class ShoppingCart : TaskEntity +{ + readonly ILogger logger; + + public ShoppingCart(ILogger logger) + { + this.logger = logger; + } + + public int AddItem(CartItem item) + { + CartItem? existing = this.State.Items.FirstOrDefault(i => i.Id == item.Id); + if (existing != null) + { + this.State.Items.Remove(existing); + this.State.Items.Add(existing with { Quantity = existing.Quantity + item.Quantity }); + } + else + { + this.State.Items.Add(item); + } + + this.logger.LogInformation("Added item {ItemId} to cart {CartId}. Total items: {Count}", item.Id, this.Context.Id.Key, this.State.Items.Count); + return this.State.Items.Count; + } + + public bool RemoveItem(string itemId) + { + CartItem? item = this.State.Items.FirstOrDefault(i => i.Id == itemId); + if (item != null) + { + this.State.Items.Remove(item); + this.logger.LogInformation("Removed item {ItemId} from cart {CartId}", itemId, this.Context.Id.Key); + return true; + } + + this.logger.LogWarning("Item {ItemId} not found in cart {CartId}", itemId, this.Context.Id.Key); + return false; + } + + public decimal GetTotalPrice() + { + decimal total = this.State.Items.Sum(i => i.Price * i.Quantity); + this.logger.LogInformation("Cart {CartId} total price: {Total:C}", this.Context.Id.Key, total); + return total; + } + + public void Clear() + { + int count = this.State.Items.Count; + this.State.Items.Clear(); + this.logger.LogInformation("Cleared {Count} items from cart {CartId}", count, this.Context.Id.Key); + } +} + +/// +/// Orchestration that demonstrates strongly-typed entity invocation. +/// +public static class ShoppingCartOrchestration +{ + [Function(nameof(ProcessShoppingCartOrder))] + public static async Task ProcessShoppingCartOrder( + [OrchestrationTrigger] TaskOrchestrationContext context, + string cartId) + { + ILogger logger = context.CreateReplaySafeLogger(nameof(ProcessShoppingCartOrder)); + + // Create a strongly-typed proxy for the shopping cart entity + EntityInstanceId entityId = new(nameof(ShoppingCart), cartId); + IShoppingCartProxy cart = context.Entities.CreateProxy(entityId); + + // Add some items to the cart using strongly-typed method calls + logger.LogInformation("Adding items to cart {CartId}", cartId); + await cart.AddItem(new CartItem("ITEM001", "Laptop", 999.99m, 1)); + await cart.AddItem(new CartItem("ITEM002", "Mouse", 29.99m, 2)); + await cart.AddItem(new CartItem("ITEM003", "Keyboard", 79.99m, 1)); + + // Get the total price + decimal totalPrice = await cart.GetTotalPrice(); + logger.LogInformation("Cart {CartId} total: {Total:C}", cartId, totalPrice); + + // Simulate order processing + if (totalPrice > 1000m) + { + logger.LogInformation("Applying discount for cart {CartId}", cartId); + totalPrice *= 0.9m; // 10% discount + } + + // Clear the cart after order is processed + await cart.Clear(); + logger.LogInformation("Cart {CartId} cleared after order processing", cartId); + + return new OrderResult(cartId, totalPrice, context.CurrentUtcDateTime); + } + + public record OrderResult(string CartId, decimal TotalPrice, DateTime OrderDate); +} + +/// +/// HTTP APIs for the shopping cart sample. +/// +public static class ShoppingCartApis +{ + /// + /// Start an orchestration to process a shopping cart order. + /// Usage: POST /api/shopping-cart/{cartId}/process + /// + [Function("ShoppingCart_ProcessOrder")] + public static async Task ProcessOrderAsync( + [HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "shopping-cart/{cartId}/process")] + HttpRequestData request, + [DurableClient] DurableTaskClient client, + string cartId) + { + string instanceId = await client.ScheduleNewOrchestrationInstanceAsync( + nameof(ShoppingCartOrchestration.ProcessShoppingCartOrder), + cartId); + + return client.CreateCheckStatusResponse(request, instanceId); + } + + /// + /// Add an item to a shopping cart using strongly-typed proxy from client. + /// Usage: POST /api/shopping-cart/{cartId}/items?id={id}&name={name}&price={price}&quantity={quantity} + /// + [Function("ShoppingCart_AddItem")] + public static async Task AddItemAsync( + [HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "shopping-cart/{cartId}/items")] + HttpRequestData request, + [DurableClient] DurableTaskClient client, + string cartId) + { + string? itemId = request.Query["id"]; + string? name = request.Query["name"]; + if (!decimal.TryParse(request.Query["price"], out decimal price) || price <= 0) + { + HttpResponseData badRequest = request.CreateResponse(HttpStatusCode.BadRequest); + await badRequest.WriteStringAsync("Invalid price"); + return badRequest; + } + + if (!int.TryParse(request.Query["quantity"], out int quantity) || quantity <= 0) + { + HttpResponseData badRequest = request.CreateResponse(HttpStatusCode.BadRequest); + await badRequest.WriteStringAsync("Invalid quantity"); + return badRequest; + } + + if (string.IsNullOrEmpty(itemId) || string.IsNullOrEmpty(name)) + { + HttpResponseData badRequest = request.CreateResponse(HttpStatusCode.BadRequest); + await badRequest.WriteStringAsync("Item ID and name are required"); + return badRequest; + } + + // Use strongly-typed proxy for client-side entity invocation + EntityInstanceId entityId = new(nameof(ShoppingCart), cartId); + IShoppingCartClientProxy cart = client.Entities.CreateProxy(entityId); + + // Signal the entity to add the item (fire-and-forget) + await cart.AddItem(new CartItem(itemId, name, price, quantity)); + + HttpResponseData response = request.CreateResponse(HttpStatusCode.Accepted); + await response.WriteStringAsync($"Item {itemId} added to cart {cartId}"); + return response; + } + + /// + /// Get the current state of a shopping cart. + /// Usage: GET /api/shopping-cart/{cartId} + /// + [Function("ShoppingCart_Get")] + public static async Task GetAsync( + [HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "shopping-cart/{cartId}")] + HttpRequestData request, + [DurableClient] DurableTaskClient client, + string cartId) + { + EntityInstanceId entityId = new(nameof(ShoppingCart), cartId); + EntityMetadata? entity = await client.Entities.GetEntityAsync(entityId); + + if (entity is null) + { + return request.CreateResponse(HttpStatusCode.NotFound); + } + + HttpResponseData response = request.CreateResponse(HttpStatusCode.OK); + await response.WriteAsJsonAsync(new + { + cartId, + entity.State.Items, + TotalPrice = entity.State.Items.Sum(i => i.Price * i.Quantity), + ItemCount = entity.State.Items.Count, + }); + + return response; + } + + /// + /// Clear a shopping cart using strongly-typed proxy. + /// Usage: DELETE /api/shopping-cart/{cartId} + /// + [Function("ShoppingCart_Clear")] + public static async Task ClearAsync( + [HttpTrigger(AuthorizationLevel.Anonymous, "delete", Route = "shopping-cart/{cartId}")] + HttpRequestData request, + [DurableClient] DurableTaskClient client, + string cartId) + { + EntityInstanceId entityId = new(nameof(ShoppingCart), cartId); + IShoppingCartClientProxy cart = client.Entities.CreateProxy(entityId); + + // Signal the entity to clear the cart + await cart.Clear(); + + HttpResponseData response = request.CreateResponse(HttpStatusCode.Accepted); + await response.WriteStringAsync($"Cart {cartId} cleared"); + return response; + } +} diff --git a/samples/AzureFunctionsApp/Entities/shopping-cart.http b/samples/AzureFunctionsApp/Entities/shopping-cart.http new file mode 100644 index 000000000..2a1a740f8 --- /dev/null +++ b/samples/AzureFunctionsApp/Entities/shopping-cart.http @@ -0,0 +1,14 @@ +### Add item to shopping cart using strongly-typed proxy +POST http://localhost:7071/api/shopping-cart/cart1/items?id=ITEM001&name=Laptop&price=999.99&quantity=1 + +### Add another item +POST http://localhost:7071/api/shopping-cart/cart1/items?id=ITEM002&name=Mouse&price=29.99&quantity=2 + +### Get shopping cart state +GET http://localhost:7071/api/shopping-cart/cart1 + +### Process order (runs orchestration that uses strongly-typed proxy) +POST http://localhost:7071/api/shopping-cart/cart1/process + +### Clear shopping cart +DELETE http://localhost:7071/api/shopping-cart/cart1 diff --git a/src/Abstractions/Abstractions.csproj b/src/Abstractions/Abstractions.csproj index db8be76ab..86188e632 100644 --- a/src/Abstractions/Abstractions.csproj +++ b/src/Abstractions/Abstractions.csproj @@ -13,6 +13,7 @@ + diff --git a/src/Abstractions/Entities/IEntityProxy.cs b/src/Abstractions/Entities/IEntityProxy.cs new file mode 100644 index 000000000..e9fa88773 --- /dev/null +++ b/src/Abstractions/Entities/IEntityProxy.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Entities; + +/// +/// Marker interface for entity proxy interfaces. +/// +/// +/// This interface is used to mark interfaces that represent entity operations. +/// Entity proxy interfaces should define methods that correspond to operations +/// that can be invoked on entities. These interfaces are used with +/// to create strongly-typed +/// proxies for entity invocation. +/// +public interface IEntityProxy +{ +} diff --git a/src/Abstractions/Entities/TaskOrchestrationEntityProxyExtensions.cs b/src/Abstractions/Entities/TaskOrchestrationEntityProxyExtensions.cs new file mode 100644 index 000000000..1f76cf036 --- /dev/null +++ b/src/Abstractions/Entities/TaskOrchestrationEntityProxyExtensions.cs @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; + +namespace Microsoft.DurableTask.Entities; + +/// +/// Extension methods for creating strongly-typed entity proxies. +/// +public static class TaskOrchestrationEntityProxyExtensions +{ + /// + /// Creates a strongly-typed proxy for invoking entity operations. + /// + /// The entity proxy interface type. Must extend . + /// The entity feature. + /// The entity instance ID. + /// A strongly-typed proxy for the entity. + /// + /// + /// The proxy interface should define methods that correspond to entity operations. + /// Each method invocation will be translated to a call or signal to the entity, depending on the return type: + /// + /// + /// Methods returning or will use CallEntityAsync. + /// Methods returning void will use SignalEntityAsync (fire-and-forget). + /// + /// + /// Example: + /// + /// public interface ICounter : IEntityProxy + /// { + /// Task<int> Add(int value); + /// Task<int> Get(); + /// void Reset(); + /// } + /// + /// var counter = context.Entities.CreateProxy<ICounter>(new EntityInstanceId("Counter", "myCounter")); + /// int result = await counter.Add(5); + /// + /// + /// + public static TEntityProxy CreateProxy( + this TaskOrchestrationEntityFeature feature, + EntityInstanceId id) + where TEntityProxy : class, IEntityProxy + { + Check.NotNull(feature); + return EntityProxy.Create(feature, id); + } + + /// + /// Creates a strongly-typed proxy for invoking entity operations. + /// + /// The entity proxy interface type. Must extend . + /// The entity feature. + /// The entity name. + /// The entity key. + /// A strongly-typed proxy for the entity. + public static TEntityProxy CreateProxy( + this TaskOrchestrationEntityFeature feature, + string entityName, + string entityKey) + where TEntityProxy : class, IEntityProxy + { + return CreateProxy(feature, new EntityInstanceId(entityName, entityKey)); + } + + /// + /// Proxy implementation for entity invocation. + /// + /// The entity proxy interface type. + class EntityProxy : DispatchProxy + where TEntityProxy : class, IEntityProxy + { + TaskOrchestrationEntityFeature feature = null!; + EntityInstanceId id; + + /// + /// Creates a proxy instance. + /// + /// The entity feature. + /// The entity instance ID. + /// The proxy instance. + public static TEntityProxy Create(TaskOrchestrationEntityFeature entityFeature, EntityInstanceId entityId) + { + object proxy = Create>(); + ((EntityProxy)proxy).Initialize(entityFeature, entityId); + return (TEntityProxy)proxy; + } + + /// + protected override object? Invoke(MethodInfo? targetMethod, object?[]? args) + { + if (targetMethod is null) + { + throw new ArgumentNullException(nameof(targetMethod)); + } + + // Get the operation name from the method name + string operationName = targetMethod.Name; + + // Determine input - if there's exactly one parameter, use it; otherwise use args array or null + object? input = args?.Length switch + { + 0 => null, + 1 => args[0], + _ => args, + }; + + Type returnType = targetMethod.ReturnType; + + // Handle void methods - these are fire-and-forget signals + if (returnType == typeof(void)) + { + // Fire and forget - we can't await this in a sync method, so we need to return immediately + // This will schedule the signal but not wait for it + Task signalTask = this.feature.SignalEntityAsync(this.id, operationName, input); + + // For void methods, we complete synchronously but the signal is scheduled + // This matches the behavior of SignalEntityAsync which returns a Task + // that completes when the signal is scheduled, not when it's processed + signalTask.ConfigureAwait(false).GetAwaiter().GetResult(); + return null; + } + + // Handle Task (non-generic) - call without expecting a result + if (returnType == typeof(Task)) + { + return this.feature.CallEntityAsync(this.id, operationName, input); + } + + // Handle Task - call with a result + if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Task<>)) + { + Type resultType = returnType.GetGenericArguments()[0]; + MethodInfo? callMethod = typeof(TaskOrchestrationEntityFeature) + .GetMethods() + .Where(m => m.Name == nameof(TaskOrchestrationEntityFeature.CallEntityAsync) && + m.IsGenericMethod && + m.GetGenericArguments().Length == 1) + .Select(m => m.MakeGenericMethod(resultType)) + .FirstOrDefault(m => + { + ParameterInfo[] parameters = m.GetParameters(); + return parameters.Length == 4 && + parameters[0].ParameterType == typeof(EntityInstanceId) && + parameters[1].ParameterType == typeof(string) && + parameters[2].ParameterType == typeof(object) && + parameters[3].ParameterType == typeof(CallEntityOptions); + }); + + if (callMethod is null) + { + throw new InvalidOperationException($"Could not find CallEntityAsync method for return type {returnType}"); + } + + return callMethod.Invoke(this.feature, new object?[] { this.id, operationName, input, null }); + } + + throw new NotSupportedException( + $"Method '{targetMethod.Name}' has unsupported return type '{returnType.Name}'. " + + "Entity proxy methods must return void, Task, or Task."); + } + + /// + /// Initializes the proxy. + /// + /// The entity feature. + /// The entity instance ID. + void Initialize(TaskOrchestrationEntityFeature entityFeature, EntityInstanceId entityId) + { + this.feature = entityFeature; + this.id = entityId; + } + } +} diff --git a/src/Client/Core/Entities/DurableEntityClientProxyExtensions.cs b/src/Client/Core/Entities/DurableEntityClientProxyExtensions.cs new file mode 100644 index 000000000..a580266f7 --- /dev/null +++ b/src/Client/Core/Entities/DurableEntityClientProxyExtensions.cs @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; +using Microsoft.DurableTask.Entities; + +namespace Microsoft.DurableTask.Client.Entities; + +/// +/// Extension methods for creating strongly-typed entity proxies on the client side. +/// +public static class DurableEntityClientProxyExtensions +{ + /// + /// Creates a strongly-typed proxy for invoking entity operations from a client. + /// + /// The entity proxy interface type. Must extend . + /// The durable entity client. + /// The entity instance ID. + /// A strongly-typed proxy for the entity. + /// + /// + /// The proxy interface should define methods that correspond to entity operations. + /// All method invocations will use SignalEntityAsync (fire-and-forget) since clients + /// cannot wait for entity operation results. + /// + /// + /// Example: + /// + /// public interface ICounter : IEntityProxy + /// { + /// Task Add(int value); + /// Task Reset(); + /// } + /// + /// var counter = client.Entities.CreateProxy<ICounter>(new EntityInstanceId("Counter", "myCounter")); + /// await counter.Add(5); + /// + /// + /// + public static TEntityProxy CreateProxy( + this DurableEntityClient client, + EntityInstanceId id) + where TEntityProxy : class, IEntityProxy + { + Check.NotNull(client); + return EntityClientProxy.Create(client, id); + } + + /// + /// Creates a strongly-typed proxy for invoking entity operations from a client. + /// + /// The entity proxy interface type. Must extend . + /// The durable entity client. + /// The entity name. + /// The entity key. + /// A strongly-typed proxy for the entity. + public static TEntityProxy CreateProxy( + this DurableEntityClient client, + string entityName, + string entityKey) + where TEntityProxy : class, IEntityProxy + { + return CreateProxy(client, new EntityInstanceId(entityName, entityKey)); + } + + /// + /// Proxy implementation for client-side entity invocation. + /// + /// The entity proxy interface type. + class EntityClientProxy : DispatchProxy + where TEntityProxy : class, IEntityProxy + { + DurableEntityClient client = null!; + EntityInstanceId id; + + /// + /// Creates a proxy instance. + /// + /// The durable entity client. + /// The entity instance ID. + /// The proxy instance. + public static TEntityProxy Create(DurableEntityClient entityClient, EntityInstanceId entityId) + { + object proxy = Create>(); + ((EntityClientProxy)proxy).Initialize(entityClient, entityId); + return (TEntityProxy)proxy; + } + + /// + protected override object? Invoke(MethodInfo? targetMethod, object?[]? args) + { + if (targetMethod is null) + { + throw new ArgumentNullException(nameof(targetMethod)); + } + + // Get the operation name from the method name + string operationName = targetMethod.Name; + + // Determine input - if there's exactly one parameter, use it; otherwise use args array or null + object? input = args?.Length switch + { + 0 => null, + 1 => args[0], + _ => args, + }; + + Type returnType = targetMethod.ReturnType; + + // Client proxies can only signal entities (fire-and-forget) + // They cannot wait for results since clients don't have orchestration context + + // Handle void methods + if (returnType == typeof(void)) + { + Task signalTask = this.client.SignalEntityAsync(this.id, operationName, input); + signalTask.ConfigureAwait(false).GetAwaiter().GetResult(); + return null; + } + + // Handle Task (fire-and-forget from client perspective) + if (returnType == typeof(Task)) + { + return this.client.SignalEntityAsync(this.id, operationName, input); + } + + // Task is not supported from clients as they cannot receive results + if (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Task<>)) + { + throw new NotSupportedException( + $"Method '{targetMethod.Name}' returns Task, which is not supported for client-side entity proxies. " + + "Clients can only signal entities (fire-and-forget). Use Task (non-generic) or void return type instead. " + + "To get entity state, use client.Entities.GetEntityAsync()."); + } + + throw new NotSupportedException( + $"Method '{targetMethod.Name}' has unsupported return type '{returnType.Name}'. " + + "Client-side entity proxy methods must return void or Task."); + } + + /// + /// Initializes the proxy. + /// + /// The durable entity client. + /// The entity instance ID. + void Initialize(DurableEntityClient entityClient, EntityInstanceId entityId) + { + this.client = entityClient; + this.id = entityId; + } + } +} diff --git a/test/Abstractions.Tests/Entities/TaskOrchestrationEntityProxyExtensionsTests.cs b/test/Abstractions.Tests/Entities/TaskOrchestrationEntityProxyExtensionsTests.cs new file mode 100644 index 000000000..14b694142 --- /dev/null +++ b/test/Abstractions.Tests/Entities/TaskOrchestrationEntityProxyExtensionsTests.cs @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.DurableTask.Entities; +using Moq; + +namespace Microsoft.DurableTask.Tests.Entities; + +public class TaskOrchestrationEntityProxyExtensionsTests +{ + [Fact] + public void CreateProxy_NullFeature_ThrowsArgumentNullException() + { + // Arrange, Act, Assert + TaskOrchestrationEntityFeature feature = null!; + EntityInstanceId id = new("TestEntity", "key1"); + + Action act = () => feature.CreateProxy(id); + + act.Should().Throw(); + } + + [Fact] + public async Task CallMethod_WithTaskResult_CallsCallEntityAsync() + { + // Arrange + Mock mockFeature = new(); + EntityInstanceId id = new("TestEntity", "key1"); + int expectedResult = 42; + + mockFeature + .Setup(f => f.CallEntityAsync(id, "GetValue", null, null)) + .ReturnsAsync(expectedResult); + + ITestEntityProxy proxy = mockFeature.Object.CreateProxy(id); + + // Act + int result = await proxy.GetValue(); + + // Assert + result.Should().Be(expectedResult); + mockFeature.Verify( + f => f.CallEntityAsync(id, "GetValue", null, null), + Times.Once); + } + + [Fact] + public async Task CallMethod_WithTaskResultAndInput_CallsCallEntityAsyncWithInput() + { + // Arrange + Mock mockFeature = new(); + EntityInstanceId id = new("TestEntity", "key1"); + int input = 5; + int expectedResult = 47; + + mockFeature + .Setup(f => f.CallEntityAsync(id, "Add", input, null)) + .ReturnsAsync(expectedResult); + + ITestEntityProxy proxy = mockFeature.Object.CreateProxy(id); + + // Act + int result = await proxy.Add(input); + + // Assert + result.Should().Be(expectedResult); + mockFeature.Verify( + f => f.CallEntityAsync(id, "Add", input, null), + Times.Once); + } + + [Fact] + public async Task CallMethod_WithTaskNoResult_CallsCallEntityAsync() + { + // Arrange + Mock mockFeature = new(); + EntityInstanceId id = new("TestEntity", "key1"); + + mockFeature + .Setup(f => f.CallEntityAsync(id, "Reset", null, null)) + .Returns(Task.CompletedTask); + + ITestEntityProxy proxy = mockFeature.Object.CreateProxy(id); + + // Act + await proxy.Reset(); + + // Assert + mockFeature.Verify( + f => f.CallEntityAsync(id, "Reset", null, null), + Times.Once); + } + + [Fact] + public void CallMethod_VoidReturn_CallsSignalEntityAsync() + { + // Arrange + Mock mockFeature = new(); + EntityInstanceId id = new("TestEntity", "key1"); + + mockFeature + .Setup(f => f.SignalEntityAsync(id, "Delete", null, null)) + .Returns(Task.CompletedTask); + + ITestEntityProxy proxy = mockFeature.Object.CreateProxy(id); + + // Act + proxy.Delete(); + + // Assert + mockFeature.Verify( + f => f.SignalEntityAsync(id, "Delete", null, null), + Times.Once); + } + + [Fact] + public async Task CallMethod_WithMultipleParameters_PassesParametersAsArray() + { + // Arrange + Mock mockFeature = new(); + EntityInstanceId id = new("TestEntity", "key1"); + string param1 = "test"; + int param2 = 42; + + mockFeature + .Setup(f => f.CallEntityAsync( + id, + "Combine", + It.Is(arr => arr != null && arr.Length == 2 && (string)arr[0]! == param1 && (int)arr[1]! == param2), + null)) + .ReturnsAsync("result"); + + ITestEntityProxy proxy = mockFeature.Object.CreateProxy(id); + + // Act + string result = await proxy.Combine(param1, param2); + + // Assert + result.Should().Be("result"); + mockFeature.Verify( + f => f.CallEntityAsync( + id, + "Combine", + It.Is(arr => arr != null && arr.Length == 2), + null), + Times.Once); + } + + [Fact] + public void CreateProxy_WithEntityNameAndKey_CreatesProxyWithCorrectId() + { + // Arrange + Mock mockFeature = new(); + string entityName = "TestEntity"; + string entityKey = "key1"; + + mockFeature + .Setup(f => f.SignalEntityAsync( + It.Is(id => id.Name == entityName.ToLowerInvariant() && id.Key == entityKey), + "Delete", + null, + null)) + .Returns(Task.CompletedTask); + + // Act + ITestEntityProxy proxy = mockFeature.Object.CreateProxy(entityName, entityKey); + proxy.Delete(); + + // Assert + mockFeature.Verify( + f => f.SignalEntityAsync( + It.Is(id => id.Name == entityName.ToLowerInvariant() && id.Key == entityKey), + "Delete", + null, + null), + Times.Once); + } + + public interface ITestEntityProxy : IEntityProxy + { + Task GetValue(); + + Task Add(int value); + + Task Reset(); + + void Delete(); + + Task Combine(string str, int num); + } +} diff --git a/test/Client/Core.Tests/Entities/DurableEntityClientProxyExtensionsTests.cs b/test/Client/Core.Tests/Entities/DurableEntityClientProxyExtensionsTests.cs new file mode 100644 index 000000000..a3e749e0d --- /dev/null +++ b/test/Client/Core.Tests/Entities/DurableEntityClientProxyExtensionsTests.cs @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Moq; + +namespace Microsoft.DurableTask.Client.Tests.Entities; + +public class DurableEntityClientProxyExtensionsTests +{ + [Fact] + public void CreateProxy_NullClient_ThrowsArgumentNullException() + { + // Arrange, Act, Assert + DurableEntityClient client = null!; + EntityInstanceId id = new("TestEntity", "key1"); + + Action act = () => client.CreateProxy(id); + + act.Should().Throw(); + } + + [Fact] + public async Task CallMethod_WithTask_CallsSignalEntityAsync() + { + // Arrange + Mock mockClient = new("test"); + EntityInstanceId id = new("TestEntity", "key1"); + + mockClient + .Setup(c => c.SignalEntityAsync(id, "Reset", null, null, default)) + .Returns(Task.CompletedTask); + + ITestEntityProxy proxy = mockClient.Object.CreateProxy(id); + + // Act + await proxy.Reset(); + + // Assert + mockClient.Verify( + c => c.SignalEntityAsync(id, "Reset", null, null, default), + Times.Once); + } + + [Fact] + public async Task CallMethod_WithTaskAndInput_CallsSignalEntityAsyncWithInput() + { + // Arrange + Mock mockClient = new("test"); + EntityInstanceId id = new("TestEntity", "key1"); + int input = 5; + + mockClient + .Setup(c => c.SignalEntityAsync(id, "Add", input, null, default)) + .Returns(Task.CompletedTask); + + ITestEntityProxy proxy = mockClient.Object.CreateProxy(id); + + // Act + await proxy.Add(input); + + // Assert + mockClient.Verify( + c => c.SignalEntityAsync(id, "Add", input, null, default), + Times.Once); + } + + [Fact] + public void CallMethod_VoidReturn_CallsSignalEntityAsync() + { + // Arrange + Mock mockClient = new("test"); + EntityInstanceId id = new("TestEntity", "key1"); + + mockClient + .Setup(c => c.SignalEntityAsync(id, "Delete", null, null, default)) + .Returns(Task.CompletedTask); + + ITestEntityProxy proxy = mockClient.Object.CreateProxy(id); + + // Act + proxy.Delete(); + + // Assert + mockClient.Verify( + c => c.SignalEntityAsync(id, "Delete", null, null, default), + Times.Once); + } + + [Fact] + public void CallMethod_WithTaskOfT_ThrowsNotSupportedException() + { + // Arrange + Mock mockClient = new("test"); + EntityInstanceId id = new("TestEntity", "key1"); + + ITestEntityProxy proxy = mockClient.Object.CreateProxy(id); + + // Act + Func act = async () => await proxy.GetValue(); + + // Assert + act.Should().ThrowAsync() + .WithMessage("*returns Task*not supported for client-side entity proxies*"); + } + + [Fact] + public async Task CallMethod_WithMultipleParameters_PassesParametersAsArray() + { + // Arrange + Mock mockClient = new("test"); + EntityInstanceId id = new("TestEntity", "key1"); + string param1 = "test"; + int param2 = 42; + + mockClient + .Setup(c => c.SignalEntityAsync( + id, + "Combine", + It.Is(arr => arr != null && arr.Length == 2 && (string)arr[0]! == param1 && (int)arr[1]! == param2), + null, + default)) + .Returns(Task.CompletedTask); + + ITestEntityProxy proxy = mockClient.Object.CreateProxy(id); + + // Act + await proxy.Combine(param1, param2); + + // Assert + mockClient.Verify( + c => c.SignalEntityAsync( + id, + "Combine", + It.Is(arr => arr != null && arr.Length == 2), + null, + default), + Times.Once); + } + + [Fact] + public void CreateProxy_WithEntityNameAndKey_CreatesProxyWithCorrectId() + { + // Arrange + Mock mockClient = new("test"); + string entityName = "TestEntity"; + string entityKey = "key1"; + + mockClient + .Setup(c => c.SignalEntityAsync( + It.Is(id => id.Name == entityName.ToLowerInvariant() && id.Key == entityKey), + "Delete", + null, + null, + default)) + .Returns(Task.CompletedTask); + + // Act + ITestEntityProxy proxy = mockClient.Object.CreateProxy(entityName, entityKey); + proxy.Delete(); + + // Assert + mockClient.Verify( + c => c.SignalEntityAsync( + It.Is(id => id.Name == entityName.ToLowerInvariant() && id.Key == entityKey), + "Delete", + null, + null, + default), + Times.Once); + } + + public interface ITestEntityProxy : IEntityProxy + { + Task GetValue(); + + Task Add(int value); + + Task Reset(); + + void Delete(); + + Task Combine(string str, int num); + } +}