-
Notifications
You must be signed in to change notification settings - Fork 790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable AIFunctionFactory.Create functions to get DI services #6141
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
|
||
using System; | ||
using System.Collections; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using Microsoft.Shared.Collections; | ||
|
||
#pragma warning disable SA1111 // Closing parenthesis should be on line of last parameter | ||
#pragma warning disable SA1112 // Closing parenthesis should be on line of opening parenthesis | ||
#pragma warning disable SA1114 // Parameter list should follow declaration | ||
#pragma warning disable CA1710 // Identifiers should have correct suffix | ||
|
||
namespace Microsoft.Extensions.AI; | ||
|
||
/// <summary>Represents arguments to be used with <see cref="AIFunction.InvokeAsync"/>.</summary> | ||
/// <remarks> | ||
/// <see cref="AIFunction.InvokeAsync"/> may be invoked with arbitary <see cref="IEnumerable{T}"/> | ||
/// implementations. However, some <see cref="AIFunction"/> implementations may dynamically check | ||
/// the type of the arguments and use the concrete type to perform more specific operations. By | ||
/// checking for <see cref="AIFunctionArguments"/>, and implementation may optionally access | ||
/// additional context provided, such as any <see cref="IServiceProvider"/> that may be associated | ||
/// with the operation. | ||
/// </remarks> | ||
public class AIFunctionArguments : IReadOnlyDictionary<string, object?> | ||
{ | ||
private readonly IReadOnlyDictionary<string, object?> _arguments; | ||
|
||
/// <summary>Initializes a new instance of the <see cref="AIFunctionArguments"/> class.</summary> | ||
/// <param name="arguments">The arguments represented by this instance.</param> | ||
public AIFunctionArguments(IEnumerable<KeyValuePair<string, object?>>? arguments) | ||
{ | ||
if (arguments is IReadOnlyDictionary<string, object?> irod) | ||
{ | ||
_arguments = irod; | ||
} | ||
else if (arguments is null | ||
#if NET | ||
|| (Enumerable.TryGetNonEnumeratedCount(arguments, out int count) && count == 0) | ||
#endif | ||
) | ||
{ | ||
_arguments = EmptyReadOnlyDictionary<string, object?>.Instance; | ||
} | ||
else | ||
{ | ||
_arguments = arguments.ToDictionary( | ||
#if !NET | ||
x => x.Key, x => x.Value | ||
#endif | ||
); | ||
} | ||
} | ||
|
||
/// <summary>Gets any services associated with these arguments.</summary> | ||
public IServiceProvider? Services { get; init; } | ||
|
||
/// <inheritdoc /> | ||
public object? this[string key] => _arguments[key]; | ||
|
||
/// <inheritdoc /> | ||
public IEnumerable<string> Keys => _arguments.Keys; | ||
|
||
/// <inheritdoc /> | ||
public IEnumerable<object?> Values => _arguments.Values; | ||
|
||
/// <inheritdoc /> | ||
public int Count => _arguments.Count; | ||
|
||
/// <inheritdoc /> | ||
public bool ContainsKey(string key) => _arguments.ContainsKey(key); | ||
|
||
/// <inheritdoc /> | ||
public IEnumerator<KeyValuePair<string, object?>> GetEnumerator() => _arguments.GetEnumerator(); | ||
|
||
/// <inheritdoc /> | ||
public bool TryGetValue(string key, out object? value) => _arguments.TryGetValue(key, out value); | ||
|
||
/// <inheritdoc /> | ||
IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_arguments).GetEnumerator(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,9 +14,14 @@ | |
using System.Text.Json.Serialization.Metadata; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using Microsoft.Extensions.DependencyInjection; | ||
using Microsoft.Shared.Collections; | ||
using Microsoft.Shared.Diagnostics; | ||
|
||
#pragma warning disable CA1031 // Do not catch general exception types | ||
#pragma warning disable S2302 // "nameof" should be used | ||
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields | ||
|
||
namespace Microsoft.Extensions.AI; | ||
|
||
/// <summary>Provides factory methods for creating commonly used implementations of <see cref="AIFunction"/>.</summary> | ||
|
@@ -196,8 +201,8 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, | |
object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; | ||
|
||
IReadOnlyDictionary<string, object?> argDict = | ||
arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary<string, object?>.Instance : | ||
arguments as IReadOnlyDictionary<string, object?> ?? | ||
arguments is null ? EmptyReadOnlyDictionary<string, object?>.Instance : | ||
arguments as IReadOnlyDictionary<string, object?> ?? // if arguments is an AIFunctionArguments, which is an IROD, use it as-is | ||
arguments. | ||
#if NET8_0_OR_GREATER | ||
ToDictionary(); | ||
|
@@ -248,6 +253,30 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu | |
|
||
private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions) | ||
{ | ||
AIJsonSchemaCreateOptions schemaOptions = new() | ||
{ | ||
// This needs to be kept in sync with the shape of AIJsonSchemaCreateOptions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should AIJsonSchemaCreateOptions expose a Clone method? Or be a record if that won't introduce other issues? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that If we don't go the record route, it would probably be best to extract this clone into an internal testable method, so you could write a reflection-based test like this to verify no one accidentally adds a property and forgets to update the clone method. |
||
TransformSchemaNode = key.SchemaOptions.TransformSchemaNode, | ||
IncludeParameter = parameterInfo => | ||
{ | ||
// Explicitly exclude from the schema parameters annotated as [FromServices] or [FromKeyedServices]. | ||
// These will be satisfied from sources other than arguments to InvokeAsync. | ||
if (parameterInfo.GetCustomAttribute<FromServicesAttribute>(inherit: true) is not null || | ||
parameterInfo.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is not null) | ||
{ | ||
return false; | ||
} | ||
|
||
// For all other parameters, delegate to whatever behavior is specified in the options. | ||
// If none is specified, include the parameter. | ||
return key.SchemaOptions.IncludeParameter?.Invoke(parameterInfo) ?? true; | ||
}, | ||
IncludeTypeInEnumSchemas = key.SchemaOptions.IncludeTypeInEnumSchemas, | ||
DisallowAdditionalProperties = key.SchemaOptions.DisallowAdditionalProperties, | ||
IncludeSchemaKeyword = key.SchemaOptions.IncludeSchemaKeyword, | ||
RequireAllProperties = key.SchemaOptions.RequireAllProperties, | ||
}; | ||
|
||
// Get marshaling delegates for parameters. | ||
ParameterInfo[] parameters = key.Method.GetParameters(); | ||
ParameterMarshallers = new Func<IReadOnlyDictionary<string, object?>, CancellationToken, object?>[parameters.Length]; | ||
|
@@ -268,7 +297,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions | |
Name, | ||
Description, | ||
serializerOptions, | ||
key.SchemaOptions); | ||
schemaOptions); | ||
} | ||
|
||
public string Name { get; } | ||
|
@@ -341,6 +370,47 @@ static bool IsAsyncMethod(MethodInfo method) | |
cancellationToken; | ||
} | ||
|
||
// For DI-based parameters, try to resolve from the service provider. | ||
if (parameter.GetCustomAttribute<FromServicesAttribute>(inherit: true) is { } fsAttr) | ||
{ | ||
return (arguments, _) => | ||
{ | ||
if ((arguments as AIFunctionArguments)?.Services is IServiceProvider services && | ||
services.GetService(parameterType) is object service) | ||
{ | ||
return service; | ||
} | ||
|
||
if (!parameter.HasDefaultValue) | ||
{ | ||
// No service could be resolved for the required parameter. | ||
Throw.ArgumentException(nameof(arguments), $"Unable to resolve service of type '{parameterType}' for parameter '{parameter.Name}'."); | ||
} | ||
|
||
// No service could be resolved. Return a default value if it's optional, otherwise throw. | ||
return parameter.DefaultValue; | ||
}; | ||
} | ||
else if (parameter.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is { } fksAttr) | ||
{ | ||
return (arguments, _) => | ||
{ | ||
if ((arguments as AIFunctionArguments)?.Services is IKeyedServiceProvider services && | ||
services.GetKeyedService(parameterType, fksAttr.Key) is object service) | ||
{ | ||
return service; | ||
} | ||
|
||
if (!parameter.HasDefaultValue) | ||
{ | ||
// No service could be resolved for the required parameter. | ||
Throw.ArgumentException(nameof(arguments), $"Unable to resolve service of type '{parameterType}' with key '{fksAttr.Key}' for parameter '{parameter.Name}'."); | ||
} | ||
|
||
return parameter.DefaultValue; | ||
}; | ||
} | ||
|
||
// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary. | ||
return (arguments, _) => | ||
{ | ||
|
@@ -359,7 +429,6 @@ static bool IsAsyncMethod(MethodInfo method) | |
|
||
object? MarshallViaJsonRoundtrip(object value) | ||
{ | ||
#pragma warning disable CA1031 // Do not catch general exception types | ||
try | ||
{ | ||
string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType())); | ||
|
@@ -370,7 +439,6 @@ static bool IsAsyncMethod(MethodInfo method) | |
// Eat any exceptions and fall back to the original value to force a cast exception later on. | ||
return value; | ||
} | ||
#pragma warning restore CA1031 | ||
} | ||
} | ||
|
||
|
@@ -482,9 +550,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT | |
#if NET | ||
return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition); | ||
#else | ||
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields | ||
const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance; | ||
#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields | ||
return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken); | ||
#endif | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
|
||
using System; | ||
|
||
namespace Microsoft.Extensions.AI; | ||
|
||
/// <summary>Indicates that a parameter to an <see cref="AIFunction"/> should be sourced from an associated <see cref="IServiceProvider"/>.</summary> | ||
/// <remarks> | ||
/// <see cref="AIFunctionFactory"/> uses this attribute to determine whether a parameter's value should be sourced from | ||
/// an <see cref="IServiceProvider"/> instead of from the nominal arguments passed to the function. The <see cref="IServiceProvider"/> | ||
/// is extracted from the <see cref="AIFunctionArguments.Services"/> property of an <see cref="AIFunctionArguments"/> passed | ||
/// as the arguments to a call to <see cref="AIFunction.InvokeAsync"/>. | ||
/// </remarks> | ||
[AttributeUsage(AttributeTargets.Parameter)] | ||
public sealed class FromServicesAttribute : Attribute | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know there's some hesitancy about duplicating the same name MVC uses. I don't have a better idea, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have further reasons to think it wouldn't suffice to support passing in an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be worth starting an effort lowering all those attributes to the M.E.DI package? We could then take a dependency on the next preview package. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not for a GA release. We wouldn't be able to add that until November. And FromServices is in an ASP.NET MVC namespace. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left a comment on #6146 explaining why I like that proposal the best. Is there any major pushback to making the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
/// <summary>Initializes a new instance of the <see cref="FromServicesAttribute"/> class.</summary> | ||
public FromServicesAttribute() | ||
{ | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will it be confusing to pass an
IServiceProvider
into the constructor and not have it get used at all byFunctionInvokingChatClient.GetService(Type serviceType, object? serviceKey = null)
?I think it might be more consistent to rely on the
IServiceProvider
implementation from theinnerClient
and ensuring that the inner clients are initialized with the host's service provider whenever that's important. Another upside is that this would allowAIFunction
s to inject services provided by theinnerClient
likeChatClientMetadata
.This might be made easier if we changed the
IChatClient.GetService
method to anIChatClient.Services
property. We could get rid of a bunch of redundantChatClientExtensions
methods while we're at it.I've already said my piece about how I don't like falling back to DI for the logger when null is passed in the same argument list, so I won't repeat that argument here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetService has the same shape as with IServiceProvider, but it's not consulting any IServiceProvider in any of the implementations we have or have seen thus far. It's just a mechanism to ask the components in the pipeline for a particular object, e.g. if you have a pipeline of IChatClients that ends in an OnnxRuntimeGenAIClient, and for some reason you want to ask it for its Tokenizer, you can do so with GetService, and that call will be passed down through the pipeline until OnnxRuntimeGenAIClient says "yeah, I've got a Tokenizer, here". That also means it doesn't really work as a property.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most inner clients aren't going to have one. e.g. the OpenAIChatClient doesn't have an IServiceProvider, the AzureAIInferenceChatClient doesn't have one, etc. In fact I don't think any implementations I've seen do. Some other middleware components might, but there's no guarantee FunctionInvocationChatClient will be wrapping one of those. I can't think of a better way than to store the IServiceProvider that's pass in from the builder, typically in UseFunctionInvocation().