Skip to content
Open
100 changes: 61 additions & 39 deletions src/Analyzers/MSTest.Analyzers.CodeFixes/Helpers/FixtureMethodFixer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,110 @@

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace MSTest.Analyzers.Helpers;

internal static class FixtureMethodFixer
{
private const SyntaxNode? VoidReturnTypeNode = null;

public static async Task<Solution> FixSignatureAsync(Document document, SyntaxNode root, SyntaxNode node,
bool isParameterLess, bool shouldBeStatic, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

SemanticModel? semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
if (node is not MethodDeclarationSyntax methodDeclaration)
{
return document.Project.Solution;
}

var methodSymbol = (IMethodSymbol?)semanticModel.GetDeclaredSymbol(node, cancellationToken);
if (methodSymbol is null)
SemanticModel semanticModel = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);

if (semanticModel.GetDeclaredSymbol(node, cancellationToken) is not IMethodSymbol methodSymbol)
{
return document.Project.Solution;
}

var wellKnownTypeProvider = WellKnownTypeProvider.GetOrCreate(semanticModel.Compilation);
var syntaxGenerator = SyntaxGenerator.GetGenerator(document);

SyntaxNode fixedMethodDeclarationNode = syntaxGenerator.MethodDeclaration(
methodSymbol.Name,
GetParameters(syntaxGenerator, isParameterLess, wellKnownTypeProvider),
typeParameters: null,
GetReturnType(syntaxGenerator, methodSymbol, wellKnownTypeProvider),
Accessibility.Public,
GetModifiers(methodSymbol, shouldBeStatic),
GetStatements(node, syntaxGenerator));
MethodDeclarationSyntax fixedMethodDeclaration = methodDeclaration
.WithParameterList(GetParameterList(isParameterLess, wellKnownTypeProvider))
.WithReturnType(GetReturnType(methodSymbol, wellKnownTypeProvider))
.WithModifiers(GetModifiers(methodDeclaration, shouldBeStatic))
.WithTypeParameterList(null);

// Copy the attributes from the old method to the new method.
fixedMethodDeclarationNode = syntaxGenerator.AddAttributes(fixedMethodDeclarationNode, syntaxGenerator.GetAttributes(node));
if (fixedMethodDeclaration.Body is null)
{
fixedMethodDeclaration = fixedMethodDeclaration
.WithBody(SyntaxFactory.Block())
.WithSemicolonToken(default);
}
else
{
SyntaxList<StatementSyntax> statements = fixedMethodDeclaration.Body.Statements;
IEnumerable<StatementSyntax> filteredStatements = statements
.Where(s => !s.IsKind(SyntaxKind.ReturnStatement) && !s.IsKind(SyntaxKind.YieldReturnStatement));

if (statements.Count != filteredStatements.Count())
{
fixedMethodDeclaration = fixedMethodDeclaration.WithBody(
fixedMethodDeclaration.Body.WithStatements(SyntaxFactory.List(filteredStatements)));
}
}

return document.WithSyntaxRoot(root.ReplaceNode(node, fixedMethodDeclarationNode)).Project.Solution;
return document.WithSyntaxRoot(root.ReplaceNode(node, fixedMethodDeclaration)).Project.Solution;
}

private static IEnumerable<SyntaxNode> GetStatements(SyntaxNode node, SyntaxGenerator syntaxGenerator)
=> syntaxGenerator.GetStatements(node)
.Where(x => !x.IsKind(SyntaxKind.ReturnStatement) && !x.IsKind(SyntaxKind.YieldReturnStatement));

private static DeclarationModifiers GetModifiers(IMethodSymbol methodSymbol, bool shouldBeStatic)
private static SyntaxTokenList GetModifiers(MethodDeclarationSyntax methodDeclaration, bool shouldBeStatic)
{
DeclarationModifiers newModifiers = methodSymbol.IsAsync
? DeclarationModifiers.Async
: DeclarationModifiers.None;
SyntaxTokenList modifiers = SyntaxFactory.TokenList(
methodDeclaration.Modifiers.Where(m =>
!m.IsKind(SyntaxKind.PublicKeyword) &&
!m.IsKind(SyntaxKind.PrivateKeyword) &&
!m.IsKind(SyntaxKind.ProtectedKeyword) &&
!m.IsKind(SyntaxKind.InternalKeyword) &&
!m.IsKind(SyntaxKind.AbstractKeyword) &&
!m.IsKind(SyntaxKind.StaticKeyword)));

SyntaxTokenList result = SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PublicKeyword));
if (shouldBeStatic)
{
result = result.Add(SyntaxFactory.Token(SyntaxKind.StaticKeyword));
}

return newModifiers.WithIsStatic(shouldBeStatic);
return result.AddRange(modifiers);
}

private static SyntaxNode? GetReturnType(SyntaxGenerator syntaxGenerator, IMethodSymbol methodSymbol, WellKnownTypeProvider wellKnownTypeProvider)
private static TypeSyntax GetReturnType(IMethodSymbol methodSymbol, WellKnownTypeProvider wellKnownTypeProvider)
{
if (SymbolEqualityComparer.Default.Equals(methodSymbol.ReturnType.OriginalDefinition, wellKnownTypeProvider.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemThreadingTasksValueTask1)))
{
return syntaxGenerator.IdentifierName("ValueTask");
return SyntaxFactory.IdentifierName("ValueTask");
}

if (methodSymbol.IsAsync
|| SymbolEqualityComparer.Default.Equals(methodSymbol.ReturnType.OriginalDefinition, wellKnownTypeProvider.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemThreadingTasksTask1)))
{
return syntaxGenerator.IdentifierName("Task");
return SyntaxFactory.IdentifierName("Task");
}

// For all other cases return void.
return VoidReturnTypeNode;
// Default to void
return SyntaxFactory.PredefinedType(SyntaxFactory.Token(SyntaxKind.VoidKeyword));
}

private static IEnumerable<SyntaxNode> GetParameters(SyntaxGenerator syntaxGenerator, bool isParameterLess,
WellKnownTypeProvider wellKnownTypeProvider)
private static ParameterListSyntax GetParameterList(bool isParameterLess, WellKnownTypeProvider wellKnownTypeProvider)
{
if (isParameterLess
|| !wellKnownTypeProvider.TryGetOrCreateTypeByMetadataName(
WellKnownTypeNames.MicrosoftVisualStudioTestToolsUnitTestingTestContext,
out INamedTypeSymbol? testContextTypeSymbol))
out _))
{
return [];
return SyntaxFactory.ParameterList();
}

SyntaxNode testContextType = syntaxGenerator.TypeExpression(testContextTypeSymbol);
SyntaxNode testContextParameter = syntaxGenerator.ParameterDeclaration("testContext", testContextType);
return [testContextParameter];
ParameterSyntax parameter = SyntaxFactory
.Parameter(SyntaxFactory.Identifier("testContext"))
.WithType(SyntaxFactory.IdentifierName("TestContext"));

return SyntaxFactory.ParameterList(SyntaxFactory.SingletonSeparatedList(parameter));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -711,4 +711,51 @@ public static void ClassInitialize(TestContext testContext)

await VerifyCS.VerifyAnalyzerAsync(code);
}

[TestMethod]
public async Task WhenClassInitializeHasComments_CommentsArePreserved()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[ClassInitialize]
internal static void {|#0:ClassInitialize|}(TestContext testContext)
{
InitializeClass();

// Class initialization comments;
// Setup code here
}

private static void InitializeClass() { }
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[ClassInitialize]
public static void ClassInitialize(TestContext testContext)
{
InitializeClass();

// Class initialization comments;
// Setup code here
}

private static void InitializeClass() { }
}
""";

await VerifyCS.VerifyCodeFixAsync(
code,
VerifyCS.Diagnostic().WithLocation(0).WithArguments("ClassInitialize"),
fixedCode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -552,4 +552,51 @@ public void TestCleanup()

await VerifyCS.VerifyAnalyzerAsync(code);
}

[TestMethod]
public async Task WhenTestCleanupHasComments_CommentsArePreserved()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestCleanup]
internal void {|#0:TestCleanup|}()
{
CleanupCode();

// Cleanup comments;
// More comments
}

private void CleanupCode() { }
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestCleanup]
public void TestCleanup()
{
CleanupCode();

// Cleanup comments;
// More comments
}

private void CleanupCode() { }
}
""";

await VerifyCS.VerifyCodeFixAsync(
code,
VerifyCS.Diagnostic().WithLocation(0).WithArguments("TestCleanup"),
fixedCode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,51 @@ await VerifyCS.VerifyCodeFixAsync(
fixedCode);
}

[TestMethod]
public async Task WhenTestInitializeHasComments_CommentsArePreserved()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestInitialize]
public void {|#0:TestSetup|}(TestContext tc)
{
SomeCode();

// Some comments;
}

private void SomeCode() { }
}
""";

string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestInitialize]
public void TestSetup()
{
SomeCode();

// Some comments;
}

private void SomeCode() { }
}
""";

await VerifyCS.VerifyCodeFixAsync(
code,
VerifyCS.Diagnostic().WithLocation(0).WithArguments("TestSetup"),
fixedCode);
}

[TestMethod]
public async Task WhenTestInitializeIsNotOnClass_Diagnostic()
{
Expand Down
Loading