Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8346664: C2: Optimize mask check with constant offset #22856

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 69 additions & 75 deletions src/hotspot/share/opto/mulnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ const Type *AndINode::mul_ring( const Type *t0, const Type *t1 ) const {

const Type* AndINode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
if (AndIL_shift_and_mask_is_always_zero(phase, in(1), in(2), T_INT, true)) {
if (AndIL_is_always_zero(phase, in(1), in(2), T_INT, true)) {
return TypeInt::ZERO;
}

Expand Down Expand Up @@ -719,7 +719,7 @@ Node* AndINode::Identity(PhaseGVN* phase) {
//------------------------------Ideal------------------------------------------
Node *AndINode::Ideal(PhaseGVN *phase, bool can_reshape) {
// pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
Node* progress = AndIL_add_shift_and_mask(phase, T_INT);
Node* progress = AndIL_sum_and_mask(phase, T_INT);
if (progress != nullptr) {
return progress;
}
Expand Down Expand Up @@ -803,7 +803,7 @@ const Type *AndLNode::mul_ring( const Type *t0, const Type *t1 ) const {

const Type* AndLNode::Value(PhaseGVN* phase) const {
// patterns similar to (v << 2) & 3
if (AndIL_shift_and_mask_is_always_zero(phase, in(1), in(2), T_LONG, true)) {
if (AndIL_is_always_zero(phase, in(1), in(2), T_LONG, true)) {
return TypeLong::ZERO;
}

Expand Down Expand Up @@ -851,7 +851,7 @@ Node* AndLNode::Identity(PhaseGVN* phase) {
//------------------------------Ideal------------------------------------------
Node *AndLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
// pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
Node* progress = AndIL_add_shift_and_mask(phase, T_LONG);
Node* progress = AndIL_sum_and_mask(phase, T_LONG);
if (progress != nullptr) {
return progress;
}
Expand Down Expand Up @@ -2052,94 +2052,88 @@ const Type* RotateRightNode::Value(PhaseGVN* phase) const {
}
}

// Given an expression (AndX shift mask) or (AndX mask shift),
// Returns a lower bound of the number of trailing zeros in expr.
jint MulNode::AndIL_min_trailing_zeros(PhaseGVN* phase, Node* expr, BasicType bt) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a static function, I don't see much value in it being a method in MulNode.

expr = expr->uncast();
if (expr == nullptr) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be nullptr, you can safely remove it.

return 0;
}
const TypeInteger* type = phase->type(expr)->isa_integer(bt);
if (type == nullptr) {
return 0;
}

if (type->is_con()) {
long con = type->get_con_as_long(type->basic_type());
return con == 0L ? 0 : count_trailing_zeros(con);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of consistency, we should return the type width for con == 0, you can obtain this by type2aelementbytes(bt) * 8

}

if (expr->Opcode() == Op_ConvI2L) {
expr = expr->in(1);
if (expr == nullptr) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cannot be nullptr, you can safely remove it, the same for expr->uncast() below. In general, the only case when the input of a ConvI2L (and other nodes) not being an int is when it is top, which means it is empty. A.k.a unreachable code.

return 0;
}
expr = expr->uncast();
if (expr == nullptr) {
return 0;
}
type = phase->type(expr)->isa_int();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are trying to look through a ConvI2L, I think for the sake of consistency, you can reassign bt to T_INT at this point.

}

if (expr->Opcode() == Op_LShift(type->basic_type())) {
Node* rhs = expr->in(2);
if (rhs == nullptr) {
return 0;
}
const TypeInt* rhs_t = phase->type(rhs)->isa_int();
if (!rhs_t || !rhs_t->is_con()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are trying to avoid implicit conversion to bool, you can use an explicit rhs_t != nullptr here.

return 0;
}
return rhs_t->get_con() & ((type->isa_int() ? BitsPerJavaInteger : BitsPerJavaLong) - 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you reassign bt, you can do type2aelementbytes(bt), which IMO is clearer.

}

return 0;
}

// Given an expression (AndX expr mask) or (AndX mask expr),
// determine if the AndX must always produce zero, because the
// the shift (x<<N) is bitwise disjoint from the mask #M.
// expr is bitwise disjoint from the mask.
// The X in AndX must be I or L, depending on bt.
// Specifically, the following cases fold to zero,
// when the shift value N is large enough to zero out
// all the set positions of the and-mask M.
// (AndI (LShiftI _ #N) #M) => #0
// (AndL (LShiftL _ #N) #M) => #0
// (AndL (ConvI2L (LShiftI _ #N)) #M) => #0
// as well as for constant operands:
// (AndI (ConI [+-] _ << #N) #M) => #0
// (AndL (ConL [+-] _ << #N) #M) => #0
// The M and N values must satisfy ((-1 << N) & M) == 0.
// Because the optimization might work for a non-constant
// mask M, we check the AndX for both operand orders.
bool MulNode::AndIL_shift_and_mask_is_always_zero(PhaseGVN* phase, Node* shift, Node* mask, BasicType bt, bool check_reverse) {
if (mask == nullptr || shift == nullptr) {
// mask M, we check for both operand orders.
bool MulNode::AndIL_is_always_zero(PhaseGVN* phase, Node* expr, Node* mask, BasicType bt, bool check_reverse) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you cannot conclude that ((x + y) & m) == 0 iff (x & m) == 0 when (y & m) == 0 because the addition x + y can carry some bit into the positions at which m is set. Consider this example for illustration:

(0b1010 + 0b0010) & 0b0100 == 0b1100 & 0b0100 == 0b0100 != 0

even when

0b1010 & 0b0100 == 0
0b0010 & 0b0100 == 0

The most trivial sufficient condition we are using here is that the lowest bit set of y is larger than the highest bit set of m. Because then adding y into x does not carry any bit into the result that is set in m but not set in x. This method can be a static function, too IMO.

if (mask == nullptr || expr == nullptr) {
return false;
}
const TypeInteger* mask_t = phase->type(mask)->isa_integer(bt);
if (mask_t == nullptr || phase->type(shift)->isa_integer(bt) == nullptr) {
if (mask_t == nullptr) {
return false;
}
shift = shift->uncast();
if (shift == nullptr) {
return false;
}
if (phase->type(shift)->isa_integer(bt) == nullptr) {
return false;
}
BasicType shift_bt = bt;
if (bt == T_LONG && shift->Opcode() == Op_ConvI2L) {
bt = T_INT;
Node* val = shift->in(1);
if (val == nullptr) {
return false;
}
val = val->uncast();
if (val == nullptr) {
return false;
}
if (val->Opcode() == Op_LShiftI) {
shift_bt = T_INT;
shift = val;
if (phase->type(shift)->isa_integer(bt) == nullptr) {
return false;
}
}
}
if (shift->Opcode() != Op_LShift(shift_bt)) {
if (check_reverse &&
(mask->Opcode() == Op_LShift(bt) ||
(bt == T_LONG && mask->Opcode() == Op_ConvI2L))) {
// try it the other way around
return AndIL_shift_and_mask_is_always_zero(phase, mask, shift, bt, false);
}
return false;
}
Node* shift2 = shift->in(2);
if (shift2 == nullptr) {
return false;
}
const Type* shift2_t = phase->type(shift2);
if (!shift2_t->isa_int() || !shift2_t->is_int()->is_con()) {
return false;
}

jint shift_con = shift2_t->is_int()->get_con() & ((shift_bt == T_INT ? BitsPerJavaInteger : BitsPerJavaLong) - 1);
if ((((jlong)1) << shift_con) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0) {
return true;
jint zeros = AndIL_min_trailing_zeros(phase, expr, bt);
if (zeros == 0) {
// try it the other way around
return check_reverse && AndIL_is_always_zero(phase, mask, expr, bt, false);
}

return false;
return ((((jlong)1) << zeros) > mask_t->hi_as_long() && mask_t->lo_as_long() >= 0);
}

// Given an expression (AndX (AddX v1 (LShiftX v2 #N)) #M)
// determine if the AndX must always produce (AndX v1 #M),
// because the shift (v2<<N) is bitwise disjoint from the mask #M.
// The X in AndX will be I or L, depending on bt.
// Specifically, the following cases fold,
// when the shift value N is large enough to zero out
// all the set positions of the and-mask M.
// (AndI (AddI v1 (LShiftI _ #N)) #M) => (AndI v1 #M)
// (AndL (AddI v1 (LShiftL _ #N)) #M) => (AndL v1 #M)
// (AndL (AddL v1 (ConvI2L (LShiftI _ #N))) #M) => (AndL v1 #M)
// The M and N values must satisfy ((-1 << N) & M) == 0.
// Because the optimization might work for a non-constant
// mask M, and because the AddX operands can come in either
// order, we check for every operand order.
Node* MulNode::AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt) {
// Given an expression (AndX (AddX v1 v2) mask)
// determine if the AndX must always produce (AndX v1 mask),
// because v2 is bitwise disjoint from the mask.
// Because the AddX operands can come in either
// order, we check for both orders.
Node* MulNode::AndIL_sum_and_mask(PhaseGVN* phase, BasicType bt) {
Node* add = in(1);
Node* mask = in(2);
if (add == nullptr || mask == nullptr) {
Expand All @@ -2157,10 +2151,10 @@ Node* MulNode::AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt) {
Node* add1 = add->in(1);
Node* add2 = add->in(2);
if (add1 != nullptr && add2 != nullptr) {
if (AndIL_shift_and_mask_is_always_zero(phase, add1, mask, bt, false)) {
if (AndIL_is_always_zero(phase, add1, mask, bt, false)) {
set_req_X(addidx, add2, phase);
return this;
} else if (AndIL_shift_and_mask_is_always_zero(phase, add2, mask, bt, false)) {
} else if (AndIL_is_always_zero(phase, add2, mask, bt, false)) {
set_req_X(addidx, add1, phase);
return this;
}
Expand Down
5 changes: 3 additions & 2 deletions src/hotspot/share/opto/mulnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ class MulNode : public Node {

static MulNode* make(Node* in1, Node* in2, BasicType bt);

static bool AndIL_shift_and_mask_is_always_zero(PhaseGVN* phase, Node* shift, Node* mask, BasicType bt, bool check_reverse);
Node* AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt);
static jint AndIL_min_trailing_zeros(PhaseGVN* phase, Node* addend, BasicType bt);
static bool AndIL_is_always_zero(PhaseGVN* phase, Node* expr, Node* mask, BasicType bt, bool check_reverse);
Node* AndIL_sum_and_mask(PhaseGVN* phase, BasicType bt);
};

//------------------------------MulINode---------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions test/hotspot/jtreg/compiler/c2/irTests/TestShiftAndMask.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public static void checkShiftNonConstMaskLong(long res) {
@IR(counts = { IRNode.AND_I, "1" })
@IR(failOn = { IRNode.ADD_I, IRNode.LSHIFT_I })
public static int addShiftMaskInt(int i, int j) {
return (j + (i << 2)) & 3; // transformed to: return j & 3;
return (j + ((i + 1) << 2)) & 3; // transformed to: return j & 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer you adding other test cases instead of modifying existing ones.

}

@Run(test = "addShiftMaskInt")
Expand Down Expand Up @@ -165,7 +165,7 @@ public static void addSshiftNonConstMaskInt_runner() {
@IR(counts = { IRNode.AND_L, "1" })
@IR(failOn = { IRNode.ADD_L, IRNode.LSHIFT_L })
public static long addShiftMaskLong(long i, long j) {
return (j + (i << 2)) & 3; // transformed to: return j & 3;
return (j + ((i - 3) << 2)) & 3; // transformed to: return j & 3;
}

@Run(test = "addShiftMaskLong")
Expand Down