Skip to content

Commit 25d43b7

Browse files
author
Elena Demikhovsky
committed
AVX-512: Optimized pattern for truncate with unsigned saturation.
DAG patterns optimization: truncate + unsigned saturation supported by VPMOVUS* instructions in AVX-512. Differential revision: https://reviews.llvm.org/D28216 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@291092 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 388e476 commit 25d43b7

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

lib/Target/X86/X86ISelLowering.cpp

+63
Original file line numberDiff line numberDiff line change
@@ -30921,6 +30921,59 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
3092130921
return DAG.getNode(X86ISD::PCMPGT, SDLoc(N), VT, Shift.getOperand(0), Ones);
3092230922
}
3092330923

30924+
/// Check if truncation with saturation form type \p SrcVT to \p DstVT
30925+
/// is valid for the given \p Subtarget.
30926+
static bool
30927+
isSATValidOnSubtarget(EVT SrcVT, EVT DstVT, const X86Subtarget &Subtarget) {
30928+
if (!Subtarget.hasAVX512())
30929+
return false;
30930+
EVT SrcElVT = SrcVT.getScalarType();
30931+
EVT DstElVT = DstVT.getScalarType();
30932+
if (SrcElVT.getSizeInBits() < 16 || SrcElVT.getSizeInBits() > 64)
30933+
return false;
30934+
if (DstElVT.getSizeInBits() < 8 || DstElVT.getSizeInBits() > 32)
30935+
return false;
30936+
if (SrcVT.is512BitVector() || Subtarget.hasVLX())
30937+
return SrcElVT.getSizeInBits() >= 32 || Subtarget.hasBWI();
30938+
return false;
30939+
}
30940+
30941+
/// Detect a pattern of truncation with saturation:
30942+
/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
30943+
/// Return the source value to be truncated or SDValue() if the pattern was not
30944+
/// matched or the unsupported on the current target.
30945+
static SDValue
30946+
detectUSatPattern(SDValue In, EVT VT, const X86Subtarget &Subtarget) {
30947+
if (In.getOpcode() != ISD::UMIN)
30948+
return SDValue();
30949+
30950+
EVT InVT = In.getValueType();
30951+
// FIXME: Scalar type may be supported if we move it to vector register.
30952+
if (!InVT.isVector() || !InVT.isSimple())
30953+
return SDValue();
30954+
30955+
if (!isSATValidOnSubtarget(InVT, VT, Subtarget))
30956+
return SDValue();
30957+
30958+
//Saturation with truncation. We truncate from InVT to VT.
30959+
assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
30960+
"Unexpected types for truncate operation");
30961+
30962+
SDValue SrcVal;
30963+
APInt C;
30964+
if (ISD::isConstantSplatVector(In.getOperand(0).getNode(), C))
30965+
SrcVal = In.getOperand(1);
30966+
else if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C))
30967+
SrcVal = In.getOperand(0);
30968+
else
30969+
return SDValue();
30970+
30971+
// C should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
30972+
// the element size of the destination type.
30973+
return (C == ((uint64_t)1 << VT.getScalarSizeInBits()) - 1) ?
30974+
SrcVal : SDValue();
30975+
}
30976+
3092430977
/// This function detects the AVG pattern between vectors of unsigned i8/i16,
3092530978
/// which is c = (a + b + 1) / 2, and replace this operation with the efficient
3092630979
/// X86ISD::AVG instruction.
@@ -31487,6 +31540,12 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
3148731540
St->getPointerInfo(), St->getAlignment(),
3148831541
St->getMemOperand()->getFlags());
3148931542

31543+
if (SDValue Val =
31544+
detectUSatPattern(St->getValue(), St->getMemoryVT(), Subtarget))
31545+
return EmitTruncSStore(false /* Unsigned saturation */, St->getChain(),
31546+
dl, Val, St->getBasePtr(),
31547+
St->getMemoryVT(), St->getMemOperand(), DAG);
31548+
3149031549
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
3149131550
unsigned NumElems = VT.getVectorNumElements();
3149231551
assert(StVT != VT && "Cannot truncate to the same type");
@@ -32104,6 +32163,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
3210432163
if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL))
3210532164
return Avg;
3210632165

32166+
// Try the truncation with unsigned saturation.
32167+
if (SDValue Val = detectUSatPattern(Src, VT, Subtarget))
32168+
return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, Val);
32169+
3210732170
// The bitcast source is a direct mmx result.
3210832171
// Detect bitcasts between i32 to x86mmx
3210932172
if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) {

test/CodeGen/X86/avx512-trunc.ll

+107
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,110 @@ define void @trunc_wb_128_mem(<8 x i16> %i, <8 x i8>* %res) #0 {
500500
store <8 x i8> %x, <8 x i8>* %res
501501
ret void
502502
}
503+
504+
505+
define void @usat_trunc_wb_256_mem(<16 x i16> %i, <16 x i8>* %res) {
506+
; KNL-LABEL: usat_trunc_wb_256_mem:
507+
; KNL: ## BB#0:
508+
; KNL-NEXT: vpminuw {{.*}}(%rip), %ymm0, %ymm0
509+
; KNL-NEXT: vpmovsxwd %ymm0, %zmm0
510+
; KNL-NEXT: vpmovdb %zmm0, %xmm0
511+
; KNL-NEXT: vmovdqu %xmm0, (%rdi)
512+
; KNL-NEXT: retq
513+
;
514+
; SKX-LABEL: usat_trunc_wb_256_mem:
515+
; SKX: ## BB#0:
516+
; SKX-NEXT: vpmovuswb %ymm0, (%rdi)
517+
; SKX-NEXT: retq
518+
%x3 = icmp ult <16 x i16> %i, <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
519+
%x5 = select <16 x i1> %x3, <16 x i16> %i, <16 x i16> <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
520+
%x6 = trunc <16 x i16> %x5 to <16 x i8>
521+
store <16 x i8> %x6, <16 x i8>* %res, align 1
522+
ret void
523+
}
524+
525+
define <16 x i8> @usat_trunc_wb_256(<16 x i16> %i) {
526+
; KNL-LABEL: usat_trunc_wb_256:
527+
; KNL: ## BB#0:
528+
; KNL-NEXT: vpminuw {{.*}}(%rip), %ymm0, %ymm0
529+
; KNL-NEXT: vpmovsxwd %ymm0, %zmm0
530+
; KNL-NEXT: vpmovdb %zmm0, %xmm0
531+
; KNL-NEXT: retq
532+
;
533+
; SKX-LABEL: usat_trunc_wb_256:
534+
; SKX: ## BB#0:
535+
; SKX-NEXT: vpmovuswb %ymm0, %xmm0
536+
; SKX-NEXT: retq
537+
%x3 = icmp ult <16 x i16> %i, <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
538+
%x5 = select <16 x i1> %x3, <16 x i16> %i, <16 x i16> <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
539+
%x6 = trunc <16 x i16> %x5 to <16 x i8>
540+
ret <16 x i8> %x6
541+
}
542+
543+
define void @usat_trunc_wb_128_mem(<8 x i16> %i, <8 x i8>* %res) {
544+
; KNL-LABEL: usat_trunc_wb_128_mem:
545+
; KNL: ## BB#0:
546+
; KNL-NEXT: vpminuw {{.*}}(%rip), %xmm0, %xmm0
547+
; KNL-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,2,4,6,8,10,12,14,u,u,u,u,u,u,u,u]
548+
; KNL-NEXT: vmovq %xmm0, (%rdi)
549+
; KNL-NEXT: retq
550+
;
551+
; SKX-LABEL: usat_trunc_wb_128_mem:
552+
; SKX: ## BB#0:
553+
; SKX-NEXT: vpmovuswb %xmm0, (%rdi)
554+
; SKX-NEXT: retq
555+
%x3 = icmp ult <8 x i16> %i, <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
556+
%x5 = select <8 x i1> %x3, <8 x i16> %i, <8 x i16> <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
557+
%x6 = trunc <8 x i16> %x5 to <8 x i8>
558+
store <8 x i8> %x6, <8 x i8>* %res, align 1
559+
ret void
560+
}
561+
562+
define void @usat_trunc_db_512_mem(<16 x i32> %i, <16 x i8>* %res) {
563+
; ALL-LABEL: usat_trunc_db_512_mem:
564+
; ALL: ## BB#0:
565+
; ALL-NEXT: vpmovusdb %zmm0, (%rdi)
566+
; ALL-NEXT: retq
567+
%x3 = icmp ult <16 x i32> %i, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
568+
%x5 = select <16 x i1> %x3, <16 x i32> %i, <16 x i32> <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
569+
%x6 = trunc <16 x i32> %x5 to <16 x i8>
570+
store <16 x i8> %x6, <16 x i8>* %res, align 1
571+
ret void
572+
}
573+
574+
define void @usat_trunc_qb_512_mem(<8 x i64> %i, <8 x i8>* %res) {
575+
; ALL-LABEL: usat_trunc_qb_512_mem:
576+
; ALL: ## BB#0:
577+
; ALL-NEXT: vpmovusqb %zmm0, (%rdi)
578+
; ALL-NEXT: retq
579+
%x3 = icmp ult <8 x i64> %i, <i64 255, i64 255, i64 255, i64 255, i64 255, i64 255, i64 255, i64 255>
580+
%x5 = select <8 x i1> %x3, <8 x i64> %i, <8 x i64> <i64 255, i64 255, i64 255, i64 255, i64 255, i64 255, i64 255, i64 255>
581+
%x6 = trunc <8 x i64> %x5 to <8 x i8>
582+
store <8 x i8> %x6, <8 x i8>* %res, align 1
583+
ret void
584+
}
585+
586+
define void @usat_trunc_qd_512_mem(<8 x i64> %i, <8 x i32>* %res) {
587+
; ALL-LABEL: usat_trunc_qd_512_mem:
588+
; ALL: ## BB#0:
589+
; ALL-NEXT: vpmovusqd %zmm0, (%rdi)
590+
; ALL-NEXT: retq
591+
%x3 = icmp ult <8 x i64> %i, <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>
592+
%x5 = select <8 x i1> %x3, <8 x i64> %i, <8 x i64> <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>
593+
%x6 = trunc <8 x i64> %x5 to <8 x i32>
594+
store <8 x i32> %x6, <8 x i32>* %res, align 1
595+
ret void
596+
}
597+
598+
define void @usat_trunc_qw_512_mem(<8 x i64> %i, <8 x i16>* %res) {
599+
; ALL-LABEL: usat_trunc_qw_512_mem:
600+
; ALL: ## BB#0:
601+
; ALL-NEXT: vpmovusqw %zmm0, (%rdi)
602+
; ALL-NEXT: retq
603+
%x3 = icmp ult <8 x i64> %i, <i64 65535, i64 65535, i64 65535, i64 65535, i64 65535, i64 65535, i64 65535, i64 65535>
604+
%x5 = select <8 x i1> %x3, <8 x i64> %i, <8 x i64> <i64 65535, i64 65535, i64 65535, i64 65535, i64 65535, i64 65535, i64 65535, i64 65535>
605+
%x6 = trunc <8 x i64> %x5 to <8 x i16>
606+
store <8 x i16> %x6, <8 x i16>* %res, align 1
607+
ret void
608+
}
609+

0 commit comments

Comments
 (0)