@@ -5138,6 +5138,8 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
5138
5138
assert(UndefElts.empty() && "Expected an empty UndefElts vector");
5139
5139
assert(EltBits.empty() && "Expected an empty EltBits vector");
5140
5140
5141
+ Op = peekThroughBitcasts(Op);
5142
+
5141
5143
EVT VT = Op.getValueType();
5142
5144
unsigned SizeInBits = VT.getSizeInBits();
5143
5145
assert((SizeInBits % EltSizeInBits) == 0 && "Can't split constant!");
@@ -5170,35 +5172,35 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
5170
5172
return true;
5171
5173
};
5172
5174
5173
- // Extract constant bits from constant pool scalar/vector.
5175
+ auto ExtractConstantBits = [SizeInBits](const Constant *Cst, APInt &Mask,
5176
+ APInt &Undefs) {
5177
+ if (!Cst)
5178
+ return false;
5179
+ unsigned CstSizeInBits = Cst->getType()->getPrimitiveSizeInBits();
5180
+ if (isa<UndefValue>(Cst)) {
5181
+ Mask = APInt::getNullValue(SizeInBits);
5182
+ Undefs = APInt::getLowBitsSet(SizeInBits, CstSizeInBits);
5183
+ return true;
5184
+ }
5185
+ if (auto *CInt = dyn_cast<ConstantInt>(Cst)) {
5186
+ Mask = CInt->getValue().zextOrTrunc(SizeInBits);
5187
+ Undefs = APInt::getNullValue(SizeInBits);
5188
+ return true;
5189
+ }
5190
+ if (auto *CFP = dyn_cast<ConstantFP>(Cst)) {
5191
+ Mask = CFP->getValueAPF().bitcastToAPInt().zextOrTrunc(SizeInBits);
5192
+ Undefs = APInt::getNullValue(SizeInBits);
5193
+ return true;
5194
+ }
5195
+ return false;
5196
+ };
5197
+
5198
+ // Extract constant bits from constant pool vector.
5174
5199
if (auto *Cst = getTargetConstantFromNode(Op)) {
5175
5200
Type *CstTy = Cst->getType();
5176
5201
if (!CstTy->isVectorTy() || (SizeInBits != CstTy->getPrimitiveSizeInBits()))
5177
5202
return false;
5178
5203
5179
- auto ExtractConstantBits = [SizeInBits](const Constant *Cst, APInt &Mask,
5180
- APInt &Undefs) {
5181
- if (!Cst)
5182
- return false;
5183
- unsigned CstSizeInBits = Cst->getType()->getPrimitiveSizeInBits();
5184
- if (isa<UndefValue>(Cst)) {
5185
- Mask = APInt::getNullValue(SizeInBits);
5186
- Undefs = APInt::getLowBitsSet(SizeInBits, CstSizeInBits);
5187
- return true;
5188
- }
5189
- if (auto *CInt = dyn_cast<ConstantInt>(Cst)) {
5190
- Mask = CInt->getValue().zextOrTrunc(SizeInBits);
5191
- Undefs = APInt::getNullValue(SizeInBits);
5192
- return true;
5193
- }
5194
- if (auto *CFP = dyn_cast<ConstantFP>(Cst)) {
5195
- Mask = CFP->getValueAPF().bitcastToAPInt().zextOrTrunc(SizeInBits);
5196
- Undefs = APInt::getNullValue(SizeInBits);
5197
- return true;
5198
- }
5199
- return false;
5200
- };
5201
-
5202
5204
unsigned CstEltSizeInBits = CstTy->getScalarSizeInBits();
5203
5205
for (unsigned i = 0, e = CstTy->getVectorNumElements(); i != e; ++i) {
5204
5206
APInt Bits, Undefs;
@@ -5211,9 +5213,27 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
5211
5213
return SplitBitData();
5212
5214
}
5213
5215
5216
+ // Extract constant bits from a broadcasted constant pool scalar.
5217
+ if (Op.getOpcode() == X86ISD::VBROADCAST &&
5218
+ EltSizeInBits <= Op.getScalarValueSizeInBits()) {
5219
+ if (auto *Broadcast = getTargetConstantFromNode(Op.getOperand(0))) {
5220
+ APInt Bits, Undefs;
5221
+ if (ExtractConstantBits(Broadcast, Bits, Undefs)) {
5222
+ unsigned NumBroadcastBits = Op.getScalarValueSizeInBits();
5223
+ unsigned NumBroadcastElts = SizeInBits / NumBroadcastBits;
5224
+ for (unsigned i = 0; i != NumBroadcastElts; ++i) {
5225
+ MaskBits |= Bits.shl(i * NumBroadcastBits);
5226
+ UndefBits |= Undefs.shl(i * NumBroadcastBits);
5227
+ }
5228
+ return SplitBitData();
5229
+ }
5230
+ }
5231
+ }
5232
+
5214
5233
return false;
5215
5234
}
5216
5235
5236
+ // TODO: Merge more of this with getTargetConstantBitsFromNode.
5217
5237
static bool getTargetShuffleMaskIndices(SDValue MaskNode,
5218
5238
unsigned MaskEltSizeInBits,
5219
5239
SmallVectorImpl<uint64_t> &RawMask) {
0 commit comments