diff --git a/src/FairyBread/DefaultValidationErrorsHandler.cs b/src/FairyBread/DefaultValidationErrorsHandler.cs index d5ab420..d83f434 100644 --- a/src/FairyBread/DefaultValidationErrorsHandler.cs +++ b/src/FairyBread/DefaultValidationErrorsHandler.cs @@ -1,4 +1,6 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; +using System.Linq; using FluentValidation; using FluentValidation.Results; using HotChocolate; @@ -12,13 +14,20 @@ public virtual void Handle( IMiddlewareContext context, IEnumerable invalidResults) { - foreach (var invalidResult in invalidResults) + if (context.ContextData.ContainsKey(WellKnownContextData.ValidatorDescriptorsParams)) { - foreach (var failure in invalidResult.Result.Errors) + throw new AggregateException(invalidResults.Select(x => new ValidationException(x.Result.Errors))); + } + else + { + foreach (var invalidResult in invalidResults) { - var errorBuilder = CreateErrorBuilder(context, invalidResult.ArgumentName, invalidResult.Validator, failure); - var error = errorBuilder.Build(); - context.ReportError(error); + foreach (var failure in invalidResult.Result.Errors) + { + var errorBuilder = CreateErrorBuilder(context, invalidResult.ArgumentName, invalidResult.Validator, failure); + var error = errorBuilder.Build(); + context.ReportError(error); + } } } } diff --git a/src/FairyBread/ValidationMiddlewareInjector.cs b/src/FairyBread/ValidationMiddlewareInjector.cs index 64dcebe..6f0a1c8 100644 --- a/src/FairyBread/ValidationMiddlewareInjector.cs +++ b/src/FairyBread/ValidationMiddlewareInjector.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using HotChocolate; using HotChocolate.Configuration; using HotChocolate.Internal; @@ -15,6 +16,7 @@ namespace FairyBread internal class ValidationMiddlewareInjector : TypeInterceptor { private FieldMiddlewareDefinition? _validationFieldMiddlewareDef; + private FieldMiddlewareDefinition? _validationFieldMiddlewareDefParams; public override void OnBeforeCompleteType( ITypeCompletionContext completionContext, @@ -36,6 +38,7 @@ public override void OnBeforeCompleteType( // Don't add validation middleware unless: // 1. we have args var needsValidationMiddleware = false; + var needsValidationMiddlewareParams = false; foreach (var argDef in fieldDef.Arguments) { @@ -48,13 +51,28 @@ public override void OnBeforeCompleteType( } // 3. there's validators for it - List validatorDescs; + Dictionary> validatorDescs; + var usingArgs = true; try { - validatorDescs = DetermineValidatorsForArg(validatorRegistry, argDef); - if (validatorDescs.Count < 1) + var validatorDescsArgs = DetermineValidatorsForArg(validatorRegistry, argDef); + if (validatorDescsArgs.Any()) + validatorDescs = new Dictionary>() { { "Args", validatorDescsArgs } }; + else { - continue; + if (fieldDef.Arguments.Count == 1) // MutationConventions always use one argument + { + var type = fieldDef.ResolverMember as MethodInfo; + var parameters = type?.GetParameters(); + if (parameters is null) + continue; + validatorDescs = DetermineValidatorsForParameters(validatorRegistry, parameters); + if (!validatorDescs.Any()) + continue; + usingArgs = false; + } + else + continue; } } catch (Exception ex) @@ -74,9 +92,16 @@ public override void OnBeforeCompleteType( } } - validatorDescs.TrimExcess(); - needsValidationMiddleware = true; - argDef.ContextData[WellKnownContextData.ValidatorDescriptors] = validatorDescs.AsReadOnly(); + if (usingArgs) + { + needsValidationMiddleware = true; + argDef.ContextData[WellKnownContextData.ValidatorDescriptors] = validatorDescs.First().Value.AsReadOnly(); + } + else + { + needsValidationMiddlewareParams = true; + argDef.ContextData[WellKnownContextData.ValidatorDescriptorsParams] = validatorDescs; + } } if (needsValidationMiddleware) @@ -89,9 +114,77 @@ public override void OnBeforeCompleteType( fieldDef.MiddlewareDefinitions.Insert(0, _validationFieldMiddlewareDef); } + else if (needsValidationMiddlewareParams) + { + if (_validationFieldMiddlewareDefParams is null) + { + _validationFieldMiddlewareDefParams = new FieldMiddlewareDefinition( + FieldClassMiddlewareFactory.Create()); + } + + fieldDef.MiddlewareDefinitions.Add(_validationFieldMiddlewareDefParams); + } + } + } + + private static Dictionary> DetermineValidatorsForParameters(IValidatorRegistry validatorRegistry, ParameterInfo[] parameters) + { + var validators = new Dictionary>(); + + + foreach (var parameter in parameters) + { + var paramVals = new List(); + // If validation is explicitly disabled, return none so validation middleware won't be added + if (parameter.CustomAttributes.Any(x => x.AttributeType == typeof(DontValidateAttribute))) + { + continue; + } + + + // Include implicit validator/s first (if allowed) + if (!parameter.CustomAttributes.Any(x => x.AttributeType == typeof(DontValidateImplicitlyAttribute))) + { + // And if we can figure out the arg's runtime type + var argRuntimeType = parameter.ParameterType; + if (argRuntimeType is not null) + { + if (validatorRegistry.Cache.TryGetValue(argRuntimeType, out var implicitValidators) && + implicitValidators is not null) + { + paramVals.AddRange(implicitValidators); + } + } + } + + // Include explicit validator/s (that aren't already added implicitly) + var explicitValidators = parameter.GetCustomAttributes().Where(x => x.GetType() == typeof(ValidateAttribute)).Cast().ToList(); + if (explicitValidators.Any()) + { + var validatorTypes = explicitValidators.SelectMany(x => x.ValidatorTypes); + // TODO: Potentially check and throw if there's a validator being explicitly applied for the wrong runtime type + + foreach (var validatorType in validatorTypes) + { + if (paramVals.Any(v => v.ValidatorType == validatorType)) + { + continue; + } + + paramVals.Add(new ValidatorDescriptor(validatorType)); + } + } + + if (paramVals.Any()) + { + paramVals.TrimExcess(); + validators[parameter.Name] = paramVals; + } } + return validators; } + private static List DetermineValidatorsForArg( IValidatorRegistry validatorRegistry, ArgumentDefinition argDef) diff --git a/src/FairyBread/ValidationMiddlewareParams.cs b/src/FairyBread/ValidationMiddlewareParams.cs new file mode 100644 index 0000000..80539df --- /dev/null +++ b/src/FairyBread/ValidationMiddlewareParams.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using FluentValidation; +using HotChocolate.Resolvers; +using HotChocolate.Types; +using Microsoft.Extensions.DependencyInjection; + +namespace FairyBread +{ + internal class ValidationMiddlewareParams + { + private readonly FieldDelegate _next; + private readonly IValidatorProvider _validatorProvider; + private readonly IValidationErrorsHandler _validationErrorsHandler; + + public ValidationMiddlewareParams( + FieldDelegate next, + IValidatorProvider validatorProvider, + IValidationErrorsHandler validationErrorsHandler) + { + _next = next; + _validatorProvider = validatorProvider; + _validationErrorsHandler = validationErrorsHandler; + } + + public async Task InvokeAsync(IMiddlewareContext context) + { + context.ContextData[WellKnownContextData.ValidatorDescriptorsParams] = true; + var arguments = context.Selection.Field.Arguments; + + var invalidResults = new List(); + var type = context.Selection.Field.ResolverMember as MethodInfo; + var parameters = type.GetParameters(); + + + + foreach (var argument in context.Selection.Field.Arguments) + { + if (argument == null) + { + continue; + } + + var resolvedValidators = GetValidators(context, argument).GroupBy(x => x.param); + if (resolvedValidators.Count() > 0) + { + foreach (var resolvedValidatorGroup in resolvedValidators) + { + try + { + var value = context.ArgumentValue(resolvedValidatorGroup.Key); + if (value == null) + { + continue; + } + + foreach (var resolvedValidator in resolvedValidatorGroup) + { + var validationContext = new ValidationContext(value); + var validationResult = await resolvedValidator.resolver.Validator.ValidateAsync( + validationContext, + context.RequestAborted); + if (validationResult != null && + !validationResult.IsValid) + { + invalidResults.Add( + new ArgumentValidationResult( + resolvedValidatorGroup.Key, + resolvedValidator.resolver.Validator, + validationResult)); + } + } + } + finally + { + foreach (var resolvedValidator in resolvedValidatorGroup) + { + resolvedValidator.resolver.Scope?.Dispose(); + } + } + } + } + } + + if (invalidResults.Any()) + { + _validationErrorsHandler.Handle(context, invalidResults); + return; + } + + await _next(context); + } + + IEnumerable<(string param, ResolvedValidator resolver)> GetValidators(IMiddlewareContext context, IInputField argument) + { + if (!argument.ContextData.TryGetValue(WellKnownContextData.ValidatorDescriptorsParams, out var validatorDescriptorsRaw) + || validatorDescriptorsRaw is not Dictionary> validatorDescriptors) + { + yield break; + } + + foreach (var validatorDescriptor in validatorDescriptors) + { + foreach (var validatorDescriptorEntry in validatorDescriptor.Value) + { + if (validatorDescriptorEntry.RequiresOwnScope) + { + var scope = context.Services.CreateScope(); // disposed by middleware + var validator = (IValidator)scope.ServiceProvider.GetRequiredService(validatorDescriptorEntry.ValidatorType); + yield return (validatorDescriptor.Key, new ResolvedValidator(validator, scope)); + } + else + { + var validator = (IValidator)context.Services.GetRequiredService(validatorDescriptorEntry.ValidatorType); + yield return (validatorDescriptor.Key, new ResolvedValidator(validator)); + } + } + } + } + } +} diff --git a/src/FairyBread/WellKnownContextData.cs b/src/FairyBread/WellKnownContextData.cs index b8fcd38..c31b24e 100644 --- a/src/FairyBread/WellKnownContextData.cs +++ b/src/FairyBread/WellKnownContextData.cs @@ -15,5 +15,8 @@ internal static class WellKnownContextData public const string ValidatorDescriptors = Prefix + ".Validators"; + + public const string ValidatorDescriptorsParams = + Prefix + ".Validators.Params"; } }