diff --git a/rewrite-kotlin/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java b/rewrite-kotlin/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java index 9247c95887..2b61764637 100644 --- a/rewrite-kotlin/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java +++ b/rewrite-kotlin/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java @@ -3641,15 +3641,51 @@ private JavaType.Primitive primitiveType(PsiElement psi) { } private JavaType.@Nullable Variable variableType(PsiElement psi, @Nullable FirElement parent) { - return psiElementAssociations.variableType(psi, parent); + JavaType.Variable psiType = psiElementAssociations.variableType(psi, parent); + return psiType == null ? null : psiType.withType(mapPrimitiveType(psiType.getType())); } private JavaType.@Nullable Method methodDeclarationType(PsiElement psi) { - return psiElementAssociations.methodDeclarationType(psi); + JavaType.Method psiType = psiElementAssociations.methodDeclarationType(psi); + return psiType == null ? null : psiType.withParameterTypes(mapPrimitiveTypes(psiType.getParameterTypes())); } private JavaType.@Nullable Method methodInvocationType(PsiElement psi) { - return psiElementAssociations.methodInvocationType(psi); + JavaType.Method psiType = psiElementAssociations.methodInvocationType(psi); + return psiType == null ? null : psiType.withParameterTypes(mapPrimitiveTypes(psiType.getParameterTypes())); + } + + private static List mapPrimitiveTypes(List types) { + return ListUtils.map(types, KotlinTreeParserVisitor::mapPrimitiveType); + } + + private static JavaType mapPrimitiveType(JavaType type) { + if (type instanceof JavaType.Class) { + String fullyQualifiedName = ((JavaType.Class) type).getFullyQualifiedName(); + switch (fullyQualifiedName) { + case "kotlin.Boolean": + return JavaType.Primitive.Boolean; + case "kotlin.Byte": + return JavaType.Primitive.Byte; + case "kotlin.Char": + return JavaType.Primitive.Char; + case "kotlin.Double": + return JavaType.Primitive.Double; + case "kotlin.Float": + return JavaType.Primitive.Float; + case "kotlin.Int": + return JavaType.Primitive.Int; + case "kotlin.Long": + return JavaType.Primitive.Long; + case "kotlin.Short": + return JavaType.Primitive.Short; + case "kotlin.String": + return JavaType.Primitive.String; + case "kotlin.Void": + return JavaType.Primitive.Void; + } + } + return type; } /*==================================================================== diff --git a/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/MethodMatcherTest.java b/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/MethodMatcherTest.java index 688372f16f..9246539dcb 100644 --- a/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/MethodMatcherTest.java +++ b/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/MethodMatcherTest.java @@ -16,6 +16,8 @@ package org.openrewrite.kotlin; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.openrewrite.DocumentExample; import org.openrewrite.ExecutionContext; import org.openrewrite.java.MethodMatcher; @@ -58,4 +60,39 @@ fun usesFunction() { ) ); } + + @ParameterizedTest + @ValueSource(strings = { + "java.lang.Math max(int, int)", + "java.lang.Math max(..)" + }) + void matchesMethodPatternOfJavaMethodWithAnyArgument(String methodPattern) { + rewriteRun( + spec -> spec.recipe(toRecipe(() -> new KotlinIsoVisitor<>() { + private final MethodMatcher methodMatcher = new MethodMatcher(methodPattern, true); + + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext p) { + if (methodMatcher.matches(method)) { + return SearchResult.found(method); + } + return super.visitMethodInvocation(method, p); + } + })), + kotlin( + """ + import java.lang.Math + fun max(a: Int, b: Int) : Int { + return Math.max(a, b) + } + """, + """ + import java.lang.Math + fun max(a: Int, b: Int) : Int { + return /*~~>*/Math.max(a, b) + } + """ + ) + ); + } } diff --git a/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/UseStaticImportTest.java b/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/UseStaticImportTest.java index 285970b40b..cb169cda25 100644 --- a/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/UseStaticImportTest.java +++ b/rewrite-kotlin/src/test/java/org/openrewrite/kotlin/UseStaticImportTest.java @@ -45,7 +45,7 @@ void noPriorImports() { @Test void withImports() { rewriteRun( - spec -> spec.recipe(new UseStaticImport("java.lang.Integer valueOf(kotlin.Int)")), + spec -> spec.recipe(new UseStaticImport("java.lang.Integer valueOf(int)")), kotlin( """ import java.util.Collections