Skip to content

Commit

Permalink
Feature/compilation safe null traversal (#133)
Browse files Browse the repository at this point in the history
* Add test cases that should work

* Add EnsureSafeAccess wrapping expression for null checks

This will ensure that all data access via props, fields, indexes and calls will be safe

* Refactor ValueGetters to enable deeper type matches

List access in compiled contexts need to be generic otherwise we can't infer types for further access

Co-authored-by: Alex McAuliffe <[email protected]>
  • Loading branch information
Romanx and Alex McAuliffe committed Aug 22, 2022
1 parent c92acb1 commit ad7f7d8
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 72 deletions.
17 changes: 9 additions & 8 deletions src/Stubble.Compilation/Contexts/CompilerContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Stubble.Compilation.Helpers;
using Stubble.Compilation.Settings;
using Stubble.Core.Contexts;
using Stubble.Core.Exceptions;
using Stubble.Core.Interfaces;
using static Stubble.Compilation.Helpers.ExpressionConstants;
using static Stubble.Compilation.Helpers.ExpressionHelpers;

namespace Stubble.Compilation.Contexts
{
Expand Down Expand Up @@ -174,7 +174,8 @@ public Expression Lookup(string name)
instance = context?.SourceData;
}

value = TryEnumerationConversionIfRequired(value);
value = TryEnumerationConversionIfRequired(value);
value = EnsureSafeAccess(value);

cache[name] = value;
}
Expand Down Expand Up @@ -263,8 +264,8 @@ public IEnumerable<Expression> GetNestedSourceData()
yield return parent.SourceData;
parent = parent.ParentContext;
}
}

}

/// <summary>
/// Gets a value from the registry using the initalized value getters
/// </summary>
Expand All @@ -276,19 +277,19 @@ protected RegistryResult GetValueFromRegistry(Type value, Expression instance, s
{
foreach (var entry in CompilerSettings.ValueGetters)
{
if (!entry.Key.IsAssignableFrom(value))
if (entry.TypeMatchCheck(value) is false)
{
continue;
}

var outputVal = entry.Value(value, instance, key, CompilerSettings.IgnoreCaseOnKeyLookup);
if (outputVal != null)
var outputVal = entry.ValueGetterMethod(value, instance, key, CompilerSettings.IgnoreCaseOnKeyLookup);
if (outputVal is not null)
{
return new RegistryResult(outputVal.Type, outputVal);
}
}

return default(RegistryResult);
return default;
}

/// <summary>
Expand Down
77 changes: 77 additions & 0 deletions src/Stubble.Compilation/Helpers/ExpressionHelpers.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// <copyright file="ExpressionHelpers.cs" company="Stubble Authors">
// Copyright (c) Stubble Authors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
// </copyright>

using System;
using System.Collections.Generic;
using System.Linq.Expressions;

#nullable enable

namespace Stubble.Compilation.Helpers
{
/// <summary>
/// A static class containing expression helpers.
/// </summary>
internal static class ExpressionHelpers
{
/// <summary>
/// Wraps a nested member expression null checking at every step.
/// </summary>
/// <param name="expression">The member expression to check.</param>
/// <returns>The original expression null checked.</returns>
public static Expression EnsureSafeAccess(Expression expression)
{
var expressions = new Queue<Expression>();
var item = expression;
while (item?.NodeType is ExpressionType.MemberAccess or ExpressionType.Call)
{
expressions.Enqueue(item);
item = item switch
{
MethodCallExpression mce => mce.Object,
MemberExpression me => me.Expression,
_ => throw new InvalidOperationException("Invalid type found"),
};
}

if (expressions.Count is 0)
{
return expression;
}

var first = expressions.Dequeue();
var result = WrapWithNullCheckIfRequired(first, first);

while (expressions.Count > 0)
{
var memberExpression = expressions.Dequeue();
result = WrapWithNullCheckIfRequired(memberExpression, result);
}

return result!;

static Expression WrapWithNullCheckIfRequired(Expression item, Expression @base)
{
var parent = item switch
{
MethodCallExpression mce => mce.Object,
MemberExpression me => me.Expression,
_ => throw new InvalidOperationException("Invalid type found"),
};

// Parent is null if it's static and if it's a value type it can't be null
if (parent is null || parent.Type.IsValueType)
{
return @base;
}

return Expression.Condition(
Expression.NotEqual(parent, Expression.Default(parent.Type)),
@base,
Expression.Default(@base.Type));
}
}
}
}
6 changes: 3 additions & 3 deletions src/Stubble.Compilation/Settings/CompilerSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class CompilerSettings : BaseSettings
/// <param name="sectionBlacklistTypes">The blacklisted section types</param>
/// <param name="encodingFunction">The encoding function for escaping strings</param>
public CompilerSettings(
Dictionary<Type, DefaultSettings.ValueGetterDelegate> valueGetters,
List<ValueGetter> valueGetters,
Dictionary<Type, List<LambdaExpression>> truthyChecks,
Dictionary<Type, EnumerationConverter> enumerationConverters,
TokenRendererPipeline<CompilerContext> rendererPipeline,
Expand All @@ -58,7 +58,7 @@ public class CompilerSettings : BaseSettings
Expression<Func<string, string>> encodingFunction)
: base(templateLoader, partialLoader, maxRecursionDepth, ignoreCaseOnLookup, parser, defaultTags, parserPipeline, sectionBlacklistTypes)
{
ValueGetters = valueGetters.ToImmutableDictionary();
ValueGetters = valueGetters.ToImmutableArray();
TruthyChecks = truthyChecks.ToImmutableDictionary(k => k.Key, v => v.Value.ToImmutableList());
EnumerationConverters = enumerationConverters.ToImmutableDictionary();
RendererPipeline = rendererPipeline;
Expand All @@ -69,7 +69,7 @@ public class CompilerSettings : BaseSettings
/// <summary>
/// Gets a map of Types to Value getter functions
/// </summary>
public ImmutableDictionary<Type, DefaultSettings.ValueGetterDelegate> ValueGetters { get; }
public ImmutableArray<ValueGetter> ValueGetters { get; }

/// <summary>
/// Gets a readonly list of TruthyChecks
Expand Down
24 changes: 19 additions & 5 deletions src/Stubble.Compilation/Settings/CompilerSettingsBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Linq.Expressions;
using Stubble.Compilation.Class;
using Stubble.Compilation.Contexts;
Expand Down Expand Up @@ -33,8 +35,8 @@ public class CompilerSettingsBuilder : BaseSettingsBuilder<CompilerSettingsBuild
/// <summary>
/// Gets or sets a map of Types to Value getter functions
/// </summary>
protected internal Dictionary<Type, DefaultSettings.ValueGetterDelegate> ValueGetters { get; set; }
= new Dictionary<Type, DefaultSettings.ValueGetterDelegate>();
protected internal List<ValueGetter> ValueGetters { get; set; }
= new List<ValueGetter>();

/// <summary>
/// Gets or sets a readonly list of TruthyChecks
Expand Down Expand Up @@ -75,7 +77,9 @@ public CompilerSettingsBuilder SetCompilationSettings(CompilationSettings settin
/// <returns>The built compilation settings</returns>
public override CompilerSettings BuildSettings()
{
var mergedGetters = DefaultSettings.DefaultValueGetters().MergeLeft(ValueGetters);
var mergedGetters = MergeGetters(
DefaultSettings.DefaultValueGetters(),
ValueGetters);

return new CompilerSettings(
mergedGetters,
Expand Down Expand Up @@ -120,9 +124,9 @@ public CompilerSettingsBuilder AddTruthyCheck<T>(Expression<Func<T, bool>> expr)
/// <typeparam name="T">The type to get the value from</typeparam>
/// <param name="func">The getter function for the type</param>
/// <returns>The builder for chaining calls</returns>
public CompilerSettingsBuilder AddValueGetter<T>(DefaultSettings.ValueGetterDelegate func)
public CompilerSettingsBuilder AddValueGetter<T>(ValueGetterDelegate func)
{
ValueGetters[typeof(T)] = func;
ValueGetters.Add(new ValueGetter(typeof(T), static type => type == typeof(T), func));
return this;
}

Expand Down Expand Up @@ -154,6 +158,16 @@ public CompilerSettingsBuilder SetEncodingFunction(Expression<Func<string, strin
{
EncodingFunction = encodingFunction;
return this;
}

private static List<ValueGetter> MergeGetters(IEnumerable<ValueGetter> baseGetters, params IEnumerable<ValueGetter>[] getters)
{
var map = baseGetters.ToDictionary(k => k.Key, v => v);
var others = getters
.Select(getter => getter.ToDictionary(k => k.Key, v => v))
.ToArray();

return map.MergeLeft(others).Values.ToList();
}
}
}
Loading

0 comments on commit ad7f7d8

Please sign in to comment.