Skip to content

Commit

Permalink
Experiment with local capture
Browse files Browse the repository at this point in the history
  • Loading branch information
justindbaur committed Sep 24, 2023
1 parent 13724ea commit eb9baae
Show file tree
Hide file tree
Showing 13 changed files with 425 additions and 210 deletions.
4 changes: 3 additions & 1 deletion example/UnitTest1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ public async Task Test2()
{
var pretend = Pretend.For<IMyInterface>();

var local = "Value";

pretend
.Setup(i => i.Greeting("Test"))
.Setup(i => i.Greeting(local))
.Returns("Thing");


Expand Down
2 changes: 1 addition & 1 deletion src/Pretender.SourceGenerator/PretenderSourceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
for (var i = 0; i < setups.Length; i++)
{
var setup = setups[i];
members.Add(setup!.GetMethodDeclaration(i));
members.AddRange(setup!.GetMembers(i));
}

var classDeclaration = SyntaxFactory.ClassDeclaration("SetupInterceptors")
Expand Down
228 changes: 222 additions & 6 deletions src/Pretender.SourceGenerator/SetupArgument.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,100 @@
using System.Diagnostics;
using System;
using System.Linq;
using System.Collections.Immutable;

namespace Pretender.SourceGenerator
{
internal class SetupArgument(IArgumentOperation argumentOperation, int index)
internal class ArgumentTracker
{
private readonly List<ILocalReferenceOperation> _neededLocals = new();
private readonly Stack<HashSet<ILocalSymbol>> _trackedLocals = new();

public ArgumentTracker()
{
_trackedLocals = new Stack<HashSet<ILocalSymbol>>();
_trackedLocals.Push(new HashSet<ILocalSymbol>(SymbolEqualityComparer.Default));
}

public ImmutableArray<ILocalReferenceOperation> NeededLocals => _neededLocals.ToImmutableArray();
public bool NeedsCapturer { get; private set; }
public void SetNeedsCapturer()
{
NeedsCapturer = true;
}

public bool TryTrackLocal(ILocalReferenceOperation localReferenceOperation)
{
var currentScope = _trackedLocals.Peek();
if (currentScope.Contains(localReferenceOperation.Local))
{
// This is being tracked as created during the current scope, ignore it
return false;
}

_neededLocals.Add(localReferenceOperation);
return true;
}

public void LocalDefined(ILocalSymbol local)
{
var currentScope = _trackedLocals.Peek();
currentScope.Add(local);
}

public void LocalsDefined(IEnumerable<ILocalSymbol> locals)
{
var currentScope = _trackedLocals.Peek();
foreach (var local in locals)
{
currentScope.Add(local);
}
}

// TODO: could create an IDisposable for this
public void EnterScope()
{
_trackedLocals.Push(new(SymbolEqualityComparer.Default));
}

public void ExitScope()
{
_trackedLocals.Pop();
}
}


internal class SetupArgument
{
private static readonly IdentifierNameSyntax CallInfoIdentifier = IdentifierName("callInfo");
private static readonly IdentifierNameSyntax ArgumentsPropertyIdentifier = IdentifierName("Arguments");

private readonly int _index = index;
private readonly int _index;

public SetupArgument(IArgumentOperation argumentOperation, int index, List<Diagnostic> diagnostics)
{
var argOperationValue = argumentOperation.Value;
var tracker = new ArgumentTracker();
// Walk the operation tree to find all locals
Visit(argOperationValue, tracker);

RequiredLocals = tracker.NeededLocals;
NeedsCapturer = tracker.NeedsCapturer;

ArgumentOperation = argumentOperation;
_index = index;
}


public ImmutableArray<ILocalReferenceOperation> RequiredLocals { get; }

public IArgumentOperation ArgumentOperation { get; } = argumentOperation;
public IArgumentOperation ArgumentOperation { get; }

public bool NeedsCapturer { get; }
public bool IsLiteral => ArgumentOperation.Value is ILiteralOperation;
public bool IsInvocation => ArgumentOperation.Value is IInvocationOperation;
public bool IsLocalReference => ArgumentOperation.Value is ILocalReferenceOperation;

public ITypeSymbol ParameterType => ArgumentOperation.Parameter!.Type;
public string ArgumentLocalName => $"{ArgumentOperation.Parameter!.Name}_arg";

Expand Down Expand Up @@ -75,6 +155,7 @@ public bool TryEmitInvocationStatements(out StatementSyntax[] statements)
else if (false) // Is non-scope capturing
{
// This is a lot more work but also very powerful in terms of speed
// We need to rewrite the delegate and replace all local references with our getter
allArgumentsSafe = false;
}
else
Expand Down Expand Up @@ -139,6 +220,96 @@ public bool TryEmitInvocationStatements(out StatementSyntax[] statements)
}
}

// Returns true if the visited operation captured a local
private static bool Visit(IOperation? operation, ArgumentTracker tracker)
{
if (operation == null)
{
return false;
}

// TODO: Handle most operations
switch (operation.Kind)
{
case OperationKind.Block:
var block = (IBlockOperation)operation;
return VisitMany(block.Operations, tracker);
case OperationKind.VariableDeclarationGroup:
var variableDeclarationGroup = (IVariableDeclarationGroupOperation)operation;
return VisitMany(variableDeclarationGroup.Declarations, tracker);

case OperationKind.Return:
var returnOp = (IReturnOperation)operation;
return Visit(returnOp.ReturnedValue, tracker);
case OperationKind.Literal:
// Literals are the best, they are easy and the end of the line
return false;
case OperationKind.Invocation:
var invocation = (IInvocationOperation)operation;
// The instance could be a local itself
return Visit(invocation.Instance, tracker)
| VisitMany(invocation.Arguments, tracker);
case OperationKind.LocalReference:
var local = (ILocalReferenceOperation)operation;
tracker.TryTrackLocal(local);
return true;
case OperationKind.ParameterReference:
return false;
case OperationKind.Binary:
var binary = (IBinaryOperation)operation;
return Visit(binary.LeftOperand, tracker) | Visit(binary.RightOperand, tracker);
case OperationKind.AnonymousFunction:
// TODO: I'm not sure if this belongs in here or DelegateCreation but lets go with here for now
tracker.EnterScope();
var anonymousFunction = (IAnonymousFunctionOperation)operation;
var found = Visit(anonymousFunction.Body, tracker);
tracker.ExitScope();
return found;
case OperationKind.DelegateCreation:
var delegateCreation = (IDelegateCreationOperation)operation;
// TODO: Now that we are in a delegate should we ignore their locals somehow?
return Visit(delegateCreation.Target, tracker);
case OperationKind.VariableInitializer:
var variableInitializer = (IVariableInitializerOperation)operation;
tracker.LocalsDefined(variableInitializer.Locals);
// TODO: Not sure if this is right
Visit(variableInitializer.Value, tracker);
return true;
case OperationKind.VariableDeclaration:
var variableDeclaration = (IVariableDeclarationOperation)operation;
return VisitMany(variableDeclaration.Declarators, tracker)
| Visit(variableDeclaration.Initializer, tracker);
case OperationKind.VariableDeclarator:
var variableDeclarator = (IVariableDeclaratorOperation)operation;
tracker.LocalDefined(variableDeclarator.Symbol);
// TODO: IgnoredArguments property?
return Visit(variableDeclarator.Initializer, tracker);


case OperationKind.Argument:
var argument = (IArgumentOperation)operation;
return Visit(argument.Value, tracker);


}

throw new NotImplementedException($"Can't visit operation '{operation.Kind}'");
}

private static bool VisitMany(IEnumerable<IOperation> operations, ArgumentTracker tracker)
{
var foundLocal = false;
foreach (var operation in operations)
{
if (Visit(operation, tracker))
{
foundLocal = true;
}
}

return foundLocal;
}

private bool TryGetMatcherAttributeType(IInvocationOperation invocationOperation, out INamedTypeSymbol matcherType)
{
var allAttributes = invocationOperation.TargetMethod.GetAttributes();
Expand Down Expand Up @@ -185,13 +356,58 @@ private bool TryGetMatcherAttributeType(IInvocationOperation invocationOperation
return true;
}

public IfStatementSyntax EmitLiteralIfCheck()
public StatementSyntax[] EmitLocalIfCheck(int index)
{
Debug.Assert(IsLocalReference, "Shouldn't have been called.");

var localOperation = (ILocalReferenceOperation)ArgumentOperation.Value;

var variableName = $"{ArgumentOperation.Parameter!.Name}_local";

var statements = new StatementSyntax[3];
statements[0] = EmitArgumentAccessor();

// This is for calling the UnsafeAccessor method that doesn't seem to work for my needs
//statements[1] = LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var"))
// .AddVariables(VariableDeclarator(variableName)
// .WithInitializer(EqualsValueClause(InvocationExpression(
// MemberAccessExpression(
// SyntaxKind.SimpleMemberAccessExpression,
// IdentifierName($"Setup{index}Accessor"),
// IdentifierName(((ILocalReferenceOperation)ArgumentOperation.Value).Local.Name)
// )
// )
// .AddArgumentListArguments(Argument(IdentifierName("target")))))));


//statements[1] = LocalDeclarationStatement(VariableDeclaration(localOperation.Local.Type.AsUnknownTypeSyntax())
// .AddVariables(VariableDeclarator(variableName)
// .WithInitializer(EqualsValueClause(
// MemberAccessExpression(
// SyntaxKind.SimpleMemberAccessExpression,
// ParenthesizedExpression(CastExpression(ParseTypeName("dynamic"), IdentifierName("target"))),
// IdentifierName(localOperation.Local.Name))))
// )
// );

// var arg_local = target.GetType().GetField("local").GetValue(target);

// TODO: This really sucks, but neither other way works
statements[1] = ExpressionStatement(
ParseExpression($"var {variableName} = target.GetType().GetField(\"{localOperation.Local.Name}\")!.GetValue(target)")
);

statements[2] = EmitIfCheck(IdentifierName(variableName));

return statements;
}

public IfStatementSyntax EmitIfCheck(ExpressionSyntax right)
{
Debug.Assert(IsLiteral, "This should only be called if you have already checked it's a literal operation.");
var binaryExpression = BinaryExpression(
SyntaxKind.NotEqualsExpression,
IdentifierName(ArgumentLocalName),
((ILiteralOperation)ArgumentOperation.Value).ToLiteralExpression()
right
);

return CreateArgumentCheck(binaryExpression);
Expand Down
Loading

0 comments on commit eb9baae

Please sign in to comment.