().also { module.putUserData(libraryVersionsKey, it) }
+ libraryVersions.clear()
+
val presentationManager = LibraryPresentationManager.getInstance()
val context = LibrariesValidatorContextImpl(module)
@@ -108,6 +124,24 @@ class MinecraftFacetDetector : StartupActivity {
.forEachLibrary forEach@{ library ->
MINECRAFT_LIBRARY_KINDS.forEach { kind ->
if (presentationManager.isLibraryOfKind(library, context.librariesContainer, setOf(kind))) {
+ val libraryFiles =
+ context.librariesContainer.getLibraryFiles(library, OrderRootType.CLASSES).toList()
+ LibraryDetectionManager.getInstance().processProperties(
+ libraryFiles,
+ object : LibraryDetectionManager.LibraryPropertiesProcessor {
+ override fun ?> processProperties(
+ kind: LibraryKind,
+ properties: P
+ ): Boolean {
+ return if (properties is LibraryVersionProperties) {
+ libraryVersions[kind] = properties.versionString ?: return true
+ false
+ } else {
+ true
+ }
+ }
+ }
+ )
platformKinds.add(kind)
}
}
diff --git a/src/main/kotlin/insight/ColorLineMarkerProvider.kt b/src/main/kotlin/insight/ColorLineMarkerProvider.kt
index eecdde674..22be92403 100644
--- a/src/main/kotlin/insight/ColorLineMarkerProvider.kt
+++ b/src/main/kotlin/insight/ColorLineMarkerProvider.kt
@@ -134,7 +134,7 @@ class ColorLineMarkerProvider : LineMarkerProvider {
return@handler
}
- val c = ColorChooser.chooseColor(editor.component, "Choose Color", color, false)
+ val c = ColorChooser.chooseColor(psiElement.project, editor.component, "Choose Color", color, false)
?: return@handler
when (workElement) {
is ULiteralExpression -> workElement.setColor(c.rgb and 0xFFFFFF)
diff --git a/src/main/kotlin/platform/bungeecord/creator/BungeeCordProjectCreator.kt b/src/main/kotlin/platform/bungeecord/creator/BungeeCordProjectCreator.kt
index 6125ad121..b37245511 100644
--- a/src/main/kotlin/platform/bungeecord/creator/BungeeCordProjectCreator.kt
+++ b/src/main/kotlin/platform/bungeecord/creator/BungeeCordProjectCreator.kt
@@ -152,8 +152,8 @@ class BungeeCordDependenciesStep(
PlatformType.WATERFALL -> {
buildSystem.repositories.add(
BuildRepository(
- "destroystokyo-repo",
- "https://repo.destroystokyo.com/repository/maven-public/"
+ "papermc-repo",
+ "https://papermc.io/repo/repository/maven-public/"
)
)
buildSystem.dependencies.add(
diff --git a/src/main/kotlin/platform/fabric/creator/FabricProjectConfig.kt b/src/main/kotlin/platform/fabric/creator/FabricProjectConfig.kt
index a9c7da5b3..117d4f8f2 100644
--- a/src/main/kotlin/platform/fabric/creator/FabricProjectConfig.kt
+++ b/src/main/kotlin/platform/fabric/creator/FabricProjectConfig.kt
@@ -71,6 +71,6 @@ class FabricProjectConfig : ProjectConfig(), GradleCreator {
override fun configureRootGradle(rootDirectory: Path, buildSystem: GradleBuildSystem) {
buildSystem.gradleVersion =
- if (semanticMcVersion >= MinecraftVersions.MC1_17) SemanticVersion.release(7, 1, 1) else gradleVersion
+ if (semanticMcVersion >= MinecraftVersions.MC1_17) SemanticVersion.release(7, 3) else gradleVersion
}
}
diff --git a/src/main/kotlin/platform/fabric/creator/FabricProjectSettingsWizard.kt b/src/main/kotlin/platform/fabric/creator/FabricProjectSettingsWizard.kt
index ea4179106..a0854ab64 100644
--- a/src/main/kotlin/platform/fabric/creator/FabricProjectSettingsWizard.kt
+++ b/src/main/kotlin/platform/fabric/creator/FabricProjectSettingsWizard.kt
@@ -258,7 +258,7 @@ class FabricProjectSettingsWizard(private val creator: MinecraftProjectCreator)
}
if (conf.loomVersion >= SemanticVersion.release(0, 7)) {
// TemplateMakerFabric incorrectly indicates loom 0.8 requires Gradle 6...
- conf.gradleVersion = SemanticVersion.release(7, 1, 1)
+ conf.gradleVersion = SemanticVersion.release(7, 3)
}
conf.environment = when ((environmentBox.selectedItem as? String)?.toLowerCase(Locale.ROOT)) {
"client" -> Side.CLIENT
diff --git a/src/main/kotlin/platform/forge/creator/ForgeProjectCreator.kt b/src/main/kotlin/platform/forge/creator/ForgeProjectCreator.kt
index ff8ea4657..3101b2310 100644
--- a/src/main/kotlin/platform/forge/creator/ForgeProjectCreator.kt
+++ b/src/main/kotlin/platform/forge/creator/ForgeProjectCreator.kt
@@ -180,7 +180,7 @@ open class Fg3ProjectCreator(
}
companion object {
- val FG5_WRAPPER_VERSION = SemanticVersion.release(7, 1, 1)
+ val FG5_WRAPPER_VERSION = SemanticVersion.release(7, 3)
}
}
diff --git a/src/main/kotlin/platform/mixin/MixinModule.kt b/src/main/kotlin/platform/mixin/MixinModule.kt
index b6c091f82..35dba6f67 100644
--- a/src/main/kotlin/platform/mixin/MixinModule.kt
+++ b/src/main/kotlin/platform/mixin/MixinModule.kt
@@ -11,9 +11,13 @@
package com.demonwav.mcdev.platform.mixin
import com.demonwav.mcdev.facet.MinecraftFacet
+import com.demonwav.mcdev.facet.MinecraftFacetDetector
import com.demonwav.mcdev.platform.AbstractModule
import com.demonwav.mcdev.platform.PlatformType
import com.demonwav.mcdev.platform.mixin.config.MixinConfig
+import com.demonwav.mcdev.platform.mixin.framework.MIXIN_LIBRARY_KIND
+import com.demonwav.mcdev.util.SemanticVersion
+import com.demonwav.mcdev.util.nullable
import com.intellij.json.psi.JsonFile
import com.intellij.json.psi.JsonObject
import com.intellij.openapi.fileTypes.FileTypeManager
@@ -27,6 +31,13 @@ import com.intellij.psi.search.GlobalSearchScope
import javax.swing.Icon
class MixinModule(facet: MinecraftFacet) : AbstractModule(facet) {
+ val mixinVersion by nullable {
+ var version = MinecraftFacetDetector.getLibraryVersions(facet.module)[MIXIN_LIBRARY_KIND]
+ ?: return@nullable null
+ // fabric mixin uses the format "0.10.4+mixin.0.8.4", return the original string otherwise.
+ version = version.substringAfter("+mixin.")
+ SemanticVersion.parse(version)
+ }
override val moduleType = MixinModuleType
override val type = PlatformType.MIXIN
diff --git a/src/main/kotlin/platform/mixin/config/reference/ConfigProperty.kt b/src/main/kotlin/platform/mixin/config/reference/ConfigProperty.kt
index 410dc1077..c03c7e925 100644
--- a/src/main/kotlin/platform/mixin/config/reference/ConfigProperty.kt
+++ b/src/main/kotlin/platform/mixin/config/reference/ConfigProperty.kt
@@ -11,6 +11,7 @@
package com.demonwav.mcdev.platform.mixin.config.reference
import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Classes.MIXIN_CONFIG
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Classes.MIXIN_SERIALIZED_NAME
import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Classes.SERIALIZED_NAME
import com.demonwav.mcdev.util.constantStringValue
import com.demonwav.mcdev.util.findAnnotation
@@ -62,9 +63,16 @@ object ConfigProperty : PsiReferenceProvider() {
}
private inline fun forEachProperty(configClass: PsiClass, func: (PsiField, String) -> Unit) {
+ val mixinSerializedNameClass =
+ JavaPsiFacade.getInstance(configClass.project).findClass(MIXIN_SERIALIZED_NAME, configClass.resolveScope)
+ val serializedNameName = if (mixinSerializedNameClass != null) {
+ MIXIN_SERIALIZED_NAME
+ } else {
+ SERIALIZED_NAME
+ }
for (field in configClass.fields) {
- val name =
- field.findAnnotation(SERIALIZED_NAME)?.findDeclaredAttributeValue(null)?.constantStringValue ?: continue
+ val annotation = field.findAnnotation(serializedNameName)
+ val name = annotation?.findDeclaredAttributeValue(null)?.constantStringValue ?: continue
func(field, name)
}
}
diff --git a/src/main/kotlin/platform/mixin/debug/MixinPositionManager.kt b/src/main/kotlin/platform/mixin/debug/MixinPositionManager.kt
index 76084452d..6cc17eaee 100644
--- a/src/main/kotlin/platform/mixin/debug/MixinPositionManager.kt
+++ b/src/main/kotlin/platform/mixin/debug/MixinPositionManager.kt
@@ -14,7 +14,6 @@ import com.demonwav.mcdev.platform.mixin.util.MixinConstants
import com.demonwav.mcdev.platform.mixin.util.mixinTargets
import com.demonwav.mcdev.util.findContainingClass
import com.demonwav.mcdev.util.ifEmpty
-import com.demonwav.mcdev.util.mapNotNull
import com.intellij.debugger.MultiRequestPositionManager
import com.intellij.debugger.NoDataException
import com.intellij.debugger.SourcePosition
@@ -33,8 +32,6 @@ import com.sun.jdi.AbsentInformationException
import com.sun.jdi.Location
import com.sun.jdi.ReferenceType
import com.sun.jdi.request.ClassPrepareRequest
-import java.util.stream.Stream
-import kotlin.streams.toList
class MixinPositionManager(private val debugProcess: DebugProcess) : MultiRequestPositionManager {
@@ -86,7 +83,7 @@ class MixinPositionManager(private val debugProcess: DebugProcess) : MultiReques
override fun getAllClasses(classPosition: SourcePosition): List {
return runReadAction {
findMatchingClasses(classPosition)
- .flatMap { name -> debugProcess.virtualMachineProxy.classesByName(name).stream() }
+ .flatMap { name -> debugProcess.virtualMachineProxy.classesByName(name).asSequence() }
.toList()
}
}
@@ -122,11 +119,11 @@ class MixinPositionManager(private val debugProcess: DebugProcess) : MultiReques
}
}
- private fun findMatchingClasses(position: SourcePosition): Stream {
+ private fun findMatchingClasses(position: SourcePosition): Sequence {
val classElement = position.elementAt?.findContainingClass() ?: throw NoDataException.INSTANCE
return classElement.mixinTargets
.ifEmpty { throw NoDataException.INSTANCE }
- .stream()
- .map { it.name }
+ .asSequence()
+ .map { it.name.replace('/', '.') }
}
}
diff --git a/src/main/kotlin/platform/mixin/handlers/InjectAnnotationHandler.kt b/src/main/kotlin/platform/mixin/handlers/InjectAnnotationHandler.kt
index 7cd9e0d2e..c272ff188 100644
--- a/src/main/kotlin/platform/mixin/handlers/InjectAnnotationHandler.kt
+++ b/src/main/kotlin/platform/mixin/handlers/InjectAnnotationHandler.kt
@@ -12,13 +12,23 @@ package com.demonwav.mcdev.platform.mixin.handlers
import com.demonwav.mcdev.platform.mixin.inspection.injector.MethodSignature
import com.demonwav.mcdev.platform.mixin.inspection.injector.ParameterGroup
+import com.demonwav.mcdev.platform.mixin.util.LocalVariables
import com.demonwav.mcdev.platform.mixin.util.callbackInfoReturnableType
import com.demonwav.mcdev.platform.mixin.util.callbackInfoType
import com.demonwav.mcdev.platform.mixin.util.getGenericReturnType
+import com.demonwav.mcdev.platform.mixin.util.hasAccess
+import com.demonwav.mcdev.platform.mixin.util.toPsiType
import com.demonwav.mcdev.util.Parameter
+import com.demonwav.mcdev.util.findModule
+import com.demonwav.mcdev.util.firstIndexOrNull
+import com.intellij.psi.JavaPsiFacade
import com.intellij.psi.PsiAnnotation
+import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiQualifiedReference
import com.intellij.psi.PsiType
+import com.intellij.psi.util.parentOfType
+import org.objectweb.asm.Opcodes
+import org.objectweb.asm.Type
import org.objectweb.asm.tree.ClassNode
import org.objectweb.asm.tree.MethodNode
@@ -36,7 +46,7 @@ class InjectAnnotationHandler : InjectorAnnotationHandler() {
result.add(
ParameterGroup(
collectTargetMethodParameters(annotation.project, targetClass, targetMethod),
- required = false,
+ required = ParameterGroup.RequiredLevel.OPTIONAL,
default = true
)
)
@@ -58,17 +68,57 @@ class InjectAnnotationHandler : InjectorAnnotationHandler() {
)
// Captured locals (only if local capture is enabled)
- // Right now we allow any parameters here since we can't easily
- // detect the local variables that can be captured
- // TODO: now we can work with the bytecode, revisit this?
- if ((
- (annotation.findDeclaredAttributeValue("locals") as? PsiQualifiedReference)
- ?.referenceName ?: "NO_CAPTURE"
- ) != "NO_CAPTURE"
- ) {
- result.add(ParameterGroup(null))
+ val localCapture = (annotation.findDeclaredAttributeValue("locals") as? PsiQualifiedReference)
+ ?.referenceName ?: "NO_CAPTURE"
+ if (localCapture != "NO_CAPTURE") {
+ annotation.findModule()?.let { module ->
+ var commonLocalsPrefix: MutableList? = null
+ val resolvedInsns = resolveInstructions(annotation, targetClass, targetMethod).ifEmpty { return@let }
+ for (insn in resolvedInsns) {
+ val locals = LocalVariables.getLocals(module, targetClass, targetMethod, insn.insn)
+ ?.drop(
+ Type.getArgumentTypes(targetMethod.desc).size +
+ if (targetMethod.hasAccess(Opcodes.ACC_STATIC)) 0 else 1
+ )
+ ?.filterNotNull()
+ ?.filter { it.desc != null }
+ ?: continue
+ if (commonLocalsPrefix == null) {
+ commonLocalsPrefix = locals.toMutableList()
+ } else {
+ val mismatch = commonLocalsPrefix.zip(locals).firstIndexOrNull { (a, b) -> a.desc != b.desc }
+ if (mismatch != null) {
+ commonLocalsPrefix.subList(mismatch, commonLocalsPrefix.size).clear()
+ }
+ }
+ }
+
+ if (commonLocalsPrefix != null) {
+ val elementFactory = JavaPsiFacade.getElementFactory(annotation.project)
+ val localParams = commonLocalsPrefix.map { local ->
+ val type =
+ Type.getType(local.desc).toPsiType(elementFactory, annotation.parentOfType())
+ sanitizedParameter(type, local.name)
+ }
+ val requiredLevel = if (localCapture == "CAPTURE_FAILSOFT") {
+ ParameterGroup.RequiredLevel.WARN_IF_ABSENT
+ } else {
+ ParameterGroup.RequiredLevel.ERROR_IF_ABSENT
+ }
+ result.add(
+ ParameterGroup(
+ localParams,
+ default = true,
+ required = requiredLevel,
+ isVarargs = true
+ )
+ )
+ }
+ }
}
return listOf(MethodSignature(result, PsiType.VOID))
}
+
+ override val allowCoerce = true
}
diff --git a/src/main/kotlin/platform/mixin/handlers/InjectorAnnotationHandler.kt b/src/main/kotlin/platform/mixin/handlers/InjectorAnnotationHandler.kt
index 253757f25..8ac6b2d52 100644
--- a/src/main/kotlin/platform/mixin/handlers/InjectorAnnotationHandler.kt
+++ b/src/main/kotlin/platform/mixin/handlers/InjectorAnnotationHandler.kt
@@ -113,10 +113,11 @@ abstract class InjectorAnnotationHandler : MixinAnnotationHandler {
open fun resolveInstructions(
annotation: PsiAnnotation,
targetClass: ClassNode,
- targetMethod: MethodNode
+ targetMethod: MethodNode,
+ mode: CollectVisitor.Mode = CollectVisitor.Mode.MATCH_ALL
): List> {
val at = annotation.findAttributeValue("at") as? PsiAnnotation ?: return emptyList()
- return AtResolver(at, targetClass, targetMethod).resolveInstructions()
+ return AtResolver(at, targetClass, targetMethod).resolveInstructions(mode)
}
/**
@@ -138,6 +139,8 @@ abstract class InjectorAnnotationHandler : MixinAnnotationHandler {
return "Cannot resolve any target instructions in target class"
}
+ open val allowCoerce = false
+
data class InsnResult(val method: ClassAndMethodNode, val result: CollectVisitor.Result<*>)
companion object {
@@ -165,9 +168,9 @@ abstract class InjectorAnnotationHandler : MixinAnnotationHandler {
protected fun sanitizedParameter(type: PsiType, name: String?): Parameter {
// Parameters should not use ellipsis because others like CallbackInfo may follow
return if (type is PsiEllipsisType) {
- Parameter(name, type.toArrayType())
+ Parameter(name?.toJavaIdentifier(), type.toArrayType())
} else {
- Parameter(name, type)
+ Parameter(name?.toJavaIdentifier(), type)
}
}
}
diff --git a/src/main/kotlin/platform/mixin/handlers/ModifyArgsHandler.kt b/src/main/kotlin/platform/mixin/handlers/ModifyArgsHandler.kt
index beb2bc590..a9e000985 100644
--- a/src/main/kotlin/platform/mixin/handlers/ModifyArgsHandler.kt
+++ b/src/main/kotlin/platform/mixin/handlers/ModifyArgsHandler.kt
@@ -40,7 +40,8 @@ class ModifyArgsHandler : InjectorAnnotationHandler() {
ParameterGroup(listOf(Parameter("args", argsType))),
ParameterGroup(
collectTargetMethodParameters(annotation.project, targetClass, targetMethod),
- required = false
+ required = ParameterGroup.RequiredLevel.OPTIONAL,
+ isVarargs = true
)
),
PsiType.VOID
diff --git a/src/main/kotlin/platform/mixin/handlers/ModifyConstantHandler.kt b/src/main/kotlin/platform/mixin/handlers/ModifyConstantHandler.kt
new file mode 100644
index 000000000..1b8720221
--- /dev/null
+++ b/src/main/kotlin/platform/mixin/handlers/ModifyConstantHandler.kt
@@ -0,0 +1,239 @@
+/*
+ * Minecraft Dev for IntelliJ
+ *
+ * https://minecraftdev.org
+ *
+ * Copyright (c) 2021 minecraft-dev
+ *
+ * MIT License
+ */
+
+package com.demonwav.mcdev.platform.mixin.handlers
+
+import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.CollectVisitor
+import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.ConstantInjectionPoint
+import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.InjectionPoint
+import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.InsnResolutionInfo
+import com.demonwav.mcdev.platform.mixin.inspection.injector.MethodSignature
+import com.demonwav.mcdev.platform.mixin.inspection.injector.ParameterGroup
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Classes.CONSTANT_CONDITION
+import com.demonwav.mcdev.platform.mixin.util.findSourceElement
+import com.demonwav.mcdev.util.constantValue
+import com.demonwav.mcdev.util.descriptor
+import com.demonwav.mcdev.util.findAnnotations
+import com.demonwav.mcdev.util.fullQualifiedName
+import com.demonwav.mcdev.util.parseArray
+import com.demonwav.mcdev.util.resolveClass
+import com.intellij.psi.PsiAnnotation
+import com.intellij.psi.PsiElement
+import com.intellij.psi.PsiEnumConstant
+import com.intellij.psi.PsiManager
+import com.intellij.psi.PsiReferenceExpression
+import com.intellij.psi.PsiType
+import com.intellij.psi.search.GlobalSearchScope
+import org.objectweb.asm.Opcodes
+import org.objectweb.asm.Type
+import org.objectweb.asm.tree.AbstractInsnNode
+import org.objectweb.asm.tree.ClassNode
+import org.objectweb.asm.tree.MethodNode
+
+class ModifyConstantHandler : InjectorAnnotationHandler() {
+ private val allowedOpcodes = setOf(
+ Opcodes.ICONST_M1,
+ Opcodes.ICONST_0,
+ Opcodes.ICONST_1,
+ Opcodes.ICONST_2,
+ Opcodes.ICONST_3,
+ Opcodes.ICONST_4,
+ Opcodes.ICONST_5,
+ Opcodes.LCONST_0,
+ Opcodes.LCONST_1,
+ Opcodes.FCONST_0,
+ Opcodes.FCONST_1,
+ Opcodes.FCONST_2,
+ Opcodes.DCONST_0,
+ Opcodes.DCONST_1,
+ Opcodes.BIPUSH,
+ Opcodes.SIPUSH,
+ Opcodes.LDC,
+ Opcodes.IFLT,
+ Opcodes.IFGE,
+ Opcodes.IFGT,
+ Opcodes.IFLE
+ )
+
+ private class ModifyConstantInfo(
+ val constantInfo: ConstantInjectionPoint.ConstantInfo,
+ val constantAnnotation: PsiAnnotation
+ )
+
+ private fun getConstantInfos(modifyConstant: PsiAnnotation): List? {
+ val constants = modifyConstant.findDeclaredAttributeValue("constant")
+ ?.findAnnotations()
+ ?.takeIf { it.isNotEmpty() }
+ ?: return null
+ return constants.map { constant ->
+ val nullValue = constant.findDeclaredAttributeValue("nullValue")?.constantValue as? Boolean ?: false
+ val intValue = (constant.findDeclaredAttributeValue("intValue")?.constantValue as? Number)?.toInt()
+ val floatValue = (constant.findDeclaredAttributeValue("floatValue")?.constantValue as? Number)?.toFloat()
+ val longValue = (constant.findDeclaredAttributeValue("longValue")?.constantValue as? Number)?.toLong()
+ val doubleValue = (constant.findDeclaredAttributeValue("doubleValue")?.constantValue as? Number)?.toDouble()
+ val stringValue = constant.findDeclaredAttributeValue("stringValue")?.constantValue as? String
+ val classValue = constant.findDeclaredAttributeValue("classValue")?.resolveClass()?.descriptor?.let {
+ Type.getType(
+ it
+ )
+ }
+
+ fun Boolean.toInt(): Int {
+ return if (this) 1 else 0
+ }
+
+ val count = nullValue.toInt() +
+ (intValue != null).toInt() +
+ (floatValue != null).toInt() +
+ (longValue != null).toInt() +
+ (doubleValue != null).toInt() +
+ (stringValue != null).toInt() +
+ (classValue != null).toInt()
+ if (count != 1) {
+ return null
+ }
+
+ val value = if (nullValue) {
+ null
+ } else {
+ intValue ?: floatValue ?: longValue ?: doubleValue ?: stringValue ?: classValue
+ }
+
+ val expandConditions = constant.findDeclaredAttributeValue("expandZeroConditions")?.parseArray {
+ if (it !is PsiReferenceExpression) {
+ return@parseArray null
+ }
+ val field = it.resolve() as? PsiEnumConstant ?: return@parseArray null
+ val enumClass = field.containingClass ?: return@parseArray null
+ if (enumClass.fullQualifiedName != CONSTANT_CONDITION) {
+ return@parseArray null
+ }
+ try {
+ ConstantInjectionPoint.ExpandCondition.valueOf(field.name)
+ } catch (e: IllegalArgumentException) {
+ null
+ }
+ }?.toSet() ?: emptySet()
+
+ ModifyConstantInfo(ConstantInjectionPoint.ConstantInfo(value, expandConditions), constant)
+ }
+ }
+
+ override fun expectedMethodSignature(
+ annotation: PsiAnnotation,
+ targetClass: ClassNode,
+ targetMethod: MethodNode
+ ): List? {
+ val constantInfos = getConstantInfos(annotation) ?: return null
+ val psiManager = PsiManager.getInstance(annotation.project)
+ return constantInfos.asSequence().map {
+ when (it.constantInfo.constant) {
+ null -> PsiType.getJavaLangObject(psiManager, annotation.resolveScope)
+ is Int -> PsiType.INT
+ is Float -> PsiType.FLOAT
+ is Long -> PsiType.LONG
+ is Double -> PsiType.DOUBLE
+ is String -> PsiType.getJavaLangString(psiManager, annotation.resolveScope)
+ is Type -> PsiType.getJavaLangClass(psiManager, annotation.resolveScope)
+ else -> throw IllegalStateException("Unknown constant type: ${it.constantInfo.constant.javaClass.name}")
+ }
+ }.distinct().map { type ->
+ MethodSignature(
+ listOf(
+ ParameterGroup(listOf(sanitizedParameter(type, "constant"))),
+ ParameterGroup(
+ collectTargetMethodParameters(annotation.project, targetClass, targetMethod),
+ isVarargs = true,
+ required = ParameterGroup.RequiredLevel.OPTIONAL
+ )
+ ),
+ type
+ )
+ }.toList()
+ }
+
+ override fun resolveForNavigation(
+ annotation: PsiAnnotation,
+ targetClass: ClassNode,
+ targetMethod: MethodNode
+ ): List {
+ val constantInfos = getConstantInfos(annotation) ?: return emptyList()
+
+ val targetElement = targetMethod.findSourceElement(
+ targetClass,
+ annotation.project,
+ GlobalSearchScope.allScope(annotation.project),
+ canDecompile = true
+ ) ?: return emptyList()
+
+ return constantInfos.asSequence().flatMap { modifyConstantInfo ->
+ val collectVisitor = ConstantInjectionPoint.MyCollectVisitor(
+ annotation.project,
+ CollectVisitor.Mode.MATCH_ALL,
+ modifyConstantInfo.constantInfo
+ )
+ InjectionPoint.addStandardFilters(modifyConstantInfo.constantAnnotation, targetClass, collectVisitor)
+ collectVisitor.visit(targetMethod)
+ val bytecodeResults = collectVisitor.result
+
+ val navigationVisitor = ConstantInjectionPoint.MyNavigationVisitor(modifyConstantInfo.constantInfo)
+ targetElement.accept(navigationVisitor)
+
+ bytecodeResults.asSequence().mapNotNull { bytecodeResult ->
+ navigationVisitor.result.getOrNull(bytecodeResult.index)
+ }
+ }.sortedBy { it.textOffset }.toList()
+ }
+
+ override fun resolveInstructions(
+ annotation: PsiAnnotation,
+ targetClass: ClassNode,
+ targetMethod: MethodNode,
+ mode: CollectVisitor.Mode
+ ): List> {
+ val constantInfos = getConstantInfos(annotation) ?: return emptyList()
+ return constantInfos.asSequence().flatMap { modifyConstantInfo ->
+ val collectVisitor = ConstantInjectionPoint.MyCollectVisitor(
+ annotation.project,
+ mode,
+ modifyConstantInfo.constantInfo
+ )
+ InjectionPoint.addStandardFilters(modifyConstantInfo.constantAnnotation, targetClass, collectVisitor)
+ collectVisitor.visit(targetMethod)
+ collectVisitor.result.asSequence()
+ }.sortedBy { targetMethod.instructions.indexOf(it.insn) }.toList()
+ }
+
+ override fun isUnresolved(
+ annotation: PsiAnnotation,
+ targetClass: ClassNode,
+ targetMethod: MethodNode
+ ): InsnResolutionInfo.Failure? {
+ val constantInfos = getConstantInfos(annotation) ?: return InsnResolutionInfo.Failure()
+ return constantInfos.asSequence().mapNotNull { modifyConstantInfo ->
+ val collectVisitor = ConstantInjectionPoint.MyCollectVisitor(
+ annotation.project,
+ CollectVisitor.Mode.MATCH_FIRST,
+ modifyConstantInfo.constantInfo
+ )
+ InjectionPoint.addStandardFilters(modifyConstantInfo.constantAnnotation, targetClass, collectVisitor)
+ collectVisitor.visit(targetMethod)
+ if (collectVisitor.result.isEmpty()) {
+ collectVisitor.filterToBlame
+ } else {
+ null
+ }
+ }.firstOrNull()?.let(InsnResolutionInfo::Failure)
+ }
+
+ override fun isInsnAllowed(insn: AbstractInsnNode): Boolean {
+ return insn.opcode in allowedOpcodes
+ }
+}
diff --git a/src/main/kotlin/platform/mixin/handlers/ModifyVariableHandler.kt b/src/main/kotlin/platform/mixin/handlers/ModifyVariableHandler.kt
index 731f1ddaa..d89576ae6 100644
--- a/src/main/kotlin/platform/mixin/handlers/ModifyVariableHandler.kt
+++ b/src/main/kotlin/platform/mixin/handlers/ModifyVariableHandler.kt
@@ -10,8 +10,27 @@
package com.demonwav.mcdev.platform.mixin.handlers
+import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.AbstractLoadInjectionPoint
+import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.CollectVisitor
+import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.InjectionPoint
import com.demonwav.mcdev.platform.mixin.inspection.injector.MethodSignature
+import com.demonwav.mcdev.platform.mixin.inspection.injector.ParameterGroup
+import com.demonwav.mcdev.platform.mixin.util.LocalVariables
+import com.demonwav.mcdev.platform.mixin.util.hasAccess
+import com.demonwav.mcdev.platform.mixin.util.toPsiType
+import com.demonwav.mcdev.util.computeStringArray
+import com.demonwav.mcdev.util.constantStringValue
+import com.demonwav.mcdev.util.constantValue
+import com.demonwav.mcdev.util.descriptor
+import com.demonwav.mcdev.util.findContainingMethod
+import com.demonwav.mcdev.util.findModule
+import com.intellij.openapi.module.Module
+import com.intellij.psi.JavaPsiFacade
import com.intellij.psi.PsiAnnotation
+import com.intellij.psi.PsiType
+import org.objectweb.asm.Opcodes
+import org.objectweb.asm.Type
+import org.objectweb.asm.tree.AbstractInsnNode
import org.objectweb.asm.tree.ClassNode
import org.objectweb.asm.tree.MethodNode
@@ -21,7 +40,159 @@ class ModifyVariableHandler : InjectorAnnotationHandler() {
targetClass: ClassNode,
targetMethod: MethodNode
): List? {
- // TODO: implement properly
- return null
+ val module = annotation.findModule() ?: return null
+
+ val at = annotation.findAttributeValue("at") as? PsiAnnotation
+ val atCode = at?.findAttributeValue("value")?.constantStringValue
+ val isLoadStore = atCode != null && InjectionPoint.byAtCode(atCode) is AbstractLoadInjectionPoint
+ val mode = if (isLoadStore) CollectVisitor.Mode.COMPLETION else CollectVisitor.Mode.MATCH_ALL
+ val targets = resolveInstructions(annotation, targetClass, targetMethod, mode)
+
+ val targetParamsGroup = ParameterGroup(
+ collectTargetMethodParameters(annotation.project, targetClass, targetMethod),
+ required = ParameterGroup.RequiredLevel.OPTIONAL,
+ isVarargs = true
+ )
+
+ val info = ModifyVariableInfo.getModifyVariableInfo(annotation, CollectVisitor.Mode.COMPLETION)
+ ?: return null
+
+ val possibleTypes = mutableSetOf()
+ for (insn in targets) {
+ val locals = info.getLocals(module, targetClass, targetMethod, insn.insn) ?: continue
+ val matchedLocals = info.matchLocals(locals, CollectVisitor.Mode.COMPLETION, matchType = false)
+ for (local in matchedLocals) {
+ possibleTypes += local.desc!!
+ }
+ }
+
+ val result = mutableListOf()
+
+ val elementFactory = JavaPsiFacade.getElementFactory(annotation.project)
+ for (type in possibleTypes) {
+ val psiType = Type.getType(type).toPsiType(elementFactory)
+ result += MethodSignature(
+ listOf(
+ ParameterGroup(listOf(sanitizedParameter(psiType, "value"))),
+ targetParamsGroup
+ ),
+ psiType
+ )
+ }
+
+ return result
+ }
+}
+
+class ModifyVariableInfo(
+ val type: PsiType?,
+ val argsOnly: Boolean,
+ val index: Int?,
+ val ordinal: Int?,
+ val names: Set
+) {
+ fun getLocals(
+ module: Module,
+ targetClass: ClassNode,
+ methodNode: MethodNode,
+ insn: AbstractInsnNode
+ ): Array? {
+ return if (argsOnly) {
+ val args = mutableListOf()
+ if (!methodNode.hasAccess(Opcodes.ACC_STATIC)) {
+ val thisDesc = Type.getObjectType(targetClass.name).descriptor
+ args.add(LocalVariables.LocalVariable("this", thisDesc, null, null, null, 0))
+ }
+ for (argType in Type.getArgumentTypes(methodNode.desc)) {
+ args.add(
+ LocalVariables.LocalVariable("arg${args.size}", argType.descriptor, null, null, null, args.size)
+ )
+ if (argType.size == 2) {
+ args.add(null)
+ }
+ }
+ args.toTypedArray()
+ } else {
+ LocalVariables.getLocals(module, targetClass, methodNode, insn)
+ }
+ }
+
+ fun matchLocals(
+ locals: Array,
+ mode: CollectVisitor.Mode,
+ matchType: Boolean = true
+ ): List {
+ val typeDesc = type?.descriptor
+ if (ordinal != null) {
+ val ordinals = mutableMapOf()
+ val result = mutableListOf()
+ for (local in locals) {
+ if (local == null) {
+ continue
+ }
+ val ordinal = ordinals[local.desc] ?: 0
+ ordinals[local.desc!!] = ordinal + 1
+ if (ordinal == ordinal && (!matchType || typeDesc == null || local.desc == typeDesc)) {
+ result += local
+ }
+ }
+ return result
+ }
+
+ if (index != null) {
+ val local = locals.firstOrNull { it?.index == index }
+ if (local != null) {
+ if (!matchType || typeDesc == null || local.desc == typeDesc) {
+ return listOf(local)
+ }
+ }
+ return emptyList()
+ }
+
+ if (names.isNotEmpty()) {
+ val result = mutableListOf()
+ for (local in locals) {
+ if (local == null) {
+ continue
+ }
+ if (names.contains(local.name)) {
+ if (!matchType || typeDesc == null || local.desc == typeDesc) {
+ result += local
+ }
+ }
+ }
+ return result
+ }
+
+ // implicit mode
+ if (mode == CollectVisitor.Mode.COMPLETION) {
+ return locals.asSequence()
+ .filterNotNull()
+ .filter { local -> locals.count { it?.desc == local.desc } == 1 }
+ .toList()
+ }
+
+ return if (matchType && typeDesc != null) {
+ locals.singleOrNull { it?.desc == typeDesc }?.let { listOf(it) } ?: emptyList()
+ } else {
+ emptyList()
+ }
+ }
+
+ companion object {
+ fun getModifyVariableInfo(modifyVariable: PsiAnnotation, mode: CollectVisitor.Mode?): ModifyVariableInfo? {
+ val method = modifyVariable.findContainingMethod() ?: return null
+ val type = method.parameterList.getParameter(0)?.type
+ if (type == null && mode != CollectVisitor.Mode.COMPLETION) {
+ return null
+ }
+ val argsOnly = modifyVariable.findDeclaredAttributeValue("argsOnly")?.constantValue as? Boolean ?: false
+ val index = (modifyVariable.findDeclaredAttributeValue("index")?.constantValue as? Int)
+ ?.takeIf { it != -1 }
+ val ordinal = (modifyVariable.findDeclaredAttributeValue("ordinal")?.constantValue as? Int)
+ ?.takeIf { it != -1 }
+ val names = modifyVariable.findDeclaredAttributeValue("name")?.computeStringArray()?.toSet() ?: emptySet()
+ return ModifyVariableInfo(type, argsOnly, index, ordinal, names)
+ }
}
}
diff --git a/src/main/kotlin/platform/mixin/handlers/RedirectInjectorHandler.kt b/src/main/kotlin/platform/mixin/handlers/RedirectInjectorHandler.kt
index 4b43c4947..a51c88450 100644
--- a/src/main/kotlin/platform/mixin/handlers/RedirectInjectorHandler.kt
+++ b/src/main/kotlin/platform/mixin/handlers/RedirectInjectorHandler.kt
@@ -84,12 +84,15 @@ class RedirectInjectorHandler : InjectorAnnotationHandler() {
// add a parameter group for capturing the target method parameters
val extraGroup = ParameterGroup(
collectTargetMethodParameters(annotation.project, targetClass, targetMethod),
- required = false
+ required = ParameterGroup.RequiredLevel.OPTIONAL,
+ isVarargs = true
)
MethodSignature(paramGroups + extraGroup, returnType)
}
}
+ override val allowCoerce = true
+
private interface RedirectType {
fun isInsnAllowed(node: AbstractInsnNode) = true
@@ -383,7 +386,7 @@ class RedirectInjectorHandler : InjectorAnnotationHandler() {
}
).map { (paramGroups, _) ->
// drop the instance parameter, return the constructed type
- MethodSignature(listOf(ParameterGroup(paramGroups[0].parameters?.drop(1))), constructedType)
+ MethodSignature(listOf(ParameterGroup(paramGroups[0].parameters.drop(1))), constructedType)
}
}
}
diff --git a/src/main/kotlin/platform/mixin/handlers/injectionPoint/AtResolver.kt b/src/main/kotlin/platform/mixin/handlers/injectionPoint/AtResolver.kt
index 64ebf28be..e8c8e700f 100644
--- a/src/main/kotlin/platform/mixin/handlers/injectionPoint/AtResolver.kt
+++ b/src/main/kotlin/platform/mixin/handlers/injectionPoint/AtResolver.kt
@@ -121,16 +121,16 @@ class AtResolver(
}
}
- fun resolveInstructions(): List> {
- return (getInstructionResolutionInfo() as? InsnResolutionInfo.Success)?.results ?: emptyList()
+ fun resolveInstructions(mode: CollectVisitor.Mode = CollectVisitor.Mode.MATCH_ALL): List> {
+ return (getInstructionResolutionInfo(mode) as? InsnResolutionInfo.Success)?.results ?: emptyList()
}
- fun getInstructionResolutionInfo(): InsnResolutionInfo {
+ fun getInstructionResolutionInfo(mode: CollectVisitor.Mode = CollectVisitor.Mode.MATCH_ALL): InsnResolutionInfo {
val injectionPoint = getInjectionPoint(at) ?: return InsnResolutionInfo.Failure()
val targetAttr = at.findAttributeValue("target")
val target = targetAttr?.let { parseMixinSelector(it) }
- val collectVisitor = injectionPoint.createCollectVisitor(at, target, targetClass, CollectVisitor.Mode.MATCH_ALL)
+ val collectVisitor = injectionPoint.createCollectVisitor(at, target, targetClass, mode)
?: return InsnResolutionInfo.Failure()
collectVisitor.visit(targetMethod)
val result = collectVisitor.result
diff --git a/src/main/kotlin/platform/mixin/handlers/injectionPoint/ConstantInjectionPoint.kt b/src/main/kotlin/platform/mixin/handlers/injectionPoint/ConstantInjectionPoint.kt
new file mode 100644
index 000000000..2a063fe19
--- /dev/null
+++ b/src/main/kotlin/platform/mixin/handlers/injectionPoint/ConstantInjectionPoint.kt
@@ -0,0 +1,257 @@
+/*
+ * Minecraft Dev for IntelliJ
+ *
+ * https://minecraftdev.org
+ *
+ * Copyright (c) 2021 minecraft-dev
+ *
+ * MIT License
+ */
+
+package com.demonwav.mcdev.platform.mixin.handlers.injectionPoint
+
+import com.demonwav.mcdev.platform.mixin.reference.MixinSelector
+import com.demonwav.mcdev.util.constantValue
+import com.demonwav.mcdev.util.createLiteralExpression
+import com.demonwav.mcdev.util.descriptor
+import com.intellij.codeInsight.lookup.LookupElementBuilder
+import com.intellij.openapi.project.Project
+import com.intellij.psi.JavaPsiFacade
+import com.intellij.psi.JavaTokenType
+import com.intellij.psi.PsiAnnotation
+import com.intellij.psi.PsiArrayType
+import com.intellij.psi.PsiBinaryExpression
+import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiClassObjectAccessExpression
+import com.intellij.psi.PsiElement
+import com.intellij.psi.PsiExpression
+import com.intellij.psi.PsiForeachStatement
+import com.intellij.psi.PsiLiteralExpression
+import com.intellij.psi.PsiSwitchLabelStatementBase
+import com.intellij.psi.util.PsiUtil
+import java.lang.IllegalArgumentException
+import java.util.Locale
+import org.objectweb.asm.Opcodes
+import org.objectweb.asm.Type
+import org.objectweb.asm.tree.ClassNode
+import org.objectweb.asm.tree.FrameNode
+import org.objectweb.asm.tree.InsnNode
+import org.objectweb.asm.tree.IntInsnNode
+import org.objectweb.asm.tree.JumpInsnNode
+import org.objectweb.asm.tree.LabelNode
+import org.objectweb.asm.tree.LdcInsnNode
+import org.objectweb.asm.tree.MethodNode
+
+class ConstantInjectionPoint : InjectionPoint() {
+ private fun getConstantInfo(at: PsiAnnotation): ConstantInfo? {
+ val args = AtResolver.getArgs(at)
+ val nullValue = args["nullValue"]?.let(java.lang.Boolean::parseBoolean) ?: false
+ val intValue = args["intValue"]?.toIntOrNull()
+ val floatValue = args["floatValue"]?.toFloatOrNull()
+ val longValue = args["longValue"]?.toLongOrNull()
+ val doubleValue = args["doubleValue"]?.toDoubleOrNull()
+ val stringValue = args["stringValue"]
+ val classValue = args["classValue"]?.let { Type.getObjectType(it.replace('.', '/')) }
+ val count =
+ nullValue.toInt() +
+ (intValue != null).toInt() +
+ (floatValue != null).toInt() +
+ (longValue != null).toInt() +
+ (doubleValue != null).toInt() +
+ (stringValue != null).toInt() +
+ (classValue != null).toInt()
+ if (count != 1) {
+ return null
+ }
+
+ val constant = if (nullValue) {
+ null
+ } else {
+ intValue ?: floatValue ?: longValue ?: doubleValue ?: stringValue ?: classValue!!
+ }
+
+ val expandConditions = args["expandZeroConditions"]
+ ?.replace(" ", "")
+ ?.split(',')
+ ?.mapNotNull {
+ try {
+ ExpandCondition.valueOf(it.toUpperCase(Locale.ROOT))
+ } catch (e: IllegalArgumentException) {
+ null
+ }
+ }
+ ?.toSet() ?: emptySet()
+
+ return ConstantInfo(constant, expandConditions)
+ }
+
+ private fun Boolean.toInt(): Int {
+ return if (this) 1 else 0
+ }
+
+ override fun createNavigationVisitor(
+ at: PsiAnnotation,
+ target: MixinSelector?,
+ targetClass: PsiClass
+ ): NavigationVisitor? {
+ val constantInfo = getConstantInfo(at) ?: return null
+ return MyNavigationVisitor(constantInfo)
+ }
+
+ override fun doCreateCollectVisitor(
+ at: PsiAnnotation,
+ target: MixinSelector?,
+ targetClass: ClassNode,
+ mode: CollectVisitor.Mode
+ ): CollectVisitor? {
+ val constantInfo = getConstantInfo(at) ?: return null
+ return MyCollectVisitor(at.project, mode, constantInfo)
+ }
+
+ override fun createLookup(
+ targetClass: ClassNode,
+ result: CollectVisitor.Result
+ ): LookupElementBuilder? {
+ return null
+ }
+
+ class ConstantInfo(val constant: Any?, val expandConditions: Set)
+
+ enum class ExpandCondition(vararg val opcodes: Int) {
+ LESS_THAN_ZERO(Opcodes.IFLT, Opcodes.IFGE),
+ LESS_THAN_OR_EQUAL_TO_ZERO(Opcodes.IFLE, Opcodes.IFGT),
+ GREATER_THAN_ZERO(Opcodes.IFLE, Opcodes.IFGT),
+ GREATER_THAN_OR_EQUAL_TO_ZERO(Opcodes.IFLT, Opcodes.IFGE),
+ }
+
+ class MyNavigationVisitor(
+ private val constantInfo: ConstantInfo
+ ) : NavigationVisitor() {
+ override fun visitForeachStatement(statement: PsiForeachStatement) {
+ if (statement.iteratedValue?.type is PsiArrayType) {
+ // index initialized to 0
+ visitConstant(statement, 0)
+ }
+ super.visitForeachStatement(statement)
+ }
+
+ override fun visitClassObjectAccessExpression(expression: PsiClassObjectAccessExpression) {
+ visitConstant(expression, Type.getType(expression.operand.type.descriptor))
+ }
+
+ override fun visitLiteralExpression(expression: PsiLiteralExpression) {
+ if (expression.textMatches("null")) {
+ visitConstant(expression, null)
+ } else {
+ super.visitLiteralExpression(expression)
+ }
+ }
+
+ override fun visitExpression(expression: PsiExpression) {
+ if (PsiUtil.isConstantExpression(expression)) {
+ val value = expression.constantValue
+ if (value != null) {
+ visitConstant(expression, value)
+ return
+ }
+ }
+ super.visitExpression(expression)
+ }
+
+ private fun visitConstant(element: PsiElement, value: Any?) {
+ if (value != constantInfo.constant) {
+ return
+ }
+
+ val parent = PsiUtil.skipParenthesizedExprUp(element.parent)
+
+ // check for expandZeroConditions
+ if (value == null || value == 0) {
+ if (parent is PsiBinaryExpression) {
+ val operation = parent.operationTokenType
+ if (operation == JavaTokenType.EQEQ || operation == JavaTokenType.NE) {
+ return
+ }
+ val opcode = when (operation) {
+ JavaTokenType.LT -> Opcodes.IFLT
+ JavaTokenType.LE -> Opcodes.IFLE
+ JavaTokenType.GT -> Opcodes.IFGT
+ JavaTokenType.GE -> Opcodes.IFGE
+ else -> null
+ }
+ if (opcode != null && !constantInfo.expandConditions.any { opcode in it.opcodes }) {
+ return
+ }
+ }
+ }
+
+ // check for switch statement (compiles to tableswitch or lookupswitch which aren't targeted)
+ if (parent is PsiSwitchLabelStatementBase) {
+ return
+ }
+
+ addResult(element)
+ }
+ }
+
+ class MyCollectVisitor(
+ private val project: Project,
+ mode: Mode,
+ private val constantInfo: ConstantInfo
+ ) : CollectVisitor(mode) {
+ override fun accept(methodNode: MethodNode) {
+ val elementFactory = JavaPsiFacade.getElementFactory(project)
+ methodNode.instructions?.iterator()?.forEachRemaining { insn ->
+ val constant = when (insn) {
+ is InsnNode -> when (insn.opcode) {
+ in Opcodes.ICONST_M1..Opcodes.ICONST_5 -> insn.opcode - Opcodes.ICONST_0
+ Opcodes.LCONST_0 -> 0L
+ Opcodes.LCONST_1 -> 1L
+ Opcodes.FCONST_0 -> 0.0f
+ Opcodes.FCONST_1 -> 1.0f
+ Opcodes.FCONST_2 -> 2.0f
+ Opcodes.DCONST_0 -> 0.0
+ Opcodes.DCONST_1 -> 1.0
+ Opcodes.ACONST_NULL -> null
+ else -> return@forEachRemaining
+ }
+ is IntInsnNode -> when (insn.opcode) {
+ Opcodes.BIPUSH, Opcodes.SIPUSH -> insn.operand
+ else -> return@forEachRemaining
+ }
+ is LdcInsnNode -> insn.cst
+ is JumpInsnNode -> {
+ if (!constantInfo.expandConditions.any { insn.opcode in it.opcodes }) {
+ return@forEachRemaining
+ }
+ var lastInsn = insn.previous
+ while (lastInsn != null && (lastInsn is LabelNode || lastInsn is FrameNode)) {
+ lastInsn = lastInsn.previous
+ }
+ if (lastInsn != null) {
+ val lastOpcode = lastInsn.opcode
+ if (lastOpcode == Opcodes.LCMP ||
+ lastOpcode == Opcodes.FCMPL ||
+ lastOpcode == Opcodes.FCMPG ||
+ lastOpcode == Opcodes.DCMPL ||
+ lastOpcode == Opcodes.DCMPG
+ ) {
+ return@forEachRemaining
+ }
+ }
+ 0
+ }
+ else -> return@forEachRemaining
+ }
+ if (constant == constantInfo.constant) {
+ val literal = if (constant is Type) {
+ elementFactory.createExpressionFromText("${constant.className}.class", null)
+ } else {
+ elementFactory.createLiteralExpression(constant)
+ }
+ addResult(insn, literal)
+ }
+ }
+ }
+ }
+}
diff --git a/src/main/kotlin/platform/mixin/handlers/injectionPoint/InjectionPoint.kt b/src/main/kotlin/platform/mixin/handlers/injectionPoint/InjectionPoint.kt
index 5ca35fda8..aecbba618 100644
--- a/src/main/kotlin/platform/mixin/handlers/injectionPoint/InjectionPoint.kt
+++ b/src/main/kotlin/platform/mixin/handlers/injectionPoint/InjectionPoint.kt
@@ -58,6 +58,100 @@ abstract class InjectionPoint {
fun byAtCode(atCode: String): InjectionPoint<*>? {
return COLLECTOR.findSingle(atCode)
}
+
+ fun addStandardFilters(at: PsiAnnotation, targetClass: ClassNode, collectVisitor: CollectVisitor<*>) {
+ addShiftSupport(at, collectVisitor)
+ addSliceFilter(at, targetClass, collectVisitor)
+ // make sure the ordinal filter is last, so that the ordinal only increments once the other filters have passed
+ addOrdinalFilter(at, collectVisitor)
+ }
+
+ private fun addShiftSupport(at: PsiAnnotation, collectVisitor: CollectVisitor<*>) {
+ val shiftAttr = at.findDeclaredAttributeValue("shift") as? PsiExpression ?: return
+ val shiftReference = PsiUtil.skipParenthesizedExprDown(shiftAttr) as? PsiReferenceExpression ?: return
+ val shift = shiftReference.resolve() as? PsiEnumConstant ?: return
+ val containingClass = shift.containingClass ?: return
+ val shiftClass = JavaPsiFacade.getInstance(at.project).findClass(SHIFT, at.resolveScope) ?: return
+ if (!(containingClass equivalentTo shiftClass)) return
+ when (shift.name) {
+ "BEFORE" -> collectVisitor.shiftBy = -1
+ "AFTER" -> collectVisitor.shiftBy = 1
+ "BY" -> {
+ val by = at.findDeclaredAttributeValue("by")?.constantValue as? Int ?: return
+ collectVisitor.shiftBy = by
+ }
+ }
+ }
+
+ private fun addSliceFilter(at: PsiAnnotation, targetClass: ClassNode, collectVisitor: CollectVisitor<*>) {
+ // resolve slice annotation, take into account slice id if present
+ val sliceId = at.findDeclaredAttributeValue("slice")?.constantStringValue
+ val parentAnnotation = at.parentOfType() ?: return
+ val slices = parentAnnotation.findDeclaredAttributeValue("slice")?.findAnnotations() ?: return
+ val slice = if (sliceId != null) {
+ slices.singleOrNull { aSlice ->
+ val aSliceId = aSlice.findDeclaredAttributeValue("id")?.constantStringValue
+ ?: return@singleOrNull false
+ aSliceId == sliceId
+ }
+ } else {
+ slices.singleOrNull()
+ } ?: return
+
+ // precompute what we can
+ val from = slice.findDeclaredAttributeValue("from") as? PsiAnnotation
+ val to = slice.findDeclaredAttributeValue("to") as? PsiAnnotation
+ if (from == null && to == null) {
+ return
+ }
+ val fromSelector = from?.findDeclaredAttributeValue("value")?.constantStringValue?.let { atCode ->
+ SliceSelector.values().firstOrNull { atCode.endsWith(":${it.name}") }
+ } ?: SliceSelector.FIRST
+ val toSelector = to?.findDeclaredAttributeValue("value")?.constantStringValue?.let { atCode ->
+ SliceSelector.values().firstOrNull { atCode.endsWith(":${it.name}") }
+ } ?: SliceSelector.FIRST
+
+ fun resolveSliceIndex(
+ sliceAt: PsiAnnotation?,
+ selector: SliceSelector,
+ insns: InsnList,
+ method: MethodNode
+ ): Int? {
+ return sliceAt?.let {
+ val results = AtResolver(sliceAt, targetClass, method).resolveInstructions()
+ val insn = if (selector == SliceSelector.LAST) {
+ results.lastOrNull()?.insn
+ } else {
+ results.firstOrNull()?.insn
+ }
+ insn?.let { insns.indexOf(it) }
+ }
+ }
+
+ // allocate lazy indexes so we don't have to re-run the at resolver for the slices each time
+ var fromInsnIndex: Int? = null
+ var toInsnIndex: Int? = null
+
+ collectVisitor.addResultFilter("slice") { result, method ->
+ val insns = method.instructions ?: return@addResultFilter true
+ if (fromInsnIndex == null) {
+ fromInsnIndex = resolveSliceIndex(from, fromSelector, insns, method) ?: 0
+ }
+ if (toInsnIndex == null) {
+ toInsnIndex = resolveSliceIndex(to, toSelector, insns, method) ?: insns.size()
+ }
+
+ insns.indexOf(result.insn) in fromInsnIndex!!..toInsnIndex!!
+ }
+ }
+
+ private fun addOrdinalFilter(at: PsiAnnotation, collectVisitor: CollectVisitor<*>) {
+ val ordinal = at.findDeclaredAttributeValue("ordinal")?.constantValue as? Int ?: return
+ if (ordinal < 0) return
+ collectVisitor.addResultFilter("ordinal") { _, _ ->
+ collectVisitor.ordinal++ == ordinal
+ }
+ }
}
open fun usesMemberReference() = false
@@ -90,100 +184,6 @@ abstract class InjectionPoint {
addStandardFilters(at, targetClass, collectVisitor)
}
- protected fun addStandardFilters(at: PsiAnnotation, targetClass: ClassNode, collectVisitor: CollectVisitor) {
- addShiftSupport(at, collectVisitor)
- addSliceFilter(at, targetClass, collectVisitor)
- // make sure the ordinal filter is last, so that the ordinal only increments once the other filters have passed
- addOrdinalFilter(at, collectVisitor)
- }
-
- private fun addShiftSupport(at: PsiAnnotation, collectVisitor: CollectVisitor) {
- val shiftAttr = at.findDeclaredAttributeValue("shift") as? PsiExpression ?: return
- val shiftReference = PsiUtil.skipParenthesizedExprDown(shiftAttr) as? PsiReferenceExpression ?: return
- val shift = shiftReference.resolve() as? PsiEnumConstant ?: return
- val containingClass = shift.containingClass ?: return
- val shiftClass = JavaPsiFacade.getInstance(at.project).findClass(SHIFT, at.resolveScope) ?: return
- if (!(containingClass equivalentTo shiftClass)) return
- when (shift.name) {
- "BEFORE" -> collectVisitor.shiftBy = -1
- "AFTER" -> collectVisitor.shiftBy = 1
- "BY" -> {
- val by = at.findDeclaredAttributeValue("by")?.constantValue as? Int ?: return
- collectVisitor.shiftBy = by
- }
- }
- }
-
- private fun addSliceFilter(at: PsiAnnotation, targetClass: ClassNode, collectVisitor: CollectVisitor) {
- // resolve slice annotation, take into account slice id if present
- val sliceId = at.findDeclaredAttributeValue("slice")?.constantStringValue
- val parentAnnotation = at.parentOfType() ?: return
- val slices = parentAnnotation.findDeclaredAttributeValue("slice")?.findAnnotations() ?: return
- val slice = if (sliceId != null) {
- slices.singleOrNull { aSlice ->
- val aSliceId = aSlice.findDeclaredAttributeValue("id")?.constantStringValue
- ?: return@singleOrNull false
- aSliceId == sliceId
- }
- } else {
- slices.singleOrNull()
- } ?: return
-
- // precompute what we can
- val from = slice.findDeclaredAttributeValue("from") as? PsiAnnotation
- val to = slice.findDeclaredAttributeValue("to") as? PsiAnnotation
- if (from == null && to == null) {
- return
- }
- val fromSelector = from?.findDeclaredAttributeValue("value")?.constantStringValue?.let { atCode ->
- SliceSelector.values().firstOrNull { atCode.endsWith(":${it.name}") }
- } ?: SliceSelector.FIRST
- val toSelector = to?.findDeclaredAttributeValue("value")?.constantStringValue?.let { atCode ->
- SliceSelector.values().firstOrNull { atCode.endsWith(":${it.name}") }
- } ?: SliceSelector.FIRST
-
- fun resolveSliceIndex(
- sliceAt: PsiAnnotation?,
- selector: SliceSelector,
- insns: InsnList,
- method: MethodNode
- ): Int? {
- return sliceAt?.let {
- val results = AtResolver(sliceAt, targetClass, method).resolveInstructions()
- val insn = if (selector == SliceSelector.LAST) {
- results.lastOrNull()?.insn
- } else {
- results.firstOrNull()?.insn
- }
- insn?.let { insns.indexOf(it) }
- }
- }
-
- // allocate lazy indexes so we don't have to re-run the at resolver for the slices each time
- var fromInsnIndex: Int? = null
- var toInsnIndex: Int? = null
-
- collectVisitor.addResultFilter("slice") { result, method ->
- val insns = method.instructions ?: return@addResultFilter true
- if (fromInsnIndex == null) {
- fromInsnIndex = resolveSliceIndex(from, fromSelector, insns, method) ?: 0
- }
- if (toInsnIndex == null) {
- toInsnIndex = resolveSliceIndex(to, toSelector, insns, method) ?: insns.size()
- }
-
- insns.indexOf(result.insn) in fromInsnIndex!!..toInsnIndex!!
- }
- }
-
- private fun addOrdinalFilter(at: PsiAnnotation, collectVisitor: CollectVisitor) {
- val ordinal = at.findDeclaredAttributeValue("ordinal")?.constantValue as? Int ?: return
- if (ordinal < 0) return
- collectVisitor.addResultFilter("ordinal") { _, _ ->
- collectVisitor.ordinal++ == ordinal
- }
- }
-
abstract fun createLookup(targetClass: ClassNode, result: CollectVisitor.Result): LookupElementBuilder?
protected fun LookupElementBuilder.setBoldIfInClass(member: PsiMember, clazz: ClassNode): LookupElementBuilder {
diff --git a/src/main/kotlin/platform/mixin/handlers/injectionPoint/LoadInjectionPoint.kt b/src/main/kotlin/platform/mixin/handlers/injectionPoint/LoadInjectionPoint.kt
new file mode 100644
index 000000000..c0f6886ef
--- /dev/null
+++ b/src/main/kotlin/platform/mixin/handlers/injectionPoint/LoadInjectionPoint.kt
@@ -0,0 +1,301 @@
+/*
+ * Minecraft Dev for IntelliJ
+ *
+ * https://minecraftdev.org
+ *
+ * Copyright (c) 2021 minecraft-dev
+ *
+ * MIT License
+ */
+
+package com.demonwav.mcdev.platform.mixin.handlers.injectionPoint
+
+import com.demonwav.mcdev.platform.mixin.handlers.ModifyVariableInfo
+import com.demonwav.mcdev.platform.mixin.reference.MixinSelector
+import com.demonwav.mcdev.platform.mixin.util.LocalVariables
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.MODIFY_VARIABLE
+import com.demonwav.mcdev.util.constantValue
+import com.demonwav.mcdev.util.findModule
+import com.demonwav.mcdev.util.isErasureEquivalentTo
+import com.intellij.codeInsight.lookup.LookupElementBuilder
+import com.intellij.openapi.module.Module
+import com.intellij.psi.JavaPsiFacade
+import com.intellij.psi.JavaTokenType
+import com.intellij.psi.PsiAnnotation
+import com.intellij.psi.PsiAssignmentExpression
+import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiElement
+import com.intellij.psi.PsiExpressionStatement
+import com.intellij.psi.PsiForeachStatement
+import com.intellij.psi.PsiParameter
+import com.intellij.psi.PsiPrimitiveType
+import com.intellij.psi.PsiReferenceExpression
+import com.intellij.psi.PsiThisExpression
+import com.intellij.psi.PsiType
+import com.intellij.psi.PsiUnaryExpression
+import com.intellij.psi.PsiVariable
+import com.intellij.psi.util.PsiUtil
+import com.intellij.psi.util.parentOfType
+import org.objectweb.asm.Opcodes
+import org.objectweb.asm.tree.ClassNode
+import org.objectweb.asm.tree.MethodNode
+import org.objectweb.asm.tree.VarInsnNode
+
+abstract class AbstractLoadInjectionPoint(private val store: Boolean) : InjectionPoint() {
+ private fun getModifyVariableInfo(at: PsiAnnotation, mode: CollectVisitor.Mode?): ModifyVariableInfo? {
+ val modifyVariable = at.parentOfType() ?: return null
+ if (!modifyVariable.hasQualifiedName(MODIFY_VARIABLE)) {
+ return null
+ }
+ return ModifyVariableInfo.getModifyVariableInfo(modifyVariable, mode)
+ }
+
+ override fun createNavigationVisitor(
+ at: PsiAnnotation,
+ target: MixinSelector?,
+ targetClass: PsiClass
+ ): NavigationVisitor? {
+ val info = getModifyVariableInfo(at, null) ?: return null
+ return MyNavigationVisitor(info, store)
+ }
+
+ override fun doCreateCollectVisitor(
+ at: PsiAnnotation,
+ target: MixinSelector?,
+ targetClass: ClassNode,
+ mode: CollectVisitor.Mode
+ ): CollectVisitor? {
+ val module = at.findModule() ?: return null
+ val info = getModifyVariableInfo(at, mode) ?: return null
+ return MyCollectVisitor(module, targetClass, mode, info, store)
+ }
+
+ override fun createLookup(
+ targetClass: ClassNode,
+ result: CollectVisitor.Result
+ ): LookupElementBuilder? {
+ return null
+ }
+
+ private class MyNavigationVisitor(
+ private val info: ModifyVariableInfo,
+ private val store: Boolean
+ ) : NavigationVisitor() {
+ override fun visitThisExpression(expression: PsiThisExpression) {
+ super.visitThisExpression(expression)
+ if (!store && expression.qualifier == null) {
+ addLocalUsage(expression, "this")
+ }
+ }
+
+ override fun visitVariable(variable: PsiVariable) {
+ super.visitVariable(variable)
+ if (store && variable.initializer != null) {
+ val name = variable.name
+ if (name != null) {
+ addLocalUsage(variable, name)
+ }
+ }
+ }
+
+ override fun visitReferenceExpression(expression: PsiReferenceExpression) {
+ super.visitReferenceExpression(expression)
+ val referenceName = expression.referenceName ?: return
+ if (expression.qualifierExpression == null) {
+ val isCorrectAccessType = if (store) {
+ PsiUtil.isAccessedForWriting(expression)
+ } else {
+ PsiUtil.isAccessedForReading(expression)
+ }
+ if (!isCorrectAccessType) {
+ return
+ }
+ val resolved = expression.resolve() as? PsiVariable ?: return
+ val type = resolved.type
+ if (type is PsiPrimitiveType &&
+ type != PsiType.FLOAT &&
+ type != PsiType.DOUBLE &&
+ type != PsiType.LONG &&
+ type != PsiType.BOOLEAN
+ ) {
+ // ModifyVariable currently cannot handle iinc
+ val parentExpr = PsiUtil.skipParenthesizedExprUp(expression.parent)
+ val isIincUnary = parentExpr is PsiUnaryExpression &&
+ (
+ parentExpr.operationSign.tokenType == JavaTokenType.PLUSPLUS ||
+ parentExpr.operationSign.tokenType == JavaTokenType.MINUSMINUS
+ )
+ val isIincAssignment = parentExpr is PsiAssignmentExpression &&
+ (
+ parentExpr.operationSign.tokenType == JavaTokenType.PLUSEQ ||
+ parentExpr.operationSign.tokenType == JavaTokenType.MINUSEQ
+ ) &&
+ PsiUtil.isConstantExpression(parentExpr.rExpression) &&
+ (parentExpr.rExpression?.constantValue as? Number)?.toInt()
+ ?.let { it >= Short.MIN_VALUE && it <= Short.MAX_VALUE } == true
+ val isIinc = isIincUnary || isIincAssignment
+ if (isIinc) {
+ if (store) {
+ return
+ }
+ val parentParent = PsiUtil.skipParenthesizedExprUp(parentExpr.parent)
+ if (parentParent is PsiExpressionStatement) {
+ return
+ }
+ }
+ }
+ if (!info.argsOnly || resolved is PsiParameter) {
+ addLocalUsage(expression, referenceName)
+ }
+ }
+ }
+
+ override fun visitForeachStatement(statement: PsiForeachStatement) {
+ checkImplicitLocalsPre(statement)
+ if (store) {
+ addLocalUsage(statement.iterationParameter, statement.iterationParameter.name)
+ }
+ super.visitForeachStatement(statement)
+ checkImplicitLocalsPost(statement)
+ }
+
+ private fun checkImplicitLocalsPre(location: PsiElement) {
+ val localsHere = LocalVariables.guessLocalsAt(location, info.argsOnly, true)
+ val localIndex = LocalVariables.guessLocalVariableIndex(location) ?: return
+ val localCount = LocalVariables.getLocalVariableSize(location)
+ for (i in localIndex until (localIndex + localCount)) {
+ val local = localsHere.firstOrNull { it.index == i } ?: continue
+ if (store) {
+ repeat(local.implicitStoreCountBefore) {
+ addLocalUsage(location, local.name, localsHere)
+ }
+ } else {
+ repeat(local.implicitLoadCountBefore) {
+ addLocalUsage(location, local.name, localsHere)
+ }
+ }
+ }
+ }
+
+ private fun checkImplicitLocalsPost(location: PsiElement) {
+ val localsHere = LocalVariables.guessLocalsAt(location, info.argsOnly, false)
+ val localIndex = LocalVariables.guessLocalVariableIndex(location) ?: return
+ val localCount = LocalVariables.getLocalVariableSize(location)
+ for (i in localIndex until (localIndex + localCount)) {
+ val local = localsHere.firstOrNull { it.index == i } ?: continue
+ if (store) {
+ repeat(local.implicitStoreCountAfter) {
+ addLocalUsage(location, local.name, localsHere)
+ }
+ } else {
+ repeat(local.implicitLoadCountAfter) {
+ addLocalUsage(location, local.name, localsHere)
+ }
+ }
+ }
+ }
+
+ private fun addLocalUsage(location: PsiElement, name: String) {
+ val localsHere = LocalVariables.guessLocalsAt(location, info.argsOnly, !store)
+ addLocalUsage(location, name, localsHere)
+ }
+
+ private fun addLocalUsage(
+ location: PsiElement,
+ name: String,
+ localsHere: List
+ ) {
+ if (info.ordinal != null) {
+ val local = localsHere.asSequence().filter {
+ it.type.isErasureEquivalentTo(info.type)
+ }.drop(info.ordinal).firstOrNull()
+ if (name == local?.name) {
+ addResult(location)
+ }
+ return
+ }
+
+ if (info.index != null) {
+ val local = localsHere.getOrNull(info.index)
+ if (name == local?.name) {
+ addResult(location)
+ }
+ return
+ }
+
+ if (info.names.isNotEmpty()) {
+ val matchingLocals = localsHere.filter {
+ info.names.contains(it.mixinName)
+ }
+ for (local in matchingLocals) {
+ if (local.name == name) {
+ addResult(location)
+ }
+ }
+ return
+ }
+
+ // implicit mode
+ val local = localsHere.singleOrNull {
+ it.type.isErasureEquivalentTo(info.type)
+ }
+ if (local != null && local.name == name) {
+ addResult(location)
+ }
+ }
+ }
+
+ private class MyCollectVisitor(
+ private val module: Module,
+ private val targetClass: ClassNode,
+ mode: Mode,
+ private val info: ModifyVariableInfo,
+ private val store: Boolean
+ ) : CollectVisitor(mode) {
+ override fun accept(methodNode: MethodNode) {
+ var opcode = when (info.type) {
+ null -> null
+ !is PsiPrimitiveType -> Opcodes.ALOAD
+ PsiType.LONG -> Opcodes.LLOAD
+ PsiType.FLOAT -> Opcodes.FLOAD
+ PsiType.DOUBLE -> Opcodes.DLOAD
+ else -> Opcodes.ILOAD
+ }
+ if (store && opcode != null) {
+ opcode += (Opcodes.ISTORE - Opcodes.ILOAD)
+ }
+ for (insn in methodNode.instructions) {
+ if (insn !is VarInsnNode) {
+ continue
+ }
+ if (opcode != null) {
+ if (opcode != insn.opcode) {
+ continue
+ }
+ } else {
+ if (store) {
+ if (insn.opcode < Opcodes.ISTORE || insn.opcode > Opcodes.ASTORE) {
+ continue
+ }
+ } else {
+ if (insn.opcode < Opcodes.ILOAD || insn.opcode > Opcodes.ALOAD) {
+ continue
+ }
+ }
+ }
+
+ val localLocation = if (store) insn.next ?: insn else insn
+ val locals = info.getLocals(module, targetClass, methodNode, localLocation) ?: continue
+
+ val elementFactory = JavaPsiFacade.getElementFactory(module.project)
+
+ for (result in info.matchLocals(locals, mode)) {
+ addResult(insn, elementFactory.createExpressionFromText(result.name, null))
+ }
+ }
+ }
+ }
+}
+
+class LoadInjectionPoint : AbstractLoadInjectionPoint(false)
+class StoreInjectionPoint : AbstractLoadInjectionPoint(true)
diff --git a/src/main/kotlin/platform/mixin/inspection/injector/InjectIntoConstructorInspection.kt b/src/main/kotlin/platform/mixin/inspection/injector/InjectIntoConstructorInspection.kt
new file mode 100644
index 000000000..dd5555977
--- /dev/null
+++ b/src/main/kotlin/platform/mixin/inspection/injector/InjectIntoConstructorInspection.kt
@@ -0,0 +1,80 @@
+/*
+ * Minecraft Dev for IntelliJ
+ *
+ * https://minecraftdev.org
+ *
+ * Copyright (c) 2021 minecraft-dev
+ *
+ * MIT License
+ */
+
+package com.demonwav.mcdev.platform.mixin.inspection.injector
+
+import com.demonwav.mcdev.facet.MinecraftFacet
+import com.demonwav.mcdev.platform.fabric.FabricModuleType
+import com.demonwav.mcdev.platform.mixin.handlers.InjectorAnnotationHandler
+import com.demonwav.mcdev.platform.mixin.handlers.MixinAnnotationHandler
+import com.demonwav.mcdev.platform.mixin.inspection.MixinInspection
+import com.demonwav.mcdev.platform.mixin.util.MethodTargetMember
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.INJECT
+import com.demonwav.mcdev.platform.mixin.util.isConstructor
+import com.demonwav.mcdev.util.findAnnotation
+import com.demonwav.mcdev.util.findModule
+import com.intellij.codeInspection.ProblemsHolder
+import com.intellij.psi.JavaElementVisitor
+import com.intellij.psi.PsiElementVisitor
+import com.intellij.psi.PsiMethod
+import java.awt.FlowLayout
+import javax.swing.JCheckBox
+import javax.swing.JComponent
+import javax.swing.JPanel
+import org.objectweb.asm.Opcodes
+
+class InjectIntoConstructorInspection : MixinInspection() {
+ @JvmField
+ var ALLOW_ON_FABRIC = true
+
+ override fun createOptionsPanel(): JComponent {
+ val panel = JPanel(FlowLayout(FlowLayout.LEFT))
+ val checkbox = JCheckBox("Allow @Inject into constructors in Fabric", ALLOW_ON_FABRIC)
+ checkbox.addActionListener {
+ ALLOW_ON_FABRIC = checkbox.isSelected
+ }
+ panel.add(checkbox)
+ return panel
+ }
+
+ override fun buildVisitor(holder: ProblemsHolder): PsiElementVisitor {
+ val isFabric = holder.file.findModule()?.let { MinecraftFacet.getInstance(it) }?.isOfType(FabricModuleType)
+ ?: false
+ if (isFabric && ALLOW_ON_FABRIC) {
+ return PsiElementVisitor.EMPTY_VISITOR
+ }
+
+ return object : JavaElementVisitor() {
+ override fun visitMethod(method: PsiMethod) {
+ super.visitMethod(method)
+ val injectAnnotation = method.findAnnotation(INJECT) ?: return
+ val problemElement = injectAnnotation.nameReferenceElement ?: return
+ val handler = MixinAnnotationHandler.forMixinAnnotation(INJECT) as? InjectorAnnotationHandler ?: return
+ val targets = handler.resolveTarget(injectAnnotation)
+ for (target in targets) {
+ if (target !is MethodTargetMember || !target.classAndMethod.method.isConstructor) {
+ continue
+ }
+ val (targetClass, targetMethod) = target.classAndMethod
+ val instructions = handler.resolveInstructions(injectAnnotation, targetClass, targetMethod)
+ if (instructions.any { it.insn.opcode != Opcodes.RETURN }) {
+ holder.registerProblem(
+ problemElement,
+ "Cannot inject into constructors at non-return instructions"
+ )
+ return
+ }
+ }
+ }
+ }
+ }
+
+ override fun getStaticDescription() = "@Inject into Constructor"
+}
diff --git a/src/main/kotlin/platform/mixin/inspection/injector/InvalidInjectorMethodSignatureInspection.kt b/src/main/kotlin/platform/mixin/inspection/injector/InvalidInjectorMethodSignatureInspection.kt
index 0ba8559ac..90263c7d5 100644
--- a/src/main/kotlin/platform/mixin/inspection/injector/InvalidInjectorMethodSignatureInspection.kt
+++ b/src/main/kotlin/platform/mixin/inspection/injector/InvalidInjectorMethodSignatureInspection.kt
@@ -16,15 +16,17 @@ import com.demonwav.mcdev.platform.mixin.handlers.MixinAnnotationHandler
import com.demonwav.mcdev.platform.mixin.inspection.MixinInspection
import com.demonwav.mcdev.platform.mixin.reference.MethodReference
import com.demonwav.mcdev.platform.mixin.util.MixinConstants
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.COERCE
import com.demonwav.mcdev.platform.mixin.util.hasAccess
+import com.demonwav.mcdev.platform.mixin.util.isAssignable
import com.demonwav.mcdev.platform.mixin.util.isConstructor
import com.demonwav.mcdev.util.Parameter
import com.demonwav.mcdev.util.fullQualifiedName
-import com.demonwav.mcdev.util.isErasureEquivalentTo
import com.demonwav.mcdev.util.synchronize
import com.intellij.codeInsight.intention.QuickFixFactory
import com.intellij.codeInspection.LocalQuickFix
import com.intellij.codeInspection.ProblemDescriptor
+import com.intellij.codeInspection.ProblemHighlightType
import com.intellij.codeInspection.ProblemsHolder
import com.intellij.openapi.project.Project
import com.intellij.psi.JavaElementVisitor
@@ -33,10 +35,12 @@ import com.intellij.psi.PsiClassType
import com.intellij.psi.PsiElementVisitor
import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiModifier
-import com.intellij.psi.PsiParameter
import com.intellij.psi.PsiParameterList
+import com.intellij.psi.PsiPrimitiveType
+import com.intellij.psi.PsiType
import com.intellij.psi.codeStyle.JavaCodeStyleManager
import com.intellij.psi.codeStyle.VariableKind
+import com.intellij.psi.util.TypeConversionUtil
import org.objectweb.asm.Opcodes
import org.objectweb.asm.tree.AbstractInsnNode
import org.objectweb.asm.tree.InsnList
@@ -119,10 +123,12 @@ class InvalidInjectorMethodSignatureInspection : MixinInspection() {
var isValid = false
for ((expectedParameters, expectedReturnType) in possibleSignatures) {
- if (checkParameters(parameters, expectedParameters)) {
+ val paramsMatch =
+ checkParameters(parameters, expectedParameters, handler.allowCoerce) == CheckResult.OK
+ if (paramsMatch) {
val methodReturnType = method.returnType
if (methodReturnType != null &&
- methodReturnType.isErasureEquivalentTo(expectedReturnType)
+ checkReturnType(expectedReturnType, methodReturnType, method, handler.allowCoerce)
) {
isValid = true
break
@@ -133,19 +139,31 @@ class InvalidInjectorMethodSignatureInspection : MixinInspection() {
if (!isValid) {
val (expectedParameters, expectedReturnType) = possibleSignatures[0]
- if (!checkParameters(parameters, expectedParameters)) {
+ val checkResult = checkParameters(parameters, expectedParameters, handler.allowCoerce)
+ if (checkResult != CheckResult.OK) {
reportedSignature = true
- holder.registerProblem(
- parameters,
- "Method parameters do not match expected parameters for $annotationName",
- ParametersQuickFix(expectedParameters, handler is InjectAnnotationHandler)
+ val description =
+ "Method parameters do not match expected parameters for $annotationName"
+ val quickFix = ParametersQuickFix(
+ expectedParameters,
+ handler is InjectAnnotationHandler
)
+ if (checkResult == CheckResult.ERROR) {
+ holder.registerProblem(parameters, description, quickFix)
+ } else {
+ holder.registerProblem(
+ parameters,
+ description,
+ ProblemHighlightType.WARNING,
+ quickFix
+ )
+ }
}
val methodReturnType = method.returnType
if (methodReturnType == null ||
- !methodReturnType.isErasureEquivalentTo(expectedReturnType)
+ !checkReturnType(expectedReturnType, methodReturnType, method, handler.allowCoerce)
) {
reportedSignature = true
@@ -187,23 +205,67 @@ class InvalidInjectorMethodSignatureInspection : MixinInspection() {
return null
}
- private fun checkParameters(parameterList: PsiParameterList, expected: List): Boolean {
+ private fun checkReturnType(
+ expectedReturnType: PsiType,
+ methodReturnType: PsiType,
+ method: PsiMethod,
+ allowCoerce: Boolean
+ ): Boolean {
+ val expectedErasure = TypeConversionUtil.erasure(expectedReturnType)
+ val returnErasure = TypeConversionUtil.erasure(methodReturnType)
+ if (expectedErasure == returnErasure) {
+ return true
+ }
+ if (!allowCoerce || !method.hasAnnotation(COERCE)) {
+ return false
+ }
+ if (expectedReturnType is PsiPrimitiveType || methodReturnType is PsiPrimitiveType) {
+ return false
+ }
+ return isAssignable(expectedReturnType, methodReturnType)
+ }
+
+ private fun checkParameters(
+ parameterList: PsiParameterList,
+ expected: List,
+ allowCoerce: Boolean
+ ): CheckResult {
val parameters = parameterList.parameters
var pos = 0
for (group in expected) {
// Check if parameter group matches
- if (group.match(parameters, pos)) {
+ if (group.match(parameters, pos, allowCoerce)) {
pos += group.size
- } else if (group.required) {
- return false
+ } else if (group.required != ParameterGroup.RequiredLevel.OPTIONAL) {
+ return if (group.required == ParameterGroup.RequiredLevel.ERROR_IF_ABSENT) {
+ CheckResult.ERROR
+ } else {
+ CheckResult.WARNING
+ }
+ }
+ }
+
+ // check we have consumed all the parameters
+ if (pos < parameters.size) {
+ return if (
+ expected.lastOrNull()?.isVarargs == true &&
+ expected.last().required == ParameterGroup.RequiredLevel.WARN_IF_ABSENT
+ ) {
+ CheckResult.WARNING
+ } else {
+ CheckResult.ERROR
}
}
- return true
+ return CheckResult.OK
}
}
+ private enum class CheckResult {
+ OK, WARNING, ERROR
+ }
+
private class ParametersQuickFix(
private val expected: List,
isInject: Boolean
@@ -225,16 +287,16 @@ class InvalidInjectorMethodSignatureInspection : MixinInspection() {
return@dropWhile fqname != MixinConstants.Classes.CALLBACK_INFO &&
fqname != MixinConstants.Classes.CALLBACK_INFO_RETURNABLE
}.drop(1) // the first element in the list is the CallbackInfo but we don't want it
- val newParams = expected.flatMapTo(mutableListOf()) {
+ val newParams = expected.flatMapTo(mutableListOf()) {
if (it.default) {
- it.parameters?.mapIndexed { i: Int, p: Parameter ->
+ it.parameters.mapIndexed { i: Int, p: Parameter ->
JavaPsiFacade.getElementFactory(project).createParameter(
p.name ?: JavaCodeStyleManager.getInstance(project)
.suggestVariableName(VariableKind.PARAMETER, null, null, p.type).names
.firstOrNull() ?: "var$i",
p.type
)
- } ?: emptyList()
+ }
} else {
emptyList()
}
diff --git a/src/main/kotlin/platform/mixin/inspection/injector/ModifyVariableArgsOnlyInspection.kt b/src/main/kotlin/platform/mixin/inspection/injector/ModifyVariableArgsOnlyInspection.kt
new file mode 100644
index 000000000..12c6f28a5
--- /dev/null
+++ b/src/main/kotlin/platform/mixin/inspection/injector/ModifyVariableArgsOnlyInspection.kt
@@ -0,0 +1,105 @@
+/*
+ * Minecraft Dev for IntelliJ
+ *
+ * https://minecraftdev.org
+ *
+ * Copyright (c) 2021 minecraft-dev
+ *
+ * MIT License
+ */
+
+package com.demonwav.mcdev.platform.mixin.inspection.injector
+
+import com.demonwav.mcdev.platform.mixin.handlers.MixinAnnotationHandler
+import com.demonwav.mcdev.platform.mixin.inspection.MixinInspection
+import com.demonwav.mcdev.platform.mixin.util.MethodTargetMember
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.MODIFY_VARIABLE
+import com.demonwav.mcdev.platform.mixin.util.hasAccess
+import com.demonwav.mcdev.util.constantValue
+import com.demonwav.mcdev.util.createLiteralExpression
+import com.demonwav.mcdev.util.descriptor
+import com.demonwav.mcdev.util.findAnnotation
+import com.demonwav.mcdev.util.ifEmpty
+import com.intellij.codeInspection.LocalQuickFixOnPsiElement
+import com.intellij.codeInspection.ProblemsHolder
+import com.intellij.openapi.project.Project
+import com.intellij.psi.JavaElementVisitor
+import com.intellij.psi.JavaPsiFacade
+import com.intellij.psi.PsiAnnotation
+import com.intellij.psi.PsiElement
+import com.intellij.psi.PsiElementVisitor
+import com.intellij.psi.PsiFile
+import com.intellij.psi.PsiMethod
+import org.objectweb.asm.Opcodes
+import org.objectweb.asm.Type
+
+class ModifyVariableArgsOnlyInspection : MixinInspection() {
+ override fun buildVisitor(holder: ProblemsHolder): PsiElementVisitor {
+ return object : JavaElementVisitor() {
+ override fun visitMethod(method: PsiMethod) {
+ val modifyVariable = method.findAnnotation(MODIFY_VARIABLE) ?: return
+ if (modifyVariable.findDeclaredAttributeValue("argsOnly")?.constantValue == true) {
+ return
+ }
+ val ordinal = (modifyVariable.findDeclaredAttributeValue("ordinal")?.constantValue as? Int?)
+ ?.takeIf { it != -1 }
+ val index = (modifyVariable.findDeclaredAttributeValue("index")?.constantValue as? Int?)
+ ?.takeIf { it != -1 }
+ if (ordinal == null && index == null && modifyVariable.findDeclaredAttributeValue("name") != null) {
+ return
+ }
+ val wantedType = method.parameterList.getParameter(0)?.type?.descriptor ?: return
+ val problemElement = modifyVariable.nameReferenceElement ?: return
+
+ val handler = MixinAnnotationHandler.forMixinAnnotation(MODIFY_VARIABLE) ?: return
+ val targets = handler.resolveTarget(modifyVariable).ifEmpty { return }
+ val methodTargets = targets.asSequence()
+ .filterIsInstance()
+ .map { it.classAndMethod }
+ for ((targetClass, targetMethod) in methodTargets) {
+ val argTypes = mutableListOf()
+ if (!targetMethod.hasAccess(Opcodes.ACC_STATIC)) {
+ argTypes += "L${targetClass.name};"
+ }
+ for (arg in Type.getArgumentTypes(targetMethod.desc)) {
+ argTypes += arg.descriptor
+ if (arg.size == 2) {
+ argTypes += null
+ }
+ }
+
+ if (ordinal != null) {
+ if (argTypes.asSequence().filter { it == wantedType }.count() <= ordinal) {
+ return
+ }
+ } else if (index != null) {
+ if (argTypes.size <= index) {
+ return
+ }
+ } else {
+ if (argTypes.asSequence().filter { it == wantedType }.count() != 1) {
+ return
+ }
+ }
+ }
+
+ val description = "ModifyVariable may be argsOnly = true"
+ holder.registerProblem(problemElement, description, AddArgsOnlyFix(modifyVariable))
+ }
+ }
+ }
+
+ override fun getStaticDescription() =
+ "Checks that ModifyVariable has argsOnly if it targets arguments, which improves performance of the mixin"
+
+ private class AddArgsOnlyFix(annotation: PsiAnnotation) : LocalQuickFixOnPsiElement(annotation) {
+ override fun getFamilyName() = "Add argsOnly = true"
+ override fun getText() = "Add argsOnly = true"
+
+ override fun invoke(project: Project, file: PsiFile, startElement: PsiElement, endElement: PsiElement) {
+ val annotation = startElement as? PsiAnnotation ?: return
+ val trueExpr = JavaPsiFacade.getElementFactory(project).createLiteralExpression(true)
+ annotation.setDeclaredAttributeValue("argsOnly", trueExpr)
+ }
+ }
+}
diff --git a/src/main/kotlin/platform/mixin/inspection/injector/ParameterGroup.kt b/src/main/kotlin/platform/mixin/inspection/injector/ParameterGroup.kt
index 94ee4847c..fd69494a8 100644
--- a/src/main/kotlin/platform/mixin/inspection/injector/ParameterGroup.kt
+++ b/src/main/kotlin/platform/mixin/inspection/injector/ParameterGroup.kt
@@ -10,29 +10,27 @@
package com.demonwav.mcdev.platform.mixin.inspection.injector
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.COERCE
+import com.demonwav.mcdev.platform.mixin.util.isAssignable
import com.demonwav.mcdev.util.Parameter
-import com.demonwav.mcdev.util.isErasureEquivalentTo
-import com.intellij.psi.PsiArrayType
-import com.intellij.psi.PsiEllipsisType
+import com.demonwav.mcdev.util.normalize
import com.intellij.psi.PsiParameter
+import com.intellij.psi.PsiPrimitiveType
+import com.intellij.psi.PsiType
data class ParameterGroup(
- val parameters: List?,
- val required: Boolean = parameters != null,
- val default: Boolean = required
+ val parameters: List,
+ val required: RequiredLevel = RequiredLevel.ERROR_IF_ABSENT,
+ val default: Boolean = required != RequiredLevel.OPTIONAL,
+ val isVarargs: Boolean = false
) {
val size
- get() = this.parameters?.size ?: 0
-
- fun match(parameters: Array, currentPosition: Int): Boolean {
- if (this.parameters == null) {
- // Wildcard parameter groups always match
- return true
- }
+ get() = this.parameters.size
+ fun match(parameters: Array, currentPosition: Int, allowCoerce: Boolean): Boolean {
// Check if remaining parameter count is enough
- if (currentPosition + size > parameters.size) {
+ if (!isVarargs && currentPosition + size > parameters.size) {
return false
}
@@ -40,15 +38,42 @@ data class ParameterGroup(
// Check parameter types
for ((_, expectedType) in this.parameters) {
- val type = parameters[pos++].type
- if (!type.isErasureEquivalentTo(expectedType)) {
- // Allow using array instead of varargs
- if (expectedType !is PsiEllipsisType || type !is PsiArrayType || type != expectedType.toArrayType()) {
+ if (isVarargs && pos == parameters.size) {
+ break
+ }
+ val parameter = parameters[pos++]
+ if (!matchParameter(expectedType, parameter, allowCoerce)) {
+ return false
+ }
+ }
+
+ return !isVarargs || pos == parameters.size
+ }
+
+ enum class RequiredLevel {
+ OPTIONAL, WARN_IF_ABSENT, ERROR_IF_ABSENT
+ }
+
+ companion object {
+ private val INT_TYPES = setOf(PsiType.INT, PsiType.SHORT, PsiType.CHAR, PsiType.BYTE, PsiType.BOOLEAN)
+
+ private fun matchParameter(expectedType: PsiType, parameter: PsiParameter, allowCoerce: Boolean): Boolean {
+ val normalizedExpected = expectedType.normalize()
+ val normalizedParameter = parameter.type.normalize()
+ if (normalizedExpected == normalizedParameter) {
+ return true
+ }
+ if (!allowCoerce || !parameter.hasAnnotation(COERCE)) {
+ return false
+ }
+
+ if (normalizedExpected is PsiPrimitiveType) {
+ if (normalizedParameter !is PsiPrimitiveType) {
return false
}
+ return normalizedExpected in INT_TYPES && normalizedParameter in INT_TYPES
}
+ return isAssignable(normalizedParameter, normalizedExpected)
}
-
- return true
}
}
diff --git a/src/main/kotlin/platform/mixin/inspection/reference/UnresolvedReferenceInspection.kt b/src/main/kotlin/platform/mixin/inspection/reference/UnresolvedReferenceInspection.kt
index 82209c8fa..f5b75ba09 100644
--- a/src/main/kotlin/platform/mixin/inspection/reference/UnresolvedReferenceInspection.kt
+++ b/src/main/kotlin/platform/mixin/inspection/reference/UnresolvedReferenceInspection.kt
@@ -16,7 +16,6 @@ import com.demonwav.mcdev.platform.mixin.reference.InjectionPointReference
import com.demonwav.mcdev.platform.mixin.reference.MethodReference
import com.demonwav.mcdev.platform.mixin.reference.MixinReference
import com.demonwav.mcdev.platform.mixin.reference.target.TargetReference
-import com.demonwav.mcdev.platform.mixin.util.isWithinDynamicMixin
import com.demonwav.mcdev.util.annotationFromNameValuePair
import com.demonwav.mcdev.util.constantStringValue
import com.intellij.codeInspection.ProblemHighlightType
@@ -60,7 +59,7 @@ class UnresolvedReferenceInspection : MixinInspection() {
}
private fun checkResolved(resolver: MixinReference, value: PsiAnnotationMemberValue) {
- if (resolver.isUnresolved(value) && !value.isWithinDynamicMixin) {
+ if (resolver.isUnresolved(value)) {
holder.registerProblem(
value,
"Cannot resolve ${resolver.description}".format(value.constantStringValue),
diff --git a/src/main/kotlin/platform/mixin/inspection/suppress/DefaultAnnotationParamInspectionSuppressor.kt b/src/main/kotlin/platform/mixin/inspection/suppress/DefaultAnnotationParamInspectionSuppressor.kt
new file mode 100644
index 000000000..350624f59
--- /dev/null
+++ b/src/main/kotlin/platform/mixin/inspection/suppress/DefaultAnnotationParamInspectionSuppressor.kt
@@ -0,0 +1,96 @@
+/*
+ * Minecraft Dev for IntelliJ
+ *
+ * https://minecraftdev.org
+ *
+ * Copyright (c) 2021 minecraft-dev
+ *
+ * MIT License
+ */
+
+package com.demonwav.mcdev.platform.mixin.inspection.suppress
+
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.ACCESSOR
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.AT
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.CONSTANT
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.INJECT
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.INVOKER
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.MIXIN
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.MODIFY_ARG
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.MODIFY_ARGS
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.MODIFY_VARIABLE
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.REDIRECT
+import com.demonwav.mcdev.util.constantValue
+import com.demonwav.mcdev.util.findAnnotation
+import com.demonwav.mcdev.util.mapFirstNotNull
+import com.intellij.codeInspection.InspectionSuppressor
+import com.intellij.codeInspection.SuppressQuickFix
+import com.intellij.psi.PsiAnnotation
+import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiElement
+import com.intellij.psi.PsiModifierListOwner
+import com.intellij.psi.PsiNameValuePair
+import com.intellij.psi.util.parentOfType
+
+class DefaultAnnotationParamInspectionSuppressor : InspectionSuppressor {
+ override fun isSuppressedFor(element: PsiElement, toolId: String): Boolean {
+ if (toolId != INSPECTION) {
+ return false
+ }
+
+ val name = element.parentOfType()?.attributeName ?: return false
+ val annotation = element.parentOfType() ?: return false
+
+ if (name in CONSTANT_SUPPRESSED && annotation.hasQualifiedName(CONSTANT)) {
+ return true
+ }
+
+ if (name == "remap" && REMAP_SUPPRESSED.any(annotation::hasQualifiedName)) {
+ val currentRemap = annotation.findAttributeValue("remap")?.constantValue as? Boolean
+ ?: return false
+ val parentRemap = generateSequence(annotation) { elem ->
+ elem.parent?.takeIf { elem !is PsiClass }
+ }
+ .filterIsInstance()
+ .drop(1) // don't look at our own owner
+ .mapNotNull { annotationOwner ->
+ REMAP_SUPPRESSED.mapFirstNotNull {
+ annotationOwner.findAnnotation(it)?.findDeclaredAttributeValue("remap")?.constantValue
+ as? Boolean
+ }
+ }
+ .firstOrNull() ?: true
+ if (currentRemap != parentRemap) {
+ return true
+ }
+ }
+
+ return false
+ }
+
+ override fun getSuppressActions(element: PsiElement?, toolId: String): Array =
+ SuppressQuickFix.EMPTY_ARRAY
+
+ companion object {
+ private const val INSPECTION = "DefaultAnnotationParam"
+ private val REMAP_SUPPRESSED = setOf(
+ AT,
+ INJECT,
+ MODIFY_ARG,
+ MODIFY_ARGS,
+ MODIFY_VARIABLE,
+ REDIRECT,
+ ACCESSOR,
+ INVOKER,
+ MIXIN
+ )
+ private val CONSTANT_SUPPRESSED = setOf(
+ "intValue",
+ "floatValue",
+ "longValue",
+ "doubleValue",
+ "stringValue",
+ "classValue"
+ )
+ }
+}
diff --git a/src/main/kotlin/platform/mixin/inspection/suppress/DynamicInspectionSuppressor.kt b/src/main/kotlin/platform/mixin/inspection/suppress/DynamicInspectionSuppressor.kt
new file mode 100644
index 000000000..eb26fba2f
--- /dev/null
+++ b/src/main/kotlin/platform/mixin/inspection/suppress/DynamicInspectionSuppressor.kt
@@ -0,0 +1,43 @@
+/*
+ * Minecraft Dev for IntelliJ
+ *
+ * https://minecraftdev.org
+ *
+ * Copyright (c) 2021 minecraft-dev
+ *
+ * MIT License
+ */
+
+package com.demonwav.mcdev.platform.mixin.inspection.suppress
+
+import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.DYNAMIC
+import com.demonwav.mcdev.util.findAnnotation
+import com.demonwav.mcdev.util.findContainingMember
+import com.intellij.codeInspection.InspectionSuppressor
+import com.intellij.codeInspection.SuppressQuickFix
+import com.intellij.psi.PsiElement
+
+class DynamicInspectionSuppressor : InspectionSuppressor {
+ private val suppressedInspections = setOf(
+ "AmbiguousMixinReference",
+ "InvalidInjectorMethodSignature",
+ "InvalidMemberReference",
+ "MixinAnnotationTarget",
+ "OverwriteModifiers",
+ "ShadowModifiers",
+ "UnqualifiedMemberReference",
+ "UnnecessaryQualifiedMemberReference",
+ "UnresolvedMixinReference"
+ )
+
+ override fun isSuppressedFor(element: PsiElement, toolId: String): Boolean {
+ if (toolId !in suppressedInspections) {
+ return false
+ }
+ return element.findContainingMember()?.findAnnotation(DYNAMIC) != null
+ }
+
+ override fun getSuppressActions(element: PsiElement?, toolId: String): Array {
+ return SuppressQuickFix.EMPTY_ARRAY
+ }
+}
diff --git a/src/main/kotlin/platform/mixin/inspection/suppress/MixinClassCastInspectionSuppressor.kt b/src/main/kotlin/platform/mixin/inspection/suppress/MixinClassCastInspectionSuppressor.kt
index 32abfc577..6452698bc 100644
--- a/src/main/kotlin/platform/mixin/inspection/suppress/MixinClassCastInspectionSuppressor.kt
+++ b/src/main/kotlin/platform/mixin/inspection/suppress/MixinClassCastInspectionSuppressor.kt
@@ -11,24 +11,24 @@
package com.demonwav.mcdev.platform.mixin.inspection.suppress
import com.demonwav.mcdev.platform.mixin.action.FindMixinsAction
-import com.demonwav.mcdev.platform.mixin.util.findStubClass
-import com.demonwav.mcdev.platform.mixin.util.isMixin
-import com.demonwav.mcdev.platform.mixin.util.mixinTargets
+import com.demonwav.mcdev.platform.mixin.util.isAssignable
import com.intellij.codeInspection.InspectionSuppressor
import com.intellij.codeInspection.SuppressQuickFix
import com.intellij.codeInspection.dataFlow.CommonDataflow
+import com.intellij.codeInspection.dataFlow.TypeConstraint
+import com.intellij.codeInspection.dataFlow.TypeConstraints
import com.intellij.codeInspection.dataFlow.types.DfReferenceType
+import com.intellij.openapi.project.Project
+import com.intellij.psi.JavaPsiFacade
+import com.intellij.psi.JavaTokenType
import com.intellij.psi.PsiArrayType
-import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiBinaryExpression
import com.intellij.psi.PsiClassType
-import com.intellij.psi.PsiDisjunctionType
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiExpression
import com.intellij.psi.PsiInstanceOfExpression
-import com.intellij.psi.PsiIntersectionType
import com.intellij.psi.PsiType
import com.intellij.psi.PsiTypeCastExpression
-import com.intellij.psi.util.InheritanceUtil
import com.intellij.psi.util.PsiUtil
/**
@@ -52,6 +52,25 @@ class MixinClassCastInspectionSuppressor : InspectionSuppressor {
return isAssignable(castType, realType)
}
+ // check == and !=
+ if (element is PsiBinaryExpression && (
+ element.operationSign.tokenType == JavaTokenType.EQEQ ||
+ element.operationSign.tokenType == JavaTokenType.NE
+ )
+ ) {
+ val rightType = element.rOperand?.let(this::getTypeConstraint) ?: return false
+ val leftType = getTypeConstraint(element.lOperand) ?: return false
+ val isTypeWarning = leftType.meet(rightType) == TypeConstraints.BOTTOM
+ if (isTypeWarning) {
+ val leftWithMixins = addMixinsToTypeConstraint(element.project, leftType)
+ val rightWithMixins = addMixinsToTypeConstraint(element.project, rightType)
+ if (leftWithMixins == leftType && rightWithMixins == rightType) {
+ return false
+ }
+ return leftWithMixins.meet(rightWithMixins) != TypeConstraints.BOTTOM
+ }
+ }
+
val castExpression = element.parent as? PsiTypeCastExpression ?: return false
val castType = castExpression.type ?: return false
val realType = getRealType(castExpression) ?: return false
@@ -59,49 +78,38 @@ class MixinClassCastInspectionSuppressor : InspectionSuppressor {
return isAssignable(castType, realType)
}
- private fun isAssignable(left: PsiType, right: PsiType): Boolean {
- return when {
- left is PsiIntersectionType -> left.conjuncts.all { isAssignable(it, right) }
- right is PsiIntersectionType -> right.conjuncts.any { isAssignable(left, it) }
- left is PsiDisjunctionType -> left.disjunctions.any { isAssignable(it, right) }
- right is PsiDisjunctionType -> isAssignable(left, right.leastUpperBound)
- left is PsiArrayType -> right is PsiArrayType && isAssignable(left.componentType, right.componentType)
- else -> {
- if (left !is PsiClassType || right !is PsiClassType) {
- return false
+ private fun addMixinsToTypeConstraint(project: Project, typeConstraint: TypeConstraint): TypeConstraint {
+ val psiType = typeConstraint.getPsiType(project) ?: return typeConstraint
+ val targetClass = when (psiType) {
+ is PsiArrayType -> (psiType.deepComponentType as? PsiClassType)?.resolve()
+ is PsiClassType -> psiType.resolve()
+ else -> null
+ } ?: return typeConstraint
+ val mixins = FindMixinsAction.findMixins(targetClass, project) ?: return typeConstraint
+ if (mixins.isEmpty()) return typeConstraint
+ val elementFactory = JavaPsiFacade.getElementFactory(project)
+ val mixinTypes = mixins.map { mixinClass ->
+ var type: PsiType = elementFactory.createType(mixinClass)
+ if (psiType is PsiArrayType) {
+ repeat(psiType.arrayDimensions) {
+ type = type.createArrayType()
}
- val leftClass = left.resolve() ?: return false
- val rightClass = right.resolve() ?: return false
- if (rightClass.isMixin) {
- val isMixinAssignable = rightClass.mixinTargets.any {
- val stubClass = it.findStubClass(rightClass.project) ?: return@any false
- isClassAssignable(leftClass, stubClass)
- }
- if (isMixinAssignable) {
- return true
- }
- }
- val mixins = FindMixinsAction.findMixins(rightClass, rightClass.project) ?: return false
- return mixins.any { isClassAssignable(leftClass, it) }
}
- }
- }
-
- private fun isClassAssignable(leftClass: PsiClass, rightClass: PsiClass): Boolean {
- var result = false
- InheritanceUtil.processSupers(rightClass, true) {
- if (it == leftClass) {
- result = true
- false
+ if (typeConstraint.isExact) {
+ TypeConstraints.exact(type)
} else {
- true
+ TypeConstraints.instanceOf(type)
}
}
- return result
+ return typeConstraint.join(mixinTypes.reduce(TypeConstraint::join))
}
private fun getRealType(expression: PsiExpression): PsiType? {
- return (CommonDataflow.getDfType(expression) as? DfReferenceType)?.constraint?.getPsiType(expression.project)
+ return getTypeConstraint(expression)?.getPsiType(expression.project)
+ }
+
+ private fun getTypeConstraint(expression: PsiExpression): TypeConstraint? {
+ return (CommonDataflow.getDfType(expression) as? DfReferenceType)?.constraint
}
override fun getSuppressActions(element: PsiElement?, toolId: String): Array =
diff --git a/src/main/kotlin/platform/mixin/reference/InjectionPointReference.kt b/src/main/kotlin/platform/mixin/reference/InjectionPointReference.kt
index 8408bc831..23657a660 100644
--- a/src/main/kotlin/platform/mixin/reference/InjectionPointReference.kt
+++ b/src/main/kotlin/platform/mixin/reference/InjectionPointReference.kt
@@ -49,7 +49,8 @@ object InjectionPointReference : ReferenceResolver(), MixinReference {
override fun resolveReference(context: PsiElement): PsiElement? {
// Remove slice selectors from the injection point type
var name = context.constantStringValue ?: return null
- val isInsideSlice = context.parentOfType()?.hasQualifiedName(SLICE) == true
+ val at = context.parentOfType() ?: return null
+ val isInsideSlice = at.parentOfType()?.hasQualifiedName(SLICE) == true
if (isInsideSlice) {
for (sliceSelector in getSliceSelectors(context.project)) {
if (name.endsWith(":$sliceSelector")) {
diff --git a/src/main/kotlin/platform/mixin/reference/target/TargetReference.kt b/src/main/kotlin/platform/mixin/reference/target/TargetReference.kt
index 53cc1e489..39313c9ce 100644
--- a/src/main/kotlin/platform/mixin/reference/target/TargetReference.kt
+++ b/src/main/kotlin/platform/mixin/reference/target/TargetReference.kt
@@ -19,7 +19,6 @@ import com.demonwav.mcdev.platform.mixin.util.MethodTargetMember
import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.AT
import com.demonwav.mcdev.util.ifEmpty
import com.demonwav.mcdev.util.insideAnnotationAttribute
-import com.demonwav.mcdev.util.mapFirstNotNull
import com.demonwav.mcdev.util.reference.PolyReferenceResolver
import com.demonwav.mcdev.util.reference.completeToLiteral
import com.intellij.patterns.ElementPattern
@@ -30,7 +29,6 @@ import com.intellij.psi.PsiElement
import com.intellij.psi.PsiElementResolveResult
import com.intellij.psi.PsiLiteral
import com.intellij.psi.PsiMember
-import com.intellij.psi.PsiMethod
import com.intellij.psi.ResolveResult
import com.intellij.psi.util.parentOfType
import com.intellij.util.ArrayUtilRt
@@ -60,11 +58,12 @@ object TargetReference : PolyReferenceResolver(), MixinReference {
* reference as unresolved.
*/
private fun getTargets(at: PsiAnnotation): List? {
- val method = at.parentOfType() ?: return emptyList()
- val (handler, annotation) = method.annotations.mapFirstNotNull { annotation ->
- val qName = annotation.qualifiedName ?: return@mapFirstNotNull null
- MixinAnnotationHandler.forMixinAnnotation(qName)?.let { it to annotation }
- } ?: return null
+ val (handler, annotation) = generateSequence(at.parent) { it.parent }
+ .filterIsInstance()
+ .mapNotNull { annotation ->
+ val qName = annotation.qualifiedName ?: return@mapNotNull null
+ MixinAnnotationHandler.forMixinAnnotation(qName)?.let { it to annotation }
+ }.firstOrNull() ?: return null
return handler.resolveTarget(annotation).mapNotNull { (it as? MethodTargetMember)?.classAndMethod }
}
diff --git a/src/main/kotlin/platform/mixin/util/AsmDfaUtil.kt b/src/main/kotlin/platform/mixin/util/AsmDfaUtil.kt
index ae52adeee..dd47b3e32 100644
--- a/src/main/kotlin/platform/mixin/util/AsmDfaUtil.kt
+++ b/src/main/kotlin/platform/mixin/util/AsmDfaUtil.kt
@@ -30,7 +30,7 @@ import org.objectweb.asm.tree.analysis.SimpleVerifier
object AsmDfaUtil {
private val LOGGER = Logger.getInstance(AsmDfaUtil::class.java)
- private fun analyzeMethod(project: Project, clazz: ClassNode, method: MethodNode): Array>? {
+ fun analyzeMethod(project: Project, clazz: ClassNode, method: MethodNode): Array?>? {
return method.cached(clazz, project) {
try {
Analyzer(
diff --git a/src/main/kotlin/platform/mixin/util/AsmUtil.kt b/src/main/kotlin/platform/mixin/util/AsmUtil.kt
index 558ecb28d..e663ea7ac 100644
--- a/src/main/kotlin/platform/mixin/util/AsmUtil.kt
+++ b/src/main/kotlin/platform/mixin/util/AsmUtil.kt
@@ -169,7 +169,7 @@ fun findClassNodeByPsiClass(psiClass: PsiClass, module: Module? = psiClass.findM
}
val classFile = parentDir.findChild("${fqn.substringAfterLast('.')}.class") ?: return null
val node = ClassNode()
- ClassReader(classFile.inputStream).accept(node, 0)
+ classFile.inputStream.use { ClassReader(it).accept(node, 0) }
node
} else {
val node = ClassNode()
@@ -178,11 +178,12 @@ fun findClassNodeByPsiClass(psiClass: PsiClass, module: Module? = psiClass.findM
}
} catch (e: Throwable) {
val actualThrowable = if (e is InvocationTargetException) e.cause ?: e else e
+ if (actualThrowable is ProcessCanceledException) {
+ throw actualThrowable
+ }
val message = actualThrowable.message
// TODO: display an error to the user?
- if (actualThrowable !is ProcessCanceledException &&
- (message == null || !message.contains("Unsupported class file major version"))
- ) {
+ if (message == null || !message.contains("Unsupported class file major version")) {
LOGGER.error(actualThrowable)
}
null
diff --git a/src/main/kotlin/platform/mixin/util/LocalVariables.kt b/src/main/kotlin/platform/mixin/util/LocalVariables.kt
new file mode 100644
index 000000000..8cb918392
--- /dev/null
+++ b/src/main/kotlin/platform/mixin/util/LocalVariables.kt
@@ -0,0 +1,907 @@
+/*
+ * Minecraft Dev for IntelliJ
+ *
+ * https://minecraftdev.org
+ *
+ * Copyright (c) 2021 minecraft-dev
+ *
+ * MIT License
+ */
+
+/*
+ * This file contains substantial amounts of code from Mixin, licensed under the MIT License (MIT).
+ * See https://github.com/SpongePowered/Mixin/blob/master/src/main/java/org/spongepowered/asm/util/Locals.java
+ *
+ * Copyright (c) SpongePowered
+ * Copyright (c) contributors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+package com.demonwav.mcdev.platform.mixin.util
+
+import com.demonwav.mcdev.facet.MinecraftFacet
+import com.demonwav.mcdev.platform.mixin.MixinModuleType
+import com.demonwav.mcdev.util.SemanticVersion
+import com.demonwav.mcdev.util.cached
+import com.demonwav.mcdev.util.mapToArray
+import com.demonwav.mcdev.util.psiType
+import com.intellij.openapi.module.Module
+import com.intellij.openapi.project.Project
+import com.intellij.openapi.util.Key
+import com.intellij.psi.CommonClassNames
+import com.intellij.psi.JavaPsiFacade
+import com.intellij.psi.JavaRecursiveElementVisitor
+import com.intellij.psi.PsiArrayType
+import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiElement
+import com.intellij.psi.PsiForeachStatement
+import com.intellij.psi.PsiLambdaExpression
+import com.intellij.psi.PsiMethod
+import com.intellij.psi.PsiModifier
+import com.intellij.psi.PsiStatement
+import com.intellij.psi.PsiType
+import com.intellij.psi.PsiVariable
+import com.intellij.psi.controlFlow.ControlFlow
+import com.intellij.psi.controlFlow.ControlFlowFactory
+import com.intellij.psi.controlFlow.ControlFlowInstructionVisitor
+import com.intellij.psi.controlFlow.ControlFlowOptions
+import com.intellij.psi.controlFlow.Instruction
+import com.intellij.psi.controlFlow.LocalsControlFlowPolicy
+import com.intellij.psi.controlFlow.WriteVariableInstruction
+import com.intellij.psi.scope.util.PsiScopesUtil
+import com.intellij.psi.util.PsiModificationTracker
+import com.intellij.psi.util.PsiTreeUtil
+import com.intellij.psi.util.parentOfType
+import kotlin.math.min
+import org.objectweb.asm.Opcodes
+import org.objectweb.asm.Type
+import org.objectweb.asm.tree.AbstractInsnNode
+import org.objectweb.asm.tree.ClassNode
+import org.objectweb.asm.tree.FrameNode
+import org.objectweb.asm.tree.InsnList
+import org.objectweb.asm.tree.LabelNode
+import org.objectweb.asm.tree.LineNumberNode
+import org.objectweb.asm.tree.MethodNode
+import org.objectweb.asm.tree.VarInsnNode
+import org.objectweb.asm.tree.analysis.BasicValue
+
+object LocalVariables {
+ private val LOCAL_INDEX_KEY = Key("mcdev.local_index")
+
+ /**
+ * Guesses the local variable index of the given variable, or of implicit locals at the given element.
+ * Only valid after [guessLocalsAt] has been called.
+ */
+ fun guessLocalVariableIndex(element: PsiElement): Int? {
+ return element.getUserData(LOCAL_INDEX_KEY)
+ }
+
+ fun guessLocalsAt(element: PsiElement, argsOnly: Boolean, start: Boolean): List {
+ val method = PsiTreeUtil.getParentOfType(element, PsiMethod::class.java, PsiLambdaExpression::class.java)
+ ?: return emptyList()
+ val actualMethod = method.parentOfType(withSelf = true) ?: return emptyList()
+ val args = mutableListOf()
+
+ var argsIndex = 0
+ if (!actualMethod.hasModifierProperty(PsiModifier.STATIC)) {
+ args += SourceLocalVariable("this", actualMethod.containingClass?.psiType ?: return emptyList(), 0)
+ argsIndex++
+ }
+
+ for (parameter in method.parameterList.parameters) {
+ val mixinName = if (argsOnly) "var$argsIndex" else parameter.name
+ args += SourceLocalVariable(parameter.name, parameter.type, argsIndex, mixinName = mixinName)
+ argsIndex++
+ if (parameter.isDoubleSlot) {
+ argsIndex++
+ }
+ }
+
+ if (argsOnly) {
+ return args
+ }
+
+ val body = method.body ?: return args
+ val controlFlow = ControlFlowFactory.getControlFlow(
+ body,
+ LocalsControlFlowPolicy(body),
+ ControlFlowOptions.NO_CONST_EVALUATE
+ )
+
+ val allLocalVariables = guessAllLocalVariables(argsIndex, body, controlFlow)
+ val elementOffset = if (start) controlFlow.getStartOffset(element) else controlFlow.getEndOffset(element)
+ return args + (allLocalVariables.getOrNull(elementOffset) ?: emptyList())
+ }
+
+ private fun guessAllLocalVariables(
+ argsSize: Int,
+ body: PsiElement,
+ controlFlow: ControlFlow
+ ): Array> {
+ return body.cached(PsiModificationTracker.MODIFICATION_COUNT) {
+ guessAllLocalVariablesUncached(argsSize, body, controlFlow)
+ }
+ }
+
+ private fun guessAllLocalVariablesUncached(
+ argsSize: Int,
+ body: PsiElement,
+ controlFlow: ControlFlow
+ ): Array> {
+ val method = body.parent
+ val allLocalVariables = getAllLocalVariables(body)
+ for (variable in allLocalVariables) {
+ var localIndex = argsSize
+ // gets all local variable declarations in scope at the declaration of variable
+ PsiScopesUtil.treeWalkUp(
+ { elem, _ ->
+ localIndex += getLocalVariableSize(elem)
+ true
+ },
+ variable,
+ method
+ )
+ // add on other implicit declarations in scope
+ for (parent in generateSequence(variable.parent, PsiElement::getParent).takeWhile { it != method }) {
+ localIndex += getLocalVariableSize(parent)
+ }
+ variable.putUserData(LOCAL_INDEX_KEY, localIndex)
+ }
+
+ // take into account implicit locals for certain constructs (e.g. foreach loops)
+ val extraVariables = mutableMapOf>()
+ for (variable in allLocalVariables) {
+ val extraVars = when (variable) {
+ is PsiVariable -> continue
+ is PsiForeachStatement -> variable.getExtraLocals()
+ else -> continue
+ }
+ val enclosingStatement = variable.parentOfType(withSelf = true) ?: continue
+ extraVariables.getOrPut(controlFlow.getStartOffset(enclosingStatement)) { mutableListOf() } += extraVars
+ }
+
+ // compute the local variables that are definitely initialized and not overwritten at each offset
+ class MyVisitor : ControlFlowInstructionVisitor() {
+ val locals = arrayOfNulls>(controlFlow.size + 1)
+ val instructionQueue = ArrayDeque()
+
+ override fun visitWriteVariableInstruction(
+ instruction: WriteVariableInstruction,
+ offset: Int,
+ nextOffset: Int
+ ) {
+ if (instruction.variable in allLocalVariables) {
+ val localIndex = instruction.variable.getUserData(LOCAL_INDEX_KEY)!!
+ var localsHere = this.locals[offset]
+ ?: arrayOfNulls(localIndex + 1).also { this.locals[offset] = it }
+ if (localIndex >= localsHere.size) {
+ localsHere = localsHere.copyOf(localIndex + 1)
+ }
+ val name = instruction.variable.name ?: return
+ localsHere[localIndex] = SourceLocalVariable(name, instruction.variable.type, localIndex)
+ if (instruction.variable.isDoubleSlot && localIndex + 1 < localsHere.size) {
+ localsHere[localIndex + 1] = null
+ }
+ this.locals[offset] = localsHere
+ }
+ visitInstruction(instruction, offset, nextOffset)
+ }
+
+ override fun visitInstruction(instruction: Instruction, offset: Int, nextOffset: Int) {
+ val extraVars = extraVariables[offset]
+ if (extraVars != null) {
+ for (variable in extraVars) {
+ val localsHere = this.locals[offset]
+ ?: arrayOfNulls(variable.index + 1).also { this.locals[offset] = it }
+ localsHere[variable.index] = variable
+ if (variable.type == PsiType.LONG || variable.type == PsiType.DOUBLE) {
+ if (variable.index + 1 < localsHere.size) {
+ localsHere[variable.index + 1] = null
+ }
+ }
+ }
+ }
+ for (i in 0 until instruction.nNext()) {
+ visitEdge(offset, instruction.getNext(offset, i))
+ }
+ }
+
+ private fun visitEdge(offset: Int, nextOffset: Int) {
+ val localsHere = this.locals[offset] ?: emptyArray()
+ var changed = false
+ val nextLocals = this.locals[nextOffset]
+ if (nextLocals == null) {
+ this.locals[nextOffset] = localsHere.clone()
+ changed = true
+ } else {
+ for (i in localsHere.size until nextLocals.size) {
+ if (nextLocals[i] != null) {
+ nextLocals[i] = null
+ changed = true
+ }
+ }
+ for (i in 0 until min(localsHere.size, nextLocals.size)) {
+ if (nextLocals[i] != localsHere[i]) {
+ if (nextLocals[i] != null) {
+ nextLocals[i] = null
+ changed = true
+ }
+ }
+ }
+ }
+ if (changed) {
+ instructionQueue.add(nextOffset)
+ }
+ }
+ }
+
+ // walk the control flow graph
+ val visitor = MyVisitor()
+ visitor.instructionQueue.add(0)
+ while (visitor.instructionQueue.isNotEmpty()) {
+ val offset = visitor.instructionQueue.removeFirst()
+ val insn = controlFlow.instructions.getOrNull(offset) ?: continue
+ insn.accept(visitor, offset, offset + 1)
+ }
+
+ return visitor.locals.mapToArray { it?.filterNotNull() ?: emptyList() }
+ }
+
+ private fun getAllLocalVariables(body: PsiElement): List {
+ val locals = mutableListOf()
+ body.accept(
+ object : JavaRecursiveElementVisitor() {
+ override fun visitVariable(variable: PsiVariable) {
+ locals += variable
+ super.visitVariable(variable)
+ }
+
+ override fun visitForeachStatement(statement: PsiForeachStatement) {
+ locals += statement
+ super.visitForeachStatement(statement)
+ }
+
+ override fun visitClass(aClass: PsiClass?) {
+ // don't recurse into classes
+ }
+
+ override fun visitMethod(method: PsiMethod?) {
+ // don't recurse into methods
+ }
+
+ override fun visitLambdaExpression(expression: PsiLambdaExpression?) {
+ // don't recurse into lambdas
+ }
+ }
+ )
+ return locals
+ }
+
+ fun getLocalVariableSize(element: PsiElement): Int {
+ return when (element) {
+ // longs and doubles take two slots
+ is PsiVariable -> if (element.isDoubleSlot) 2 else 1
+ // arrays have copy of array, length and index variables, iterables have the iterator variable
+ is PsiForeachStatement -> if (element.iterationParameter.type is PsiArrayType) 3 else 1
+ else -> 0
+ }
+ }
+
+ private val PsiVariable.isDoubleSlot: Boolean
+ get() = type == PsiType.DOUBLE || type == PsiType.LONG
+
+ private fun PsiForeachStatement.getExtraLocals(): List {
+ val localIndex = getUserData(LOCAL_INDEX_KEY)!!
+ val iterable = iteratedValue ?: return emptyList()
+ val type = iterable.type
+ if (type is PsiArrayType) {
+ return listOf(
+ // array
+ SourceLocalVariable(
+ "var$localIndex",
+ type,
+ localIndex,
+ implicitLoadCountBefore = 1,
+ implicitStoreCountBefore = 1
+ ),
+ // length
+ SourceLocalVariable(
+ "var${localIndex + 1}",
+ PsiType.INT,
+ localIndex + 1,
+ implicitStoreCountBefore = 1,
+ implicitLoadCountAfter = 1
+ ),
+ // index
+ SourceLocalVariable(
+ "var${localIndex + 2}",
+ PsiType.INT,
+ localIndex + 2,
+ implicitStoreCountBefore = 1,
+ implicitLoadCountBefore = 1,
+ implicitLoadCountAfter = 1
+ )
+ )
+ } else {
+ val iteratorType = JavaPsiFacade.getElementFactory(project)
+ .createTypeByFQClassName(
+ CommonClassNames.JAVA_UTIL_ITERATOR,
+ resolveScope
+ )
+ return listOf(
+ // iterator
+ SourceLocalVariable(
+ "var$localIndex",
+ iteratorType,
+ localIndex,
+ implicitStoreCountBefore = 1,
+ implicitLoadCountBefore = 1
+ )
+ )
+ }
+ }
+
+ fun getLocals(
+ module: Module,
+ classNode: ClassNode,
+ method: MethodNode,
+ node: AbstractInsnNode
+ ): Array? {
+ return getLocals(module.project, classNode, method, node, detectCurrentSettings(module))
+ }
+
+ private fun getLocals(
+ project: Project,
+ classNode: ClassNode,
+ method: MethodNode,
+ nodeArg: AbstractInsnNode,
+ settings: Settings
+ ): Array? {
+ return try {
+ doGetLocals(project, classNode, method, nodeArg, settings)
+ } catch (e: LocalAnalysisFailedException) {
+ null
+ }
+ }
+
+ private val resurrectLocalsChange = SemanticVersion.release(0, 8, 3)
+ private fun detectCurrentSettings(module: Module): Settings {
+ val mixinVersion = MinecraftFacet.getInstance(module, MixinModuleType)?.mixinVersion
+ ?: throw LocalAnalysisFailedException()
+ return if (mixinVersion < resurrectLocalsChange) {
+ Settings.NO_RESURRECT
+ } else {
+ Settings.DEFAULT
+ }
+ }
+
+ private fun doGetLocals(
+ project: Project,
+ classNode: ClassNode,
+ method: MethodNode,
+ nodeArg: AbstractInsnNode,
+ settings: Settings
+ ): Array {
+ var node = nodeArg
+ for (i in 0 until 3) {
+ if (node !is LabelNode && node !is LineNumberNode) {
+ break
+ }
+ val nextNode = method.instructions.nextNode(node)
+ if (nextNode is FrameNode) { // Do not ffwd over frames
+ break
+ }
+ node = nextNode
+ }
+
+ val frames = method.instructions.iterator().asSequence().filterIsInstance().toList()
+ val frame = arrayOfNulls(method.maxLocals)
+ var local = 0
+ var index = 0
+
+ // Initialise implicit "this" reference in non-static methods
+ if (!method.hasAccess(Opcodes.ACC_STATIC)) {
+ frame[local++] = LocalVariable("this", Type.getObjectType(classNode.name).toString(), null, null, null, 0)
+ }
+
+ // Initialise method arguments
+ for (argType in Type.getArgumentTypes(method.desc)) {
+ frame[local] = LocalVariable("arg" + index++, argType.toString(), null, null, null, local)
+ local += argType.size
+ }
+
+ val initialFrameSize = local
+ var frameSize = local
+ var frameIndex = -1
+ var lastFrameSize = local
+ var knownFrameSize = local
+ var storeInsn: VarInsnNode? = null
+
+ for (insn in method.instructions) {
+ // Tick the zombies
+ for (zombie in frame.asSequence().filterIsInstance()) {
+ zombie.lifetime++
+ if (insn is FrameNode) {
+ zombie.frames++
+ }
+ }
+
+ if (storeInsn != null) {
+ val storedLocal = getLocalVariableAt(project, classNode, method, insn, storeInsn.`var`)
+ frame[storeInsn.`var`] = storedLocal
+ knownFrameSize = knownFrameSize.coerceAtLeast(storeInsn.`var` + 1)
+ if (storedLocal != null &&
+ storeInsn.`var` < method.maxLocals - 1 &&
+ storedLocal.desc != null &&
+ Type.getType(storedLocal.desc).size == 2
+ ) {
+ frame[storeInsn.`var` + 1] = null // TOP
+ knownFrameSize = knownFrameSize.coerceAtLeast(storeInsn.`var` + 2)
+ if (settings.resurrectExposedOnStore) {
+ resurrect(frame, knownFrameSize, settings)
+ }
+ }
+ storeInsn = null
+ }
+
+ if (insn is FrameNode) {
+ fun handleFrame() {
+ frameIndex++
+ if (insn.type == Opcodes.F_SAME || insn.type == Opcodes.F_SAME1) {
+ return
+ }
+ val frameNodeSize = insn.computeFrameSize(initialFrameSize)
+ val frameData = frames.getOrNull(frameIndex)
+ if (frameData != null) {
+ if (frameData.type == Opcodes.F_FULL) {
+ frameSize = frameNodeSize.coerceAtLeast(initialFrameSize)
+ lastFrameSize = frameSize
+ knownFrameSize = lastFrameSize
+ } else {
+ frameSize = getAdjustedFrameSize(
+ frameSize,
+ frameData.type,
+ frameData.computeFrameSize(initialFrameSize),
+ initialFrameSize
+ )
+ }
+ } else {
+ frameSize =
+ getAdjustedFrameSize(
+ frameSize,
+ insn.type,
+ frameNodeSize,
+ initialFrameSize
+ )
+ }
+
+ // Sanity check
+ if (frameSize < initialFrameSize) {
+ throw IllegalStateException(
+ "Locals entered an invalid state evaluating " +
+ "${classNode.name}::${method.name}${method.desc} at instruction " +
+ "${method.instructions.indexOf(insn)}. Initial frame size is" +
+ " $initialFrameSize, calculated a frame size of $frameSize"
+ )
+ }
+ if ((
+ (frameData == null && (insn.type == Opcodes.F_CHOP || insn.type == Opcodes.F_NEW)) ||
+ (frameData != null && frameData.type == Opcodes.F_CHOP)
+ )
+ ) {
+ for (framePos in frameSize until frame.size) {
+ frame[framePos] = ZombieLocalVariable.of(frame[framePos], ZombieLocalVariable.CHOP)
+ }
+ lastFrameSize = frameSize
+ knownFrameSize = lastFrameSize
+ return
+ }
+ var framePos = if (insn.type == Opcodes.F_APPEND) lastFrameSize else 0
+ lastFrameSize = frameSize
+
+ // localPos tracks the location in the frame node's locals list, which doesn't leave space for TOP entries
+ var localPos = 0
+ while (framePos < frame.size) {
+ // Get the local at the current position in the FrameNode's locals list
+ val localType = if ((localPos < insn.local.size)) insn.local[localPos] else null
+ if (localType is String) { // String refers to a reference type
+ frame[framePos] =
+ getLocalVariableAt(
+ project,
+ classNode,
+ method,
+ method.instructions.indexOf(insn),
+ framePos
+ )
+ } else if (localType is Int) { // Integer refers to a primitive type or other marker
+ val isMarkerType = localType == Opcodes.UNINITIALIZED_THIS || localType == Opcodes.NULL
+ val is32bitValue = localType == Opcodes.INTEGER || localType == Opcodes.FLOAT
+ val is64bitValue = localType == Opcodes.DOUBLE || localType == Opcodes.LONG
+ if (localType == Opcodes.TOP) {
+ // Explicit TOP entries are pretty much always bogus, but depending on our resurrection
+ // strategy we may want to resurrect eligible zombies here. Real TOP entries are handled below
+ if (frame[framePos] is ZombieLocalVariable && settings.resurrectForBogusTop) {
+ val zombie = frame[framePos] as ZombieLocalVariable
+ if (zombie.type == ZombieLocalVariable.TRIM) {
+ frame[framePos] = zombie.ancestor
+ }
+ }
+ } else if (isMarkerType) {
+ frame[framePos] = null
+ } else if (is32bitValue || is64bitValue) {
+ frame[framePos] =
+ getLocalVariableAt(
+ project,
+ classNode,
+ method,
+ method.instructions.indexOf(insn),
+ framePos
+ )
+ if (is64bitValue) {
+ framePos++
+ frame[framePos] = null // TOP
+ }
+ } else {
+ throw IllegalStateException(
+ "Unrecognised locals opcode $localType in locals array at position" +
+ " $localPos in ${classNode.name}.${method.name}${method.desc}"
+ )
+ }
+ } else if (localType == null) {
+ if ((framePos >= initialFrameSize) && (framePos >= frameSize) && (frameSize > 0)) {
+ if (framePos < knownFrameSize) {
+ frame[framePos] = getLocalVariableAt(
+ project,
+ classNode,
+ method,
+ insn,
+ framePos
+ )
+ } else {
+ frame[framePos] = ZombieLocalVariable.of(frame[framePos], ZombieLocalVariable.TRIM)
+ }
+ }
+ } else if (localType is LabelNode) {
+ // Uninitialised
+ } else {
+ throw IllegalStateException(
+ "Invalid value $localType in locals array at position" +
+ " $localPos in ${classNode.name}.${method.name}${method.desc}"
+ )
+ }
+ framePos++
+ localPos++
+ }
+ }
+
+ handleFrame()
+ } else if (insn is VarInsnNode) {
+ val isLoad = insn.getOpcode() >= Opcodes.ILOAD && insn.getOpcode() <= Opcodes.SALOAD
+ if (isLoad) {
+ val loadedVar = getLocalVariableAt(project, classNode, method, insn, insn.`var`)
+ frame[insn.`var`] = loadedVar
+ val varSize = loadedVar?.desc?.let { Type.getType(it).size } ?: 1
+ knownFrameSize = (insn.`var` + varSize).coerceAtLeast(knownFrameSize)
+ if (settings.resurrectExposedOnLoad) {
+ resurrect(frame, knownFrameSize, settings)
+ }
+ } else {
+ // Update the LVT for the opcode AFTER this one, since we always want to know
+ // the frame state BEFORE the *current* instruction to match the contract of
+ // injection points
+ storeInsn = insn
+ }
+ }
+
+ if (insn === node) {
+ break
+ }
+ }
+
+ // Null out any "unknown" or mixin-provided locals
+ for (l in frame.indices) {
+ val variable = frame[l]
+ if (variable is ZombieLocalVariable) {
+ // preserve zombies where the frame node which culled them was immediately prior to
+ // the matched instruction, or *was itself* the matched instruction, the returned
+ // frame will contain the original node (the zombie ancestor)
+ frame[l] = if (variable.lifetime > 1) null else variable.ancestor
+ }
+ if (variable != null && variable.desc == null) {
+ frame[l] = null
+ }
+ }
+
+ return frame
+ }
+
+ private fun getAdjustedFrameSize(currentSize: Int, type: Int, size: Int, initialFrameSize: Int): Int {
+ return when (type) {
+ Opcodes.F_NEW, Opcodes.F_FULL -> size.coerceAtLeast(initialFrameSize)
+ Opcodes.F_APPEND -> currentSize + size
+ Opcodes.F_CHOP -> (size - currentSize).coerceAtLeast(initialFrameSize)
+ Opcodes.F_SAME, Opcodes.F_SAME1 -> currentSize
+ else -> currentSize
+ }
+ }
+
+ private fun resurrect(frame: Array, knownFrameSize: Int, settings: Settings) {
+ for ((index, node) in frame.withIndex()) {
+ if (index >= knownFrameSize) {
+ break
+ }
+ if (node is ZombieLocalVariable && node.checkResurrect(settings)) {
+ frame[index] = node.ancestor
+ }
+ }
+ }
+
+ private fun FrameNode.computeFrameSize(initialFrameSize: Int): Int {
+ if (this.local == null) {
+ return initialFrameSize
+ }
+ var size = 0
+ for (local in this.local) {
+ size += if (local == Opcodes.DOUBLE || local == Opcodes.LONG) 2 else 1
+ }
+ return size.coerceAtLeast(initialFrameSize)
+ }
+
+ private fun getLocalVariableAt(
+ project: Project,
+ classNode: ClassNode,
+ method: MethodNode,
+ pos: AbstractInsnNode,
+ index: Int
+ ): LocalVariable? {
+ return getLocalVariableAt(project, classNode, method, method.instructions.indexOf(pos), index)
+ }
+
+ private fun getLocalVariableAt(
+ project: Project,
+ classNode: ClassNode,
+ method: MethodNode,
+ pos: Int,
+ index: Int
+ ): LocalVariable? {
+ var localVariableNode: LocalVariable? = null
+ var fallbackNode: LocalVariable? = null
+ for (local in method.getLocalVariableTable(project, classNode)) {
+ if (local.index != index) {
+ continue
+ }
+ if (local.isInRange(pos)) {
+ localVariableNode = local
+ } else if (localVariableNode == null) {
+ fallbackNode = local
+ }
+ }
+ if (localVariableNode == null && method.localVariables.isNotEmpty()) {
+ for (local in getGeneratedLocalVariableTable(project, classNode, method)) {
+ if (local.index == index && local.isInRange(pos)) {
+ localVariableNode = local
+ }
+ }
+ }
+ return localVariableNode ?: fallbackNode
+ }
+
+ private fun InsnList.nextNode(insn: AbstractInsnNode): AbstractInsnNode {
+ val index = indexOf(insn) + 1
+ if (index > 0 && index < size()) {
+ return get(index)
+ }
+ return insn
+ }
+
+ private fun MethodNode.getLocalVariableTable(project: Project, classNode: ClassNode): List {
+ if (localVariables.isEmpty()) {
+ return getGeneratedLocalVariableTable(project, classNode, this)
+ }
+ return localVariables.map {
+ LocalVariable(
+ it.name,
+ it.desc,
+ it.signature,
+ instructions.indexOf(it.start),
+ instructions.indexOf(it.end),
+ it.index
+ )
+ }
+ }
+
+ private fun getGeneratedLocalVariableTable(
+ project: Project,
+ classNode: ClassNode,
+ method: MethodNode
+ ): List {
+ val frames = AsmDfaUtil.analyzeMethod(project, classNode, method) ?: throw LocalAnalysisFailedException()
+
+ // Record the original size of the method
+ val methodSize = method.instructions.size()
+
+ // List of LocalVariableNodes to return
+ val localVariables = mutableListOf()
+
+ // LocalVariableNodes for current frame
+ val localVars = arrayOfNulls(method.maxLocals)
+
+ // locals in previous frame, used to work out what changes between frames
+ val locals = arrayOfNulls(method.maxLocals)
+
+ val lastKnownType = arrayOfNulls(method.maxLocals)
+
+ // Traverse the frames and work out when locals begin and end
+ for (i in 0 until methodSize) {
+ val f = frames[i] ?: continue
+ for (j in 0 until f.locals) {
+ val local = f.getLocal(j)
+ if (local == null && locals[j] == null) {
+ continue
+ }
+ if (local != null && local == locals[j]) {
+ continue
+ }
+ if (local == null && locals[j] != null) {
+ val localVar = localVars[j]!!
+ localVariables.add(localVar)
+ localVar.end = i
+ localVars[j] = null
+ } else if (local != null) {
+ if (locals[j] != null) {
+ val localVar = localVars[j]!!
+ localVariables.add(localVar)
+ localVar.end = i
+ localVars[j] = null
+ }
+ var desc = lastKnownType[j]
+ val localType = local.type
+ if (localType != null) {
+ desc = if (localType.sort >= Type.ARRAY && localType.internalName == "null") {
+ "Ljava/lang/Object;"
+ } else {
+ localType.descriptor
+ }
+ }
+ localVars[j] = LocalVariable("var$j", desc, null, i, null, j)
+ if (desc != null) {
+ lastKnownType[j] = desc
+ }
+ }
+ locals[j] = local
+ }
+ }
+
+ // Reached the end of the method so flush all current locals and mark the end
+ for (k in localVars.indices) {
+ val localVar = localVars[k]
+ if (localVar != null) {
+ localVar.end = methodSize
+ localVariables.add(localVar)
+ }
+ }
+
+ return localVariables
+ }
+
+ data class Settings(
+ val choppedInsnThreshold: Int,
+ val trimmedInsnThreshold: Int,
+ val choppedFrameThreshold: Int,
+ val trimmedFrameThreshold: Int,
+ val resurrectExposedOnLoad: Boolean,
+ val resurrectExposedOnStore: Boolean,
+ val resurrectForBogusTop: Boolean
+ ) {
+ companion object {
+ val NO_RESURRECT = Settings(
+ choppedInsnThreshold = 0,
+ choppedFrameThreshold = 0,
+ trimmedInsnThreshold = 0,
+ trimmedFrameThreshold = 0,
+ resurrectExposedOnLoad = false,
+ resurrectExposedOnStore = false,
+ resurrectForBogusTop = false
+ )
+
+ val DEFAULT = Settings(
+ choppedInsnThreshold = -1,
+ choppedFrameThreshold = 1,
+ trimmedInsnThreshold = -1,
+ trimmedFrameThreshold = -1,
+ resurrectExposedOnLoad = true,
+ resurrectExposedOnStore = true,
+ resurrectForBogusTop = true
+ )
+ }
+ }
+
+ data class SourceLocalVariable(
+ val name: String,
+ val type: PsiType,
+ val index: Int,
+ val mixinName: String = name,
+ val implicitLoadCountBefore: Int = 0,
+ val implicitLoadCountAfter: Int = 0,
+ val implicitStoreCountBefore: Int = 0,
+ val implicitStoreCountAfter: Int = 0
+ )
+
+ open class LocalVariable(
+ val name: String,
+ val desc: String?,
+ val signature: String?,
+ val start: Int?,
+ var end: Int?,
+ val index: Int
+ ) {
+ fun isInRange(index: Int): Boolean {
+ val end = this.end
+ return (start == null || index >= start) && (end == null || index < end)
+ }
+ }
+
+ private class LocalAnalysisFailedException : Exception() {
+ override fun fillInStackTrace(): Throwable {
+ return this
+ }
+ }
+
+ private class ZombieLocalVariable private constructor(
+ val ancestor: LocalVariable,
+ val type: Char
+ ) : LocalVariable(
+ ancestor.name,
+ ancestor.desc,
+ ancestor.signature,
+ ancestor.start,
+ ancestor.end,
+ ancestor.index
+ ) {
+ var lifetime = 0
+ var frames = 0
+
+ fun checkResurrect(settings: Settings): Boolean {
+ val insnThreshold = if (type == CHOP) settings.choppedInsnThreshold else settings.trimmedInsnThreshold
+ if (insnThreshold > -1 && lifetime > insnThreshold) {
+ return false
+ }
+ val frameThreshold = if (type == CHOP) settings.choppedFrameThreshold else settings.trimmedFrameThreshold
+ return frameThreshold == -1 || frames <= frameThreshold
+ }
+
+ override fun toString(): String {
+ return String.format("Z(%s,%-2d)", type, lifetime)
+ }
+
+ companion object {
+ const val CHOP = 'C'
+ const val TRIM = 'X'
+
+ fun of(ancestor: LocalVariable?, type: Char): ZombieLocalVariable? {
+ return if (ancestor is ZombieLocalVariable) {
+ ancestor
+ } else {
+ ancestor?.let { ZombieLocalVariable(it, type) }
+ }
+ }
+ }
+ }
+}
diff --git a/src/main/kotlin/platform/mixin/util/Mixin.kt b/src/main/kotlin/platform/mixin/util/Mixin.kt
index 5b9a3b56b..fcb0797ae 100644
--- a/src/main/kotlin/platform/mixin/util/Mixin.kt
+++ b/src/main/kotlin/platform/mixin/util/Mixin.kt
@@ -10,6 +10,7 @@
package com.demonwav.mcdev.platform.mixin.util
+import com.demonwav.mcdev.platform.mixin.action.FindMixinsAction
import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.ACCESSOR
import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.INVOKER
import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.MIXIN
@@ -23,11 +24,16 @@ import com.demonwav.mcdev.util.resolveClassArray
import com.intellij.openapi.project.Project
import com.intellij.psi.JavaPsiFacade
import com.intellij.psi.PsiAnnotation
+import com.intellij.psi.PsiArrayType
import com.intellij.psi.PsiClass
+import com.intellij.psi.PsiClassType
+import com.intellij.psi.PsiDisjunctionType
import com.intellij.psi.PsiElement
+import com.intellij.psi.PsiIntersectionType
import com.intellij.psi.PsiPrimitiveType
import com.intellij.psi.PsiType
import com.intellij.psi.search.GlobalSearchScope
+import com.intellij.psi.util.InheritanceUtil
import com.intellij.psi.util.PsiModificationTracker
import org.objectweb.asm.Opcodes
import org.objectweb.asm.tree.ClassNode
@@ -134,3 +140,44 @@ fun callbackInfoReturnableType(project: Project, context: PsiElement, returnType
fun argsType(project: Project): PsiType =
PsiType.getTypeByName(ARGS, project, GlobalSearchScope.allScope(project))
+
+fun isAssignable(left: PsiType, right: PsiType): Boolean {
+ return when {
+ left is PsiIntersectionType -> left.conjuncts.all { isAssignable(it, right) }
+ right is PsiIntersectionType -> right.conjuncts.any { isAssignable(left, it) }
+ left is PsiDisjunctionType -> left.disjunctions.any { isAssignable(it, right) }
+ right is PsiDisjunctionType -> isAssignable(left, right.leastUpperBound)
+ left is PsiArrayType -> right is PsiArrayType && isAssignable(left.componentType, right.componentType)
+ else -> {
+ if (left !is PsiClassType || right !is PsiClassType) {
+ return false
+ }
+ val leftClass = left.resolve() ?: return false
+ val rightClass = right.resolve() ?: return false
+ if (rightClass.isMixin) {
+ val isMixinAssignable = rightClass.mixinTargets.any {
+ val stubClass = it.findStubClass(rightClass.project) ?: return@any false
+ isClassAssignable(leftClass, stubClass)
+ }
+ if (isMixinAssignable) {
+ return true
+ }
+ }
+ val mixins = FindMixinsAction.findMixins(rightClass, rightClass.project) ?: return false
+ return mixins.any { isClassAssignable(leftClass, it) }
+ }
+ }
+}
+
+private fun isClassAssignable(leftClass: PsiClass, rightClass: PsiClass): Boolean {
+ var result = false
+ InheritanceUtil.processSupers(rightClass, true) {
+ if (it == leftClass) {
+ result = true
+ false
+ } else {
+ true
+ }
+ }
+ return result
+}
diff --git a/src/main/kotlin/platform/mixin/util/MixinConstants.kt b/src/main/kotlin/platform/mixin/util/MixinConstants.kt
index 589d2502f..4ce520ba9 100644
--- a/src/main/kotlin/platform/mixin/util/MixinConstants.kt
+++ b/src/main/kotlin/platform/mixin/util/MixinConstants.kt
@@ -21,6 +21,7 @@ object MixinConstants {
const val CALLBACK_INFO_RETURNABLE = "org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable"
const val ARGS = "org.spongepowered.asm.mixin.injection.invoke.arg.Args"
const val COMPATIBILITY_LEVEL = "org.spongepowered.asm.mixin.MixinEnvironment.CompatibilityLevel"
+ const val CONSTANT_CONDITION = "org.spongepowered.asm.mixin.injection.Constant.Condition"
const val INJECTION_POINT = "org.spongepowered.asm.mixin.injection.InjectionPoint"
const val SELECTOR = "org.spongepowered.asm.mixin.injection.InjectionPoint.Selector"
const val MIXIN_AGENT = "org.spongepowered.tools.agent.MixinAgent"
@@ -31,12 +32,15 @@ object MixinConstants {
const val SHIFT = "org.spongepowered.asm.mixin.injection.At.Shift"
const val SERIALIZED_NAME = "com.google.gson.annotations.SerializedName"
+ const val MIXIN_SERIALIZED_NAME = "org.spongepowered.include.$SERIALIZED_NAME"
}
object Annotations {
const val ACCESSOR = "org.spongepowered.asm.mixin.gen.Accessor"
const val AT = "org.spongepowered.asm.mixin.injection.At"
const val AT_CODE = "org.spongepowered.asm.mixin.injection.InjectionPoint.AtCode"
+ const val COERCE = "org.spongepowered.asm.mixin.injection.Coerce"
+ const val CONSTANT = "org.spongepowered.asm.mixin.injection.Constant"
const val DEBUG = "org.spongepowered.asm.mixin.Debug"
const val DESC = "org.spongepowered.asm.mixin.injection.Desc"
const val DYNAMIC = "org.spongepowered.asm.mixin.Dynamic"
diff --git a/src/main/kotlin/platform/mixin/util/TargetClass.kt b/src/main/kotlin/platform/mixin/util/TargetClass.kt
index d565a6a9e..bb17436d9 100644
--- a/src/main/kotlin/platform/mixin/util/TargetClass.kt
+++ b/src/main/kotlin/platform/mixin/util/TargetClass.kt
@@ -13,11 +13,9 @@ package com.demonwav.mcdev.platform.mixin.util
import com.demonwav.mcdev.platform.mixin.util.MixinConstants.Annotations.DYNAMIC
import com.demonwav.mcdev.util.equivalentTo
import com.demonwav.mcdev.util.findAnnotation
-import com.demonwav.mcdev.util.findContainingMethod
import com.demonwav.mcdev.util.findMethods
import com.demonwav.mcdev.util.resolveClass
import com.intellij.psi.PsiClass
-import com.intellij.psi.PsiElement
import com.intellij.psi.PsiMember
import org.objectweb.asm.Opcodes
import org.objectweb.asm.tree.ClassNode
@@ -28,9 +26,6 @@ fun PsiMember.findUpstreamMixin(): PsiClass? {
return findAnnotation(DYNAMIC)?.findDeclaredAttributeValue("mixin")?.resolveClass()
}
-val PsiElement.isWithinDynamicMixin: Boolean
- get() = findContainingMethod()?.findAnnotation(DYNAMIC) != null
-
data class ClassAndMethodNode(val clazz: ClassNode, val method: MethodNode)
fun findMethods(psiClass: PsiClass, allowClinit: Boolean = true): Sequence? {
diff --git a/src/main/kotlin/util/Parameter.kt b/src/main/kotlin/util/Parameter.kt
index 732749dfc..61eb87416 100644
--- a/src/main/kotlin/util/Parameter.kt
+++ b/src/main/kotlin/util/Parameter.kt
@@ -15,4 +15,7 @@ import com.intellij.psi.PsiType
data class Parameter(val name: String?, val type: PsiType) {
constructor(parameter: PsiParameter) : this(parameter.name, parameter.type)
+ init {
+ assert(name?.isJavaKeyword() != true)
+ }
}
diff --git a/src/main/kotlin/util/annotation-utils.kt b/src/main/kotlin/util/annotation-utils.kt
index 169773713..48c5b4e9e 100644
--- a/src/main/kotlin/util/annotation-utils.kt
+++ b/src/main/kotlin/util/annotation-utils.kt
@@ -65,7 +65,7 @@ fun PsiAnnotationMemberValue?.isNotEmpty(): Boolean {
return this != null && (this !is PsiArrayInitializerMemberValue || initializers.isNotEmpty())
}
-private inline fun PsiAnnotationMemberValue.parseArray(func: (PsiAnnotationMemberValue) -> T?): List {
+inline fun PsiAnnotationMemberValue.parseArray(func: (PsiAnnotationMemberValue) -> T?): List {
return if (this is PsiArrayInitializerMemberValue) {
initializers.mapNotNull(func)
} else {
diff --git a/src/main/kotlin/util/psi-utils.kt b/src/main/kotlin/util/psi-utils.kt
index d9c28b0a4..8f4e9779f 100644
--- a/src/main/kotlin/util/psi-utils.kt
+++ b/src/main/kotlin/util/psi-utils.kt
@@ -22,15 +22,19 @@ import com.intellij.openapi.roots.ModuleRootManager
import com.intellij.openapi.roots.ProjectFileIndex
import com.intellij.openapi.roots.impl.OrderEntryUtil
import com.intellij.openapi.util.Key
+import com.intellij.openapi.util.text.StringUtil
import com.intellij.psi.ElementManipulator
import com.intellij.psi.ElementManipulators
import com.intellij.psi.JavaPsiFacade
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiDirectory
import com.intellij.psi.PsiElement
+import com.intellij.psi.PsiElementFactory
import com.intellij.psi.PsiElementResolveResult
+import com.intellij.psi.PsiEllipsisType
import com.intellij.psi.PsiFile
import com.intellij.psi.PsiKeyword
+import com.intellij.psi.PsiLiteralExpression
import com.intellij.psi.PsiMember
import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiMethodReferenceExpression
@@ -46,8 +50,10 @@ import com.intellij.psi.filters.ElementFilter
import com.intellij.psi.util.CachedValueProvider
import com.intellij.psi.util.CachedValuesManager
import com.intellij.psi.util.PsiTreeUtil
+import com.intellij.psi.util.PsiTypesUtil
import com.intellij.psi.util.TypeConversionUtil
import com.intellij.refactoring.changeSignature.ChangeSignatureUtil
+import com.intellij.util.IncorrectOperationException
import com.siyeh.ig.psiutils.ImportUtils
// Parent
@@ -181,8 +187,15 @@ infix fun PsiElement.equivalentTo(other: PsiElement): Boolean {
}
fun PsiType?.isErasureEquivalentTo(other: PsiType?): Boolean {
- // TODO: Do more checks for generics instead
- return TypeConversionUtil.erasure(this) == TypeConversionUtil.erasure(other)
+ return this?.normalize() == other?.normalize()
+}
+
+fun PsiType.normalize(): PsiType {
+ var normalized = TypeConversionUtil.erasure(this)
+ if (normalized is PsiEllipsisType) {
+ normalized = normalized.toArrayType()
+ }
+ return normalized
}
val PsiMethod.nameAndParameterTypes: String
@@ -249,3 +262,19 @@ val PsiMethodReferenceExpression.hasSyntheticMethod: Boolean
if (qualifier !is PsiReferenceExpression) return true
return qualifier.resolve() !is PsiClass
}
+
+val PsiClass.psiType: PsiType
+ get() = PsiTypesUtil.getClassType(this)
+
+fun PsiElementFactory.createLiteralExpression(constant: Any?): PsiLiteralExpression {
+ return when (constant) {
+ null -> createExpressionFromText("null", null)
+ is Boolean, is Double, is Int -> createExpressionFromText(constant.toString(), null)
+ is Char -> createExpressionFromText("'${StringUtil.escapeCharCharacters(constant.toString())}'", null)
+ is Float -> createExpressionFromText("${constant}F", null)
+ is Long -> createExpressionFromText("${constant}L", null)
+ is String -> createExpressionFromText("\"${StringUtil.escapeStringCharacters(constant)}\"", null)
+
+ else -> throw IncorrectOperationException("Unsupported literal type: ${constant.javaClass.name}")
+ } as PsiLiteralExpression
+}
diff --git a/src/main/kotlin/util/utils.kt b/src/main/kotlin/util/utils.kt
index 41e723fb0..97f2a6a4a 100644
--- a/src/main/kotlin/util/utils.kt
+++ b/src/main/kotlin/util/utils.kt
@@ -184,6 +184,15 @@ fun Array.rotate(amount: Int) {
}
}
+inline fun Iterable.firstIndexOrNull(predicate: (T) -> Boolean): Int? {
+ for ((index, element) in this.withIndex()) {
+ if (predicate(element)) {
+ return index
+ }
+ }
+ return null
+}
+
fun Module.findChildren(): Set {
return runReadAction {
val manager = ModuleManager.getInstance(project)
diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml
index c9805bf70..e83c90bfb 100644
--- a/src/main/resources/META-INF/plugin.xml
+++ b/src/main/resources/META-INF/plugin.xml
@@ -59,7 +59,7 @@
-
+
@@ -71,7 +71,9 @@
-
+
+
+
@@ -316,6 +318,8 @@
+
+
@@ -677,6 +681,14 @@
implementationClass="com.demonwav.mcdev.platform.mixin.inspection.implements.DuplicateInterfacePrefixInspection"/>
+
+