diff --git a/core/src/main/java/com/google/errorprone/bugpatterns/ByteBufferBackingArray.java b/core/src/main/java/com/google/errorprone/bugpatterns/ByteBufferBackingArray.java index c2df097a585..e14d992cdb5 100644 --- a/core/src/main/java/com/google/errorprone/bugpatterns/ByteBufferBackingArray.java +++ b/core/src/main/java/com/google/errorprone/bugpatterns/ByteBufferBackingArray.java @@ -30,11 +30,16 @@ import com.google.errorprone.matchers.Matcher; import com.google.errorprone.util.ASTHelpers; import com.sun.source.tree.AssignmentTree; +import com.sun.source.tree.BinaryTree; import com.sun.source.tree.ClassTree; import com.sun.source.tree.ExpressionTree; +import com.sun.source.tree.IfTree; import com.sun.source.tree.MethodInvocationTree; import com.sun.source.tree.MethodTree; +import com.sun.source.tree.ParenthesizedTree; import com.sun.source.tree.Tree; +import com.sun.source.tree.Tree.Kind; +import com.sun.source.tree.UnaryTree; import com.sun.source.tree.VariableTree; import com.sun.source.util.TreeScanner; import com.sun.tools.javac.code.Symbol; @@ -44,12 +49,13 @@ import java.util.Optional; /** - * Checks when ByteBuffer.array() is used without calling .arrayOffset() to know the offset of the - * array, or when the buffer wasn't initialized using ByteBuffer.wrap() or ByteBuffer.allocate(). + * Checks when ByteBuffer.array() is used without calling .arrayOffset() or .hasArray() to ensure + * safe access to the backing array, or when the buffer wasn't initialized using ByteBuffer.wrap() + * or ByteBuffer.allocate(). */ @BugPattern( summary = - "ByteBuffer.array() shouldn't be called unless ByteBuffer.arrayOffset() is used or " + "ByteBuffer.array() shouldn't be called unless ByteBuffer.arrayOffset() or ByteBuffer.hasArray() is used or " + "if the ByteBuffer was initialized using ByteBuffer.wrap() or ByteBuffer.allocate().", severity = WARNING) public class ByteBufferBackingArray extends BugChecker implements MethodInvocationTreeMatcher { @@ -60,6 +66,9 @@ public class ByteBufferBackingArray extends BugChecker implements MethodInvocati private static final Matcher BYTE_BUFFER_ARRAY_OFFSET_MATCHER = anyOf(instanceMethod().onDescendantOf(ByteBuffer.class.getName()).named("arrayOffset")); + private static final Matcher BYTE_BUFFER_HAS_ARRAY_MATCHER = + anyOf(instanceMethod().onDescendantOf(ByteBuffer.class.getName()).named("hasArray")); + private static final Matcher BYTE_BUFFER_ALLOWED_INITIALIZERS_MATCHER = staticMethod().onClass(ByteBuffer.class.getName()).namedAnyOf("allocate", "wrap"); @@ -85,6 +94,10 @@ public Description matchMethodInvocation(MethodInvocationTree tree, VisitorState return Description.NO_MATCH; } + if (isGuardedByHasArrayTrueBranch(tree, bufferSymbol, state)) { + return Description.NO_MATCH; + } + // Checks for validating use on method scope. if (bufferSymbol.owner instanceof MethodSymbol methodSymbol) { MethodTree enclosingMethod = ASTHelpers.findMethod(methodSymbol, state); @@ -120,14 +133,15 @@ private static boolean isValidInitializerOrNotAByteBuffer( } /** - * Scan for a call to ByteBuffer.arrayOffset() or check if buffer was initialized with either - * ByteBuffer.wrap() or ByteBuffer.allocate(). + * Scan for a call to ByteBuffer.arrayOffset() or ByteBuffer.hasArray(), or check if buffer was + * initialized with either ByteBuffer.wrap() or ByteBuffer.allocate(). */ private static class ValidByteBufferArrayScanner extends TreeScanner { private final Symbol searchedBufferSymbol; private boolean visited; private boolean valid; + private boolean guardActive; static boolean scan(Tree tree, VisitorState state, Symbol searchedBufferSymbol) { ValidByteBufferArrayScanner visitor = new ValidByteBufferArrayScanner(searchedBufferSymbol); @@ -160,6 +174,9 @@ public Void visitMethodInvocation(MethodInvocationTree tree, VisitorState state) if (searchedBufferSymbol.equals(bufferSymbol)) { if (BYTE_BUFFER_ARRAY_MATCHER.matches(tree, state)) { visited = true; + if (guardActive) { + valid = true; + } } else if (BYTE_BUFFER_ARRAY_OFFSET_MATCHER.matches(tree, state)) { valid = true; } @@ -167,6 +184,37 @@ public Void visitMethodInvocation(MethodInvocationTree tree, VisitorState state) return super.visitMethodInvocation(tree, state); } + @Override + public Void visitIf(IfTree tree, VisitorState state) { + if (valid) { + return null; + } + boolean thenGuard = + conditionIsConjunctWithPositiveHasArray(tree.getCondition(), searchedBufferSymbol, state); + boolean elseGuard = + conditionIsNegatedConjunctWithPositiveHasArray( + tree.getCondition(), searchedBufferSymbol, state); + + // Scan THEN, optionally under guard + boolean oldValid = valid; + boolean oldGuard = guardActive; + guardActive = guardActive || thenGuard; + scan(tree.getThenStatement(), state); + valid = oldValid || valid; // preserve any success found in THEN + guardActive = oldGuard; + + // Scan ELSE, optionally under guard + if (tree.getElseStatement() != null) { + oldValid = valid; + oldGuard = guardActive; + guardActive = guardActive || elseGuard; + scan(tree.getElseStatement(), state); + valid = oldValid || valid; // preserve any success found in ELSE + guardActive = oldGuard; + } + return null; + } + private void checkForInitializer( Symbol foundSymbol, ExpressionTree expression, VisitorState state) { if (visited || valid) { @@ -184,6 +232,107 @@ private void checkForInitializer( } } + private static boolean isGuardedByHasArrayTrueBranch( + MethodInvocationTree arrayCall, Symbol bufferSymbol, VisitorState state) { + var path = state.getPath(); + while (path != null) { + Tree leaf = path.getLeaf(); + if (leaf instanceof IfTree ifTree) { + // array() is inside THEN branch guarded by a conjunction that includes hasArray() + if (containsTree(ifTree.getThenStatement(), arrayCall) + && conditionIsConjunctWithPositiveHasArray(ifTree.getCondition(), bufferSymbol, state)) { + return true; + } + // array() is inside ELSE branch and condition is negation of a conjunction including hasArray() + if (ifTree.getElseStatement() != null + && containsTree(ifTree.getElseStatement(), arrayCall) + && conditionIsNegatedConjunctWithPositiveHasArray( + ifTree.getCondition(), bufferSymbol, state)) { + return true; + } + } + path = path.getParentPath(); + } + return false; + } + + private static boolean containsTree(Tree container, Tree target) { + if (container == null) { + return false; + } + class Finder extends TreeScanner { + @Override + public Boolean reduce(Boolean r1, Boolean r2) { + return firstNonNull(r1, false) || firstNonNull(r2, false); + } + + @Override + public Boolean scan(Tree node, Void unused) { + if (node == null) { + return false; + } + if (node == target) { + return true; + } + return super.scan(node, unused); + } + } + return firstNonNull(container.accept(new Finder(), null), false); + } + + private static boolean conditionIsConjunctWithPositiveHasArray( + ExpressionTree condition, Symbol bufferSymbol, VisitorState state) { + return containsPositiveHasArray(condition, /*negated=*/ false, bufferSymbol, state) + && !containsLogicalOr(condition); + } + + private static boolean conditionIsNegatedConjunctWithPositiveHasArray( + ExpressionTree condition, Symbol bufferSymbol, VisitorState state) { + return containsPositiveHasArray(condition, /*negated=*/ true, bufferSymbol, state) + && !containsLogicalOr(condition); + } + + private static boolean containsPositiveHasArray( + ExpressionTree tree, boolean negated, Symbol bufferSymbol, VisitorState state) { + return switch (tree.getKind()) { + case PARENTHESIZED -> containsPositiveHasArray( + ((ParenthesizedTree) tree).getExpression(), negated, bufferSymbol, state); + case LOGICAL_COMPLEMENT -> containsPositiveHasArray( + ((UnaryTree) tree).getExpression(), !negated, bufferSymbol, state); + case CONDITIONAL_AND, CONDITIONAL_OR -> { + BinaryTree bt = (BinaryTree) tree; + yield containsPositiveHasArray(bt.getLeftOperand(), negated, bufferSymbol, state) + || containsPositiveHasArray(bt.getRightOperand(), negated, bufferSymbol, state); + } + default -> { + if (tree instanceof MethodInvocationTree mit + && BYTE_BUFFER_HAS_ARRAY_MATCHER.matches(mit, state)) { + Symbol recv = ASTHelpers.getSymbol(ASTHelpers.getReceiver(mit)); + yield !negated && bufferSymbol.equals(recv); + } + yield false; + } + }; + } + + private static boolean containsLogicalOr(ExpressionTree tree) { + class OrFinder extends TreeScanner { + @Override + public Boolean reduce(Boolean r1, Boolean r2) { + return firstNonNull(r1, false) || firstNonNull(r2, false); + } + + @Override + public Boolean visitBinary(BinaryTree node, Void unused) { + if (node.getKind() == Kind.CONDITIONAL_OR) { + return true; + } + return super.visitBinary(node, unused); + } + } + return firstNonNull(tree.accept(new OrFinder(), null), false); + } + /** Scan for a call to ByteBuffer.wrap() or ByteBuffer.allocate(). */ private static class ValidByteBufferInitializerScanner extends TreeScanner { diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/ByteBufferBackingArrayTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/ByteBufferBackingArrayTest.java index 48c882487f0..01f03ec24d4 100644 --- a/core/src/test/java/com/google/errorprone/bugpatterns/ByteBufferBackingArrayTest.java +++ b/core/src/test/java/com/google/errorprone/bugpatterns/ByteBufferBackingArrayTest.java @@ -104,6 +104,13 @@ void array_precededByNotAValidMethod_isFlagged() { // BUG: Diagnostic contains: ByteBuffer.array() buff.array(); } + + void array_precededByHasArray_isFlagged() { + ByteBuffer buffer = ByteBuffer.allocateDirect(10); + buffer.hasArray(); + // BUG: Diagnostic contains: ByteBuffer.array() + buffer.array(); + } }\ """) .doTest(); @@ -249,6 +256,33 @@ void array_inLambdaExpression_precededByByteBufferAllocate_isNotFlagged() { return null; }; } + + void array_precededByHasArray_inConditional_isNotFlagged() { + ByteBuffer buffer = ByteBuffer.allocateDirect(10); + if (buffer.hasArray()) { + buffer.array(); + } + } + + void array_inElseOfNegatedHasArray_isNotFlagged() { + ByteBuffer buffer = ByteBuffer.allocateDirect(10); + if (!buffer.hasArray()) { + // no array() here + } else { + buffer.array(); // safe due to else of !hasArray() + } + } + + void array_precededByHasArray_inIfElse_isNotFlagged() { + final int frameSize = 100; + final ByteBuffer buffer = ByteBuffer.allocateDirect(frameSize); + if (buffer.hasArray()) { + buffer.array(); + } else { + final byte[] array = new byte[frameSize]; + buffer.get(array); + } + } }\ """) .doTest();