Skip to content

Commit

Permalink
Support upgrading foreach loops wrt. subtyping to susport suspendable…
Browse files Browse the repository at this point in the history
… iterators puniverse#285
  • Loading branch information
FroMage committed Nov 24, 2017
1 parent a01a8da commit 1c8a7a8
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public boolean isAlreadyInstrumented() {
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
this.className = name;
this.isInterface = (access & Opcodes.ACC_INTERFACE) != 0;
this.classEntry = new ClassEntry(superName);
this.classEntry = new ClassEntry(name, superName);
classEntry.setInterfaces(interfaces);
classEntry.setIsInterface(isInterface);
}
Expand Down Expand Up @@ -137,7 +137,7 @@ public MethodVisitor visitMethod(final int access, final String name, final Stri
}
}
suspendable = InstrumentClass.suspendableToSuperIfAbstract(access, suspendable);
classEntry.set(name, desc, suspendable);
classEntry.set(name, desc, suspendable, (access & Opcodes.ACC_BRIDGE) != 0);

if (suspendable == null) // look for @Suspendable annotation
return new MethodVisitor(ASMAPI) {
Expand All @@ -153,7 +153,8 @@ public AnnotationVisitor visitAnnotation(String adesc, boolean visible) {
@Override
public void visitEnd() {
super.visitEnd();
classEntry.set(name, desc, InstrumentClass.suspendableToSuperIfAbstract(access, susp ? SuspendableType.SUSPENDABLE : SuspendableType.NON_SUSPENDABLE));
classEntry.set(name, desc, InstrumentClass.suspendableToSuperIfAbstract(access, susp ? SuspendableType.SUSPENDABLE : SuspendableType.NON_SUSPENDABLE),
(access & Opcodes.ACC_BRIDGE) != 0);
hasSuspendable = hasSuspendable | susp;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ public MethodVisitor visitMethod(final int access, final String name, final Stri
final SuspendableType setSuspendable = classEntry.check(name, desc);

if (setSuspendable == null)
classEntry.set(name, desc, markedSuspendable != null ? markedSuspendable : SuspendableType.NON_SUSPENDABLE);
classEntry.set(name, desc, markedSuspendable != null ? markedSuspendable : SuspendableType.NON_SUSPENDABLE,
(access & Opcodes.ACC_BRIDGE) != 0);

final SuspendableType suspendable = max(markedSuspendable, setSuspendable, SuspendableType.NON_SUSPENDABLE);

Expand Down Expand Up @@ -200,7 +201,7 @@ private void commit() {

if (db.isDebug())
db.log(LogLevel.INFO, "Method %s#%s%s suspendable: %s (markedSuspendable: %s setSuspendable: %s)", className, name, desc, susp, susp, setSuspendable);
classEntry.set(name, desc, susp);
classEntry.set(name, desc, susp, (access & Opcodes.ACC_BRIDGE) != 0);

if (susp == SuspendableType.SUSPENDABLE && checkAccessForMethodInstrumentation(access)) {
if (isSynchronized(access)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,27 +60,15 @@
import static co.paralleluniverse.fibers.instrument.MethodDatabase.isMethodHandleInvocation;
import static co.paralleluniverse.fibers.instrument.MethodDatabase.isReflectInvocation;
import static co.paralleluniverse.fibers.instrument.MethodDatabase.isSyntheticAccess;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import java.util.*;

import org.objectweb.asm.AnnotationVisitor;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.InvokeDynamicInsnNode;
import org.objectweb.asm.tree.JumpInsnNode;
import org.objectweb.asm.tree.LabelNode;
import org.objectweb.asm.tree.LineNumberNode;
import org.objectweb.asm.tree.LocalVariableNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.TryCatchBlockNode;
import org.objectweb.asm.tree.*;
import org.objectweb.asm.tree.analysis.Analyzer;
import org.objectweb.asm.tree.analysis.AnalyzerException;
import org.objectweb.asm.tree.analysis.BasicValue;
Expand Down Expand Up @@ -144,6 +132,7 @@ class InstrumentMethod {
this.mn = mn;

try {
upgradeForeach(mn);
Analyzer a = new TypeAnalyzer(db);
this.frames = a.analyze(className, mn);
this.lvarStack = mn.maxLocals;
Expand All @@ -157,6 +146,147 @@ class InstrumentMethod {
}
}

public void upgradeForeach(MethodNode mn) {
ListIterator<AbstractInsnNode> it = mn.instructions.iterator();
int i = 0;
while(it.hasNext()) {
AbstractInsnNode instr = it.next();
if(instr.getType() == AbstractInsnNode.METHOD_INSN
&& (instr.getOpcode() == Opcodes.INVOKEVIRTUAL
|| instr.getOpcode() == Opcodes.INVOKEINTERFACE)) {
MethodInsnNode mCall = (MethodInsnNode) instr;
if(mCall.name.equals("iterator")
// we can't check the return type here because Eclipse makes it Iterator
// but javac respects subtypes (but only here)
// && mCall.desc.equals("()Ljava/util/Iterator;")
) {
checkForeach(mCall);
}
}
}
}

private void checkForeach(MethodInsnNode iteratorCall) {
// iterable.iterator(): invoke iterator(), store, jump to test [only Eclipse]
AbstractInsnNode mCallPlus1 = iteratorCall.getNext();
if(mCallPlus1 == null
|| mCallPlus1.getType() != AbstractInsnNode.VAR_INSN
|| mCallPlus1.getOpcode() != Opcodes.ASTORE)
return;
VarInsnNode mCallStore = (VarInsnNode) mCallPlus1;
int iteratorVarIndex = mCallStore.var;
AbstractInsnNode mCallPlus2 = mCallPlus1.getNext();
if(mCallPlus2 == null)
return;

boolean testBeforeNext;
AbstractInsnNode testInstr;
if(mCallPlus2.getType() == AbstractInsnNode.JUMP_INSN
&& mCallPlus2.getOpcode() == Opcodes.GOTO){
testBeforeNext = true;
// jump to the hasNext() test: load, invoke hasNext(), ifne to body
JumpInsnNode jumpToTest = (JumpInsnNode) mCallPlus2;
testInstr = getJumpTarget(jumpToTest.label);
}else{
// continue hasNext() test: label, load, invoke hasNext(), ifeq to end
testBeforeNext = false;
if(mCallPlus2.getType() != AbstractInsnNode.LABEL)
return;
testInstr = mCallPlus2.getNext();
}

if(testInstr == null
|| testInstr.getType() != AbstractInsnNode.VAR_INSN
|| testInstr.getOpcode() != Opcodes.ALOAD)
return;
VarInsnNode testLoad = (VarInsnNode) testInstr;
if(testLoad.var != iteratorVarIndex)
return;
AbstractInsnNode testLoadPlus1 = testLoad.getNext();
if(testLoadPlus1 == null
|| testLoadPlus1.getType() != AbstractInsnNode.METHOD_INSN
|| testLoadPlus1.getOpcode() != Opcodes.INVOKEINTERFACE)
return;
MethodInsnNode hasNextCall = (MethodInsnNode) testLoadPlus1;
if(!hasNextCall.name.equals("hasNext")
|| !hasNextCall.owner.equals("java/util/Iterator")
|| !hasNextCall.desc.equals("()Z"))
return;
AbstractInsnNode testLoadPlus2 = hasNextCall.getNext();
if(testLoadPlus2 == null
|| testLoadPlus2.getType() != AbstractInsnNode.JUMP_INSN)
return;
if(testBeforeNext && testLoadPlus2.getOpcode() != Opcodes.IFNE)
return;
if(!testBeforeNext && testLoadPlus2.getOpcode() != Opcodes.IFEQ)
return;

// Now check body: load, invoke next()
JumpInsnNode jumpToBody = (JumpInsnNode) testLoadPlus2;
AbstractInsnNode bodyInstr = testBeforeNext ? getJumpTarget(jumpToBody.label) : jumpToBody.getNext();
if(bodyInstr == null
|| bodyInstr.getType() != AbstractInsnNode.VAR_INSN
|| bodyInstr.getOpcode() != Opcodes.ALOAD)
return;
VarInsnNode bodyLoad = (VarInsnNode) bodyInstr;
if(bodyLoad.var != iteratorVarIndex)
return;
AbstractInsnNode bodyLoadPlus1 = bodyLoad.getNext();
if(bodyLoadPlus1 == null
|| bodyLoadPlus1.getType() != AbstractInsnNode.METHOD_INSN
|| bodyLoadPlus1.getOpcode() != Opcodes.INVOKEINTERFACE)
return;
MethodInsnNode nextCall = (MethodInsnNode) bodyLoadPlus1;
if(!nextCall.name.equals("next")
|| !nextCall.owner.equals("java/util/Iterator")
|| !nextCall.desc.equals("()Ljava/lang/Object;"))
return;

MethodDatabase.ClassEntry iterableClassEntry = db.getOrLoadClassEntry(iteratorCall.owner);
if(iterableClassEntry == null)
return;
if(!iterableClassEntry.implementsInterface("java/lang/Iterable", db))
return;
MethodDatabase.ClassEntry methodOwnerClass = iterableClassEntry.getClassImplementingMethod("iterator()", db);
// iteratorType contains the "L...;" parts
String iteratorType = methodOwnerClass.getReturnType("iterator()");
if(iteratorType == null || iteratorType.equals("Ljava/util/Iterator;"))
return;

MethodDatabase.ClassEntry iteratorClass =
db.getOrLoadClassEntry(iteratorType.substring(1, iteratorType.length()-1));
if(iteratorClass == null)
return;
MethodDatabase.ClassEntry nextOwnerClass = iteratorClass.getClassImplementingMethod("next()", db);
if(nextOwnerClass == null)
return;
String nextMethodOwner = nextOwnerClass.getName();
boolean nextMethodInterface = nextOwnerClass.isInterface();

MethodDatabase.ClassEntry hasNextOwnerClass = iteratorClass.getClassImplementingMethod("hasNext()", db);
if(hasNextOwnerClass == null)
return;
String hasNextMethodOwner = hasNextOwnerClass.getName();
boolean hasNextMethodInterface = hasNextOwnerClass.isInterface();

iteratorCall.desc = "()"+iteratorType;
hasNextCall.owner = hasNextMethodOwner;
hasNextCall.setOpcode(hasNextMethodInterface ? Opcodes.INVOKEINTERFACE : Opcodes.INVOKEVIRTUAL);
hasNextCall.itf = hasNextMethodInterface;
nextCall.owner = nextMethodOwner;
nextCall.setOpcode(nextMethodInterface ? Opcodes.INVOKEINTERFACE : Opcodes.INVOKEVIRTUAL);
nextCall.itf = nextMethodInterface;
}

private AbstractInsnNode getJumpTarget(LabelNode label) {
AbstractInsnNode next = label.getNext();
while(next.getType() == AbstractInsnNode.FRAME
|| next.getType() == AbstractInsnNode.LINE) {
next = next.getNext();
}
return next;
}

private void collectCallsites() {
if (suspCallsBcis == null) {
suspCallsBcis = new int[8];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,8 @@
import java.io.IOException;
import java.io.InputStream;
import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
import java.util.*;

import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Opcodes;
Expand Down Expand Up @@ -239,7 +235,7 @@ public synchronized ClassEntry getClassEntry(String className) {
public synchronized ClassEntry getOrCreateClassEntry(String className, String superType) {
ClassEntry ce = classes.get(className);
if (ce == null) {
ce = new ClassEntry(superType);
ce = new ClassEntry(className, superType);
classes.put(className, ce);
}
return ce;
Expand Down Expand Up @@ -454,30 +450,40 @@ public static boolean isProblematicClass(String className) {
|| className.startsWith("org/apache/log4j/");
}

private static final ClassEntry CLASS_NOT_FOUND = new ClassEntry("<class not found>");
private static final ClassEntry CLASS_NOT_FOUND = new ClassEntry("<class not found>", "<class not found>");

public enum SuspendableType {
NON_SUSPENDABLE, SUSPENDABLE_SUPER, SUSPENDABLE
};

public static final class ClassEntry {
private final HashMap<String, SuspendableType> methods;
public final HashMap<String, SuspendableType> methods;
public final HashSet<String> bridges;
private String sourceName;
private String sourceDebugInfo;
private boolean isInterface;
private String[] interfaces;
private final String superName;
private final String name;
private boolean instrumented;
private volatile boolean requiresInstrumentation;

public ClassEntry(String superName) {
public ClassEntry(String name, String superName) {
this.name = name;
this.superName = superName;
this.methods = new HashMap<>();
this.bridges = new HashSet<>();
}

public String getName() {
return name;
}

public void set(String name, String desc, SuspendableType suspendable) {
public void set(String name, String desc, SuspendableType suspendable, boolean bridge) {
String nameAndDesc = key(name, desc);
methods.put(nameAndDesc, suspendable);
if(bridge)
bridges.add(nameAndDesc);
}

public String getSourceName() {
Expand Down Expand Up @@ -568,6 +574,61 @@ public boolean isInstrumented() {
public void setInstrumented(boolean instrumented) {
this.instrumented = instrumented;
}

public boolean implementsInterface(String name, MethodDatabase db) {
for(String interf : interfaces){
if(interf.equals(name))
return true;
}
if(superName != null){
ClassEntry superClass = db.getOrLoadClassEntry(superName);
if(superClass != null && superClass.implementsInterface(name, db))
return true;
}
for(String interf : interfaces){
ClassEntry superClass = db.getOrLoadClassEntry(interf);
if(superClass != null && superClass.implementsInterface(name, db))
return true;
}
return false;
}

public ClassEntry getClassImplementingMethod(String methodNameAndParams, MethodDatabase db) {
for (Map.Entry<String, SuspendableType> entry : methods.entrySet()) {
String key = entry.getKey();
if (key.substring(0, key.indexOf(')')+1).equals(methodNameAndParams)
&& !bridges.contains(key))
return this;
}
if(superName != null){
ClassEntry superClass = db.getOrLoadClassEntry(superName);
if(superClass != null) {
ClassEntry ret = superClass.getClassImplementingMethod(methodNameAndParams, db);
if (ret != null)
return ret;
}
}
for(String interf : interfaces){
ClassEntry superClass = db.getOrLoadClassEntry(interf);
if(superClass != null) {
ClassEntry ret = superClass.getClassImplementingMethod(methodNameAndParams, db);
if (ret != null)
return ret;
}
}
return null;
}

public String getReturnType(String methodNameAndParams) {
for (Map.Entry<String, SuspendableType> entry : methods.entrySet()) {
String key = entry.getKey();
int retIndex = key.indexOf(")")+1;
if (key.substring(0, retIndex).equals(methodNameAndParams)
&& !bridges.contains(key))
return key.substring(retIndex);
}
return null;
}
}

public static class ExtractSuperClass extends ClassVisitor {
Expand Down

0 comments on commit 1c8a7a8

Please sign in to comment.