diff --git a/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java b/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java index b003d206..d3dc5acb 100644 --- a/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java +++ b/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java @@ -787,7 +787,7 @@ public J visitLiteralStringTemplateEntry(KtLiteralStringTemplateEntry entry, Exe String value = maybeAdjustCRLF(entry); boolean quoted = entry.getPrevSibling().getNode().getElementType() == KtTokens.OPEN_QUOTE && - entry.getNextSibling().getNode().getElementType() == KtTokens.CLOSING_QUOTE; + entry.getNextSibling().getNode().getElementType() == KtTokens.CLOSING_QUOTE; String valueSource = quoted ? "\"" + value + "\"" : value; @@ -827,8 +827,8 @@ public J visitNullableType(KtNullableType nullableType, ExecutionContext data) { TypeTree typeTree = (TypeTree) requireNonNull(innerType).accept(this, data); Set consumedSpaces = new HashSet<>(); if (innerType.getNextSibling() != null && - isSpace(innerType.getNextSibling().getNode()) && - !(innerType instanceof KtNullableType)) { + isSpace(innerType.getNextSibling().getNode()) && + !(innerType instanceof KtNullableType)) { consumedSpaces.add(innerType.getNextSibling()); } @@ -2569,8 +2569,8 @@ public J visitImportDirective(KtImportDirective importDirective, ExecutionContex PsiElement first = PsiTreeUtil.skipWhitespacesAndCommentsForward(importPsi); PsiElement last = findLastChild(importDirective, psi -> !(psi instanceof KtImportAlias) && - !isSpace(psi.getNode()) && - psi.getNode().getElementType() != KtTokens.SEMICOLON); + !isSpace(psi.getNode()) && + psi.getNode().getElementType() != KtTokens.SEMICOLON); String text = nodeRangeText(getNodeOrNull(first), getNodeOrNull(last)); TypeTree reference = TypeTree.build(text, '`'); @@ -3098,7 +3098,7 @@ public J visitStringTemplateExpression(KtStringTemplateExpression expression, Ex KtStringTemplateEntry[] entries = expression.getEntries(); boolean hasStringTemplateEntry = Arrays.stream(entries).anyMatch(x -> x instanceof KtBlockStringTemplateEntry || - x instanceof KtSimpleNameStringTemplateEntry); + x instanceof KtSimpleNameStringTemplateEntry); if (hasStringTemplateEntry) { String delimiter = expression.getFirstChild().getText(); @@ -3139,8 +3139,8 @@ private static String getString(KtStringTemplateExpression expression, StringBui PsiElement openQuote = expression.getFirstChild(); PsiElement closingQuota = expression.getLastChild(); if (openQuote == null || closingQuota == null || - openQuote.getNode().getElementType() != KtTokens.OPEN_QUOTE || - closingQuota.getNode().getElementType() != KtTokens.CLOSING_QUOTE) { + openQuote.getNode().getElementType() != KtTokens.OPEN_QUOTE || + closingQuota.getNode().getElementType() != KtTokens.CLOSING_QUOTE) { throw new UnsupportedOperationException("This should never happen"); } @@ -3642,15 +3642,51 @@ private JavaType.Primitive primitiveType(PsiElement psi) { } private JavaType.@Nullable Variable variableType(PsiElement psi, @Nullable FirElement parent) { - return psiElementAssociations.variableType(psi, parent); + JavaType.Variable psiType = psiElementAssociations.variableType(psi, parent); + return psiType == null ? null : psiType.withType(mapPrimitiveType(psiType.getType())); } private JavaType.@Nullable Method methodDeclarationType(PsiElement psi) { - return psiElementAssociations.methodDeclarationType(psi); + JavaType.Method psiType = psiElementAssociations.methodDeclarationType(psi); + return psiType == null ? null : psiType.withParameterTypes(mapPrimitiveTypes(psiType.getParameterTypes())); } private JavaType.@Nullable Method methodInvocationType(PsiElement psi) { - return psiElementAssociations.methodInvocationType(psi); + JavaType.Method psiType = psiElementAssociations.methodInvocationType(psi); + return psiType == null ? null : psiType.withParameterTypes(mapPrimitiveTypes(psiType.getParameterTypes())); + } + + private static List mapPrimitiveTypes(List types) { + return ListUtils.map(types, KotlinTreeParserVisitor::mapPrimitiveType); + } + + private static JavaType mapPrimitiveType(JavaType type) { + if (type instanceof JavaType.Class) { + String fullyQualifiedName = ((JavaType.Class) type).getFullyQualifiedName(); + switch (fullyQualifiedName) { + case "kotlin.Boolean": + return JavaType.Primitive.Boolean; + case "kotlin.Byte": + return JavaType.Primitive.Byte; + case "kotlin.Char": + return JavaType.Primitive.Char; + case "kotlin.Double": + return JavaType.Primitive.Double; + case "kotlin.Float": + return JavaType.Primitive.Float; + case "kotlin.Int": + return JavaType.Primitive.Int; + case "kotlin.Long": + return JavaType.Primitive.Long; + case "kotlin.Short": + return JavaType.Primitive.Short; + case "kotlin.String": + return JavaType.Primitive.String; + case "kotlin.Void": + return JavaType.Primitive.Void; + } + } + return type; } /*==================================================================== @@ -3673,7 +3709,7 @@ private J.Identifier createIdentifier(String name, Space prefix, private J.Identifier createIdentifier(String name, Space prefix, @Nullable JavaType type, - JavaType.@Nullable Variable fieldType) { + JavaType.@Nullable Variable fieldType) { Markers markers = Markers.EMPTY; String updated = name; if (name.startsWith("`")) { @@ -3914,10 +3950,10 @@ private Space prefixAndInfix(@Nullable PsiElement element, @Nullable Set> mapParameters(@Nullable KtParameterList li ); superTypes.add(padRight(delegationCall, suffix(superTypeCallEntry))); } else if (superTypeListEntry instanceof KtSuperTypeEntry || - superTypeListEntry instanceof KtDelegatedSuperTypeEntry) { + superTypeListEntry instanceof KtDelegatedSuperTypeEntry) { TypeTree typeTree = (TypeTree) superTypeListEntry.accept(this, data); if (i == 0) { @@ -4207,7 +4243,7 @@ private Space toSpace(@Nullable PsiElement element) { if (elementType == KtTokens.WHITE_SPACE) { return Space.build(maybeAdjustCRLF(element), emptyList()); } else if (elementType == KtTokens.EOL_COMMENT || - elementType == KtTokens.BLOCK_COMMENT) { + elementType == KtTokens.BLOCK_COMMENT) { String nodeText = maybeAdjustCRLF(element); boolean isBlockComment = ((PsiComment) element).getTokenType() == KtTokens.BLOCK_COMMENT; String comment = isBlockComment ? nodeText.substring(2, nodeText.length() - 2) : nodeText.substring(2); diff --git a/src/test/java/org/openrewrite/kotlin/MethodMatcherTest.java b/src/test/java/org/openrewrite/kotlin/MethodMatcherTest.java index 688372f1..8c42f51b 100644 --- a/src/test/java/org/openrewrite/kotlin/MethodMatcherTest.java +++ b/src/test/java/org/openrewrite/kotlin/MethodMatcherTest.java @@ -16,6 +16,8 @@ package org.openrewrite.kotlin; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.openrewrite.DocumentExample; import org.openrewrite.ExecutionContext; import org.openrewrite.java.MethodMatcher; @@ -34,6 +36,7 @@ void matchesTopLevelFunction() { rewriteRun( spec -> spec.recipe(toRecipe(() -> new KotlinIsoVisitor<>() { private static final MethodMatcher methodMatcher = new MethodMatcher("openRewriteFile0Kt function(..)"); + @Override public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext p) { if (methodMatcher.matches(method.getMethodType())) { @@ -58,4 +61,40 @@ fun usesFunction() { ) ); } + + @ParameterizedTest + @ValueSource(strings = { + "java.lang.Math max(int, int)", + "java.lang.Math max(..)" + }) + void matchesMethodPatternOfJavaMethodWithAnyArgument(String methodPattern) { + rewriteRun( + spec -> spec.recipe(toRecipe(() -> new KotlinIsoVisitor<>() { + private final MethodMatcher methodMatcher = new MethodMatcher(methodPattern, true); + + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext p) { + if (methodMatcher.matches(method)) { + return SearchResult.found(method); + } + return super.visitMethodInvocation(method, p); + } + })), + kotlin( + """ + import java.lang.Math + fun max(a: Int, b: Int) : Int { + return Math.max(a, b) + } + """, + """ + import java.lang.Math + fun max(a: Int, b: Int) : Int { + return /*~~>*/Math.max(a, b) + } + """ + ) + ); + } + } diff --git a/src/test/java/org/openrewrite/kotlin/UseStaticImportTest.java b/src/test/java/org/openrewrite/kotlin/UseStaticImportTest.java index 285970b4..cb169cda 100644 --- a/src/test/java/org/openrewrite/kotlin/UseStaticImportTest.java +++ b/src/test/java/org/openrewrite/kotlin/UseStaticImportTest.java @@ -45,7 +45,7 @@ void noPriorImports() { @Test void withImports() { rewriteRun( - spec -> spec.recipe(new UseStaticImport("java.lang.Integer valueOf(kotlin.Int)")), + spec -> spec.recipe(new UseStaticImport("java.lang.Integer valueOf(int)")), kotlin( """ import java.util.Collections