Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8346664: C2: Optimize mask check with constant offset #22856

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 66 additions & 92 deletions src/hotspot/share/opto/mulnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,13 @@ const Type *AndINode::mul_ring( const Type *t0, const Type *t1 ) const {
return and_value<TypeInt>(r0, r1);
}

// Is expr a neutral element wrt addition under mask?
static bool AndIL_is_zero_element(const PhaseGVN* phase, const Node* expr, const Node* mask, BasicType bt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer a more expressive function name over a comment here. The comment is a little confusing to me too.

The old name at least talked about shift and mask - is that not relevant any more?

Or you just decide to name it AndIL_is_zero, and drop out the comment. Because who knows someone might add other things that check for zero in that method, and then your comment would be out-dated (but probably people would forget to adjust it).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to is_zero_element_under_mask. "zero element" for me drives down that it's neither

  • checking for expr == 0
  • nor checking for expr & mask == 0
    but really (X + expr) & mask == X & mask for all X.

There is no requirement for shift node, e.g, we recognize is_zero_under_mask(192, 7). However, the constants that are recognized here are "shifts in spirit" (e.g. expanded from (i + 24) << 3). If you can think of a good term for this that doesn't suggest there's an actual "shift node" we could try and incorporate that.

Since this is a forward declaration, the elaborate comment is below. Went and dropped the short one here.


const Type* AndINode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this comment be updated now for a more general pattern?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to drop this and refer the reader to the definition.

if (AndIL_shift_and_mask_is_always_zero(phase, in(1), in(2), T_INT, true)) {
if (AndIL_is_zero_element(phase, in(1), in(2), T_INT) ||
AndIL_is_zero_element(phase, in(2), in(1), T_INT)) {
Copy link
Author

@mernst-github mernst-github Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it easier to reason about the "reverse" check when we simply expand it here.

return TypeInt::ZERO;
}

Expand Down Expand Up @@ -719,7 +723,7 @@ Node* AndINode::Identity(PhaseGVN* phase) {
//------------------------------Ideal------------------------------------------
Node *AndINode::Ideal(PhaseGVN *phase, bool can_reshape) {
// pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
Node* progress = AndIL_add_shift_and_mask(phase, T_INT);
Node* progress = AndIL_sum_and_mask(phase, T_INT);
if (progress != nullptr) {
return progress;
}
Expand Down Expand Up @@ -803,7 +807,8 @@ const Type *AndLNode::mul_ring( const Type *t0, const Type *t1 ) const {

const Type* AndLNode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
if (AndIL_shift_and_mask_is_always_zero(phase, in(1), in(2), T_LONG, true)) {
if (AndIL_is_zero_element(phase, in(1), in(2), T_LONG) ||
AndIL_is_zero_element(phase, in(2), in(1), T_LONG)) {
return TypeLong::ZERO;
}

Expand Down Expand Up @@ -851,7 +856,7 @@ Node* AndLNode::Identity(PhaseGVN* phase) {
//------------------------------Ideal------------------------------------------
Node *AndLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
// pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
Node* progress = AndIL_add_shift_and_mask(phase, T_LONG);
Node* progress = AndIL_sum_and_mask(phase, T_LONG);
if (progress != nullptr) {
return progress;
}
Expand Down Expand Up @@ -2052,99 +2057,70 @@ const Type* RotateRightNode::Value(PhaseGVN* phase) const {
}
}

// Given an expression (AndX shift mask) or (AndX mask shift),
// determine if the AndX must always produce zero, because the
// the shift (x<<N) is bitwise disjoint from the mask #M.
// The X in AndX must be I or L, depending on bt.
// Specifically, the following cases fold to zero,
// when the shift value N is large enough to zero out
// all the set positions of the and-mask M.
// (AndI (LShiftI _ #N) #M) => #0
// (AndL (LShiftL _ #N) #M) => #0
// (AndL (ConvI2L (LShiftI _ #N)) #M) => #0
// The M and N values must satisfy ((-1 << N) & M) == 0.
// Because the optimization might work for a non-constant
// mask M, we check the AndX for both operand orders.
bool MulNode::AndIL_shift_and_mask_is_always_zero(PhaseGVN* phase, Node* shift, Node* mask, BasicType bt, bool check_reverse) {
if (mask == nullptr || shift == nullptr) {
return false;
}
const TypeInteger* mask_t = phase->type(mask)->isa_integer(bt);
if (mask_t == nullptr || phase->type(shift)->isa_integer(bt) == nullptr) {
return false;
}
shift = shift->uncast();
if (shift == nullptr) {
return false;
//------------------------------ Sum & Mask ------------------------------

// Returns a lower bound on the number of trailing zeros in expr.
static jint AndIL_min_trailing_zeros(const PhaseGVN* phase, const Node* expr, BasicType bt) {
Comment on lines +2059 to +2060
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method restricted to use in AndIL? Because it looks like it is doing something more generic: trying to figure out a lower bound on the trailing zeros of an expression.

If that is the case: Why not put it in Node::get_trailing_zeros_lower_bound(phase, bt), so it can be used elsewhere too?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue that while this might be incidentally reusable outside of the scope of "And" nodes, as long as there is no actual demand to reuse this, I would rather not add it to the rather prominent Node class to avoid api bloat.

Iff the notion of "is known to be a multiple of a certain power of two" is really of general interest, I would expect it to become part of TypeInteger.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, just leave it where it is for now. I'm ok with it.

expr = expr->uncast();
const TypeInteger* type = phase->type(expr)->isa_integer(bt);
if (type == nullptr) {
return 0;
}
if (phase->type(shift)->isa_integer(bt) == nullptr) {
return false;

if (type->is_con()) {
long con = type->get_con_as_long(bt);
return con == 0L ? (type2aelembytes(bt) * BitsPerByte) : count_trailing_zeros(con);
}
BasicType shift_bt = bt;
if (bt == T_LONG && shift->Opcode() == Op_ConvI2L) {

if (expr->Opcode() == Op_ConvI2L) {
expr = expr->in(1)->uncast();
bt = T_INT;
Node* val = shift->in(1);
if (val == nullptr) {
return false;
}
val = val->uncast();
if (val == nullptr) {
return false;
}
if (val->Opcode() == Op_LShiftI) {
shift_bt = T_INT;
shift = val;
if (phase->type(shift)->isa_integer(bt) == nullptr) {
return false;
}
}
}
if (shift->Opcode() != Op_LShift(shift_bt)) {
if (check_reverse &&
(mask->Opcode() == Op_LShift(bt) ||
(bt == T_LONG && mask->Opcode() == Op_ConvI2L))) {
// try it the other way around
return AndIL_shift_and_mask_is_always_zero(phase, mask, shift, bt, false);
}
return false;
}
Node* shift2 = shift->in(2);
if (shift2 == nullptr) {
return false;
}
const Type* shift2_t = phase->type(shift2);
if (!shift2_t->isa_int() || !shift2_t->is_int()->is_con()) {
return false;
type = phase->type(expr)->isa_int();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are trying to look through a ConvI2L, I think for the sake of consistency, you can reassign bt to T_INT at this point.

Copy link
Author

@mernst-github mernst-github Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

jint shift_con = shift2_t->is_int()->get_con() & ((shift_bt == T_INT ? BitsPerJavaInteger : BitsPerJavaLong) - 1);
if ((((jlong)1) << shift_con) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0) {
return true;
if (expr->Opcode() == Op_LShift(bt)) {
const TypeInt* rhs_t = phase->type(expr->in(2))->isa_int();
if (rhs_t == nullptr || !rhs_t->is_con()) {
return 0;
}
return rhs_t->get_con() % (type2aelembytes(bt) * BitsPerByte);
}

return false;
return 0;
}

// Given an expression (AndX (AddX v1 (LShiftX v2 #N)) #M)
// determine if the AndX must always produce (AndX v1 #M),
// because the shift (v2<<N) is bitwise disjoint from the mask #M.
// The X in AndX will be I or L, depending on bt.
// Specifically, the following cases fold,
// Given an expression (AndX T+expr mask), determine
// whether expr is neutral wrt addition under mask
// and hence the result is always equivalent to (AndX T mask),
// The X in AndX must be I or L, depending on bt.
// Specifically, this holds for the following cases,
// when the shift value N is large enough to zero out
// all the set positions of the and-mask M.
// (AndI (AddI v1 (LShiftI _ #N)) #M) => (AndI v1 #M)
// (AndL (AddI v1 (LShiftL _ #N)) #M) => (AndL v1 #M)
// (AndL (AddL v1 (ConvI2L (LShiftI _ #N))) #M) => (AndL v1 #M)
// all the set positions of the and-mask M:
// (AndI (LShiftI _ #N) #M)
// (AndL (LShiftL _ #N) #M)
// (AndL (ConvI2L (LShiftI _ #N)) #M)
// including constant operands:
// (AndI (ConI (_ << #N)) #M)
// (AndL (ConL (_ << #N)) #M)
// The M and N values must satisfy ((-1 << N) & M) == 0.
// Because the optimization might work for a non-constant
// mask M, and because the AddX operands can come in either
// order, we check for every operand order.
Node* MulNode::AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt) {
static bool AndIL_is_zero_element(const PhaseGVN* phase, const Node* expr, const Node* mask, BasicType bt) {
const TypeInteger* mask_t = phase->type(mask)->isa_integer(bt);
if (mask_t == nullptr) {
return false;
}

jint zeros = AndIL_min_trailing_zeros(phase, expr, bt);
return zeros > 0 && ((((jlong)1) << zeros) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line indicates that the mask could be a variable. You should make sure to add some tests for that in your patterns. You can create a variable in a specific range like this Math.min(5, Math.max(1, x)), should get you x clamped into the region 1..5.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a variant for adding consts using the same pattern as the other "NonConstMask" tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please point me to the tests where mask_t is a range, and not a constant?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests addConstNonConstMask[Int,Long] (https://github.com/openjdk/jdk/pull/22856/files#diff-2c6beb2b7bcb76601adb439471e786963c6e0d5cb6db132381f64e10df5819daR207), copied from the existing NonConstMask tests exercise this.

Comment on lines +2106 to +2112
Copy link
Author

@mernst-github mernst-github Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I came up with another way to formulate this. I think it makes even more clear how this is all about "expr is shifted to be completely left of the mask", i.e. we're comparing right-most value bit position to left-most mask bit position, all other bits are irrelevant.

This could also inform how to bias the random generator. Let me know what you think.

Suggested change
const TypeInteger* mask_t = phase->type(mask)->isa_integer(bt);
if (mask_t == nullptr) {
return false;
}
jint zeros = AndIL_min_trailing_zeros(phase, expr, bt);
return zeros > 0 && ((((jlong)1) << zeros) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0);
jint expr_trailing_zeros = AndIL_min_trailing_zeros(phase, expr, bt);
if (expr_trailing_zeros == 0) return false; // zero mask handled in MulNode::Value
const TypeInteger* mask_t = phase->type(mask)->isa_integer(bt);
if (mask_t == nullptr || mask_t->lo_as_long() < 0) return false;
jint mask_bit_width = mask_t->hi_as_long() == 0 ? 0 : (BitsPerLong - count_leading_zeros(mask_t->hi_as_long()));
return expr_trailing_zeros >= mask_bit_width;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask_t->lo_as_long() == 0 does not imply mask == 0, though. Other than that I think it is a great suggestion.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Fixed above. I was trying to get around special casing 0 (count_*_zeros don't like them).

}

// Given an expression (AndX (AddX v1 v2) mask)
// determine if the AndX must always produce (AndX v1 mask),
// because v2 is zero wrt addition under mask.
// Because the AddX operands can come in either
// order, we check for both orders.
Node* MulNode::AndIL_sum_and_mask(PhaseGVN* phase, BasicType bt) {
Node* add = in(1);
Node* mask = in(2);
if (add == nullptr || mask == nullptr) {
return nullptr;
}
int addidx = 0;
if (add->Opcode() == Op_Add(bt)) {
addidx = 1;
Expand All @@ -2156,14 +2132,12 @@ Node* MulNode::AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt) {
if (addidx > 0) {
Node* add1 = add->in(1);
Node* add2 = add->in(2);
if (add1 != nullptr && add2 != nullptr) {
if (AndIL_shift_and_mask_is_always_zero(phase, add1, mask, bt, false)) {
set_req_X(addidx, add2, phase);
return this;
} else if (AndIL_shift_and_mask_is_always_zero(phase, add2, mask, bt, false)) {
set_req_X(addidx, add1, phase);
return this;
}
if (AndIL_is_zero_element(phase, add1, mask, bt)) {
set_req_X(addidx, add2, phase);
return this;
} else if (AndIL_is_zero_element(phase, add2, mask, bt)) {
set_req_X(addidx, add1, phase);
return this;
}
}
return nullptr;
Expand Down
4 changes: 2 additions & 2 deletions src/hotspot/share/opto/mulnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class MulNode : public Node {

static MulNode* make(Node* in1, Node* in2, BasicType bt);

static bool AndIL_shift_and_mask_is_always_zero(PhaseGVN* phase, Node* shift, Node* mask, BasicType bt, bool check_reverse);
Node* AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt);
protected:
Node* AndIL_sum_and_mask(PhaseGVN* phase, BasicType bt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please update the copyright from 2024 -> 2025?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

};

//------------------------------MulINode---------------------------------------
Expand Down
43 changes: 42 additions & 1 deletion test/hotspot/jtreg/compiler/c2/irTests/TestShiftAndMask.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

/*
* @test
* @bug 8277850 8278949 8285793
* @bug 8277850 8278949 8285793 8346664
* @summary C2: optimize mask checks in counted loops
* @library /test/lib /
* @run driver compiler.c2.irTests.TestShiftAndMask
Expand Down Expand Up @@ -133,6 +133,30 @@ public static void addShiftMaskInt_runner() {
}
}

@Test
@IR(counts = { IRNode.AND_I, "1" })
@IR(failOn = { IRNode.ADD_I, IRNode.LSHIFT_I })
public static int addShiftPlusConstMaskInt(int i, int j) {
return (j + ((i + 5) << 2)) & 3; // transformed to: return j & 3;
}

@Run(test = "addShiftPlusConstMaskInt")
public static void addShiftPlusConstMaskInt_runner() {
int i = RANDOM.nextInt();
int j = RANDOM.nextInt();
int res = addShiftPlusConstMaskInt(i, j);
if (res != (j & 3)) {
throw new RuntimeException("incorrect result: " + res);
}
}

@Test
@Arguments(values = {Argument.RANDOM_EACH, Argument.RANDOM_EACH})
@IR(counts = { IRNode.ADD_I, "2", IRNode.LSHIFT_I, "1" })
public static int addShiftPlusConstDisjointMaskInt(int i, int j) {
return (j + ((i + 5) << 2)) & 32; // NOT transformed even though (5<<2) & 32 == 0
}

@Test
@IR(counts = { IRNode.AND_I, "1" })
@IR(failOn = { IRNode.ADD_I, IRNode.LSHIFT_I })
Expand Down Expand Up @@ -178,6 +202,23 @@ public static void addShiftMaskLong_runner() {
}
}

@Test
@IR(counts = { IRNode.AND_L, "1" })
@IR(failOn = { IRNode.ADD_L, IRNode.LSHIFT_L })
public static long addShiftPlusConstMaskLong(long i, long j) {
return (j + ((i - 5) << 2)) & 3; // transformed to: return j & 3;
}

@Run(test = "addShiftPlusConstMaskLong")
public static void addShiftPlusConstMaskLong_runner() {
long i = RANDOM.nextLong();
long j = RANDOM.nextLong();
long res = addShiftPlusConstMaskLong(i, j);
if (res != (j & 3)) {
throw new RuntimeException("incorrect result: " + res);
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to see that you have some examples here!

I think it would be great to have some more though. The divil hides in the details. In the edge cases usually.

You currently have patterns like this:
(j + ((i + c1) << c2)) & c3;
What if you generate the constants c1, c2, c3 randomly:
public static final int C1 = random.nextInt() (or some other random distribution that makes more sense).
Then the compiler will see them as constants (because final), and attempt constant folding.

You can then do result verification: You create a method copy that you restrict to the interpreter, and the other copy can be compiled. Then you test the method with all sorts of random inputs for i, j, and verify the results of the two methods (compiled vs interpreted).

Maybe you can add some more patterns as well, just to have a better test coverage.

Does that make sense?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I understand the intent, and I've now randomized all constant masks / shifts / consts in this file. But just to make sure: IIUC the tests are only compiled once per invocation, there is no way I can tell the framework to "C2 compile this x times with different random constants". I.e. I can make test this a hundred times locally, but I cannot create large coverage via the framework, right?

Also not quite sure I understand the verification proposal. How would that be different from the current comparisons if (result != expected simplified form) ? Now if the framework supported an automatic comparison of compiled vs interpreted invocation, that would be nice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, creating large coverage with a single run under the IR framework is not currently possible I think.

Generally, there are other tricks to get "changing constants", see what I did with setConstant and int_con in this test:
test/hotspot/jtreg/compiler/loopopts/superword/TestAlignVectorFuzzer.java

But the tests are rerun a lot anyway, so that is not super necessary.

I am working on a Template framework that makes using random constants much easier, and also generating multiple methods where only the constants differ. That should make things a little easier.

I suppose that works: if (result != expected simplified form)
Though only for cases where we have a valid simplification. If you also want to test the cases that have a very similar pattern, but should not accidentally wrongly optimize, then you would have to do the compiled/interpreted comparison.

@Test
@IR(counts = { IRNode.AND_L, "1" })
@IR(failOn = { IRNode.ADD_L, IRNode.LSHIFT_L })
Expand Down