Skip to content

Commit

Permalink
Merge pull request #66 from AinsleySnow/vp-select
Browse files Browse the repository at this point in the history
[VP][DAGCombiner] Use `simplifySelect` when combining vp.select.
  • Loading branch information
ChunyuLiao authored Apr 1, 2024
2 parents e009baa + 2b7b578 commit 3ff1963
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 10 deletions.
47 changes: 37 additions & 10 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ namespace {
SDValue visitCTPOP(SDNode *N);
SDValue visitSELECT(SDNode *N);
SDValue visitVSELECT(SDNode *N);
SDValue visitVP_SELECT(SDNode *N);
SDValue visitSELECT_CC(SDNode *N);
SDValue visitSETCC(SDNode *N);
SDValue visitSETCCCARRY(SDNode *N);
Expand Down Expand Up @@ -922,6 +923,9 @@ class VPMatchContext {
assert(Root->isVPOpcode());
if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
RootMaskOp = Root->getOperand(*RootMaskPos);
else if (Root->getOpcode() == ISD::VP_SELECT)
RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
Root->getOperand(0).getValueType());

if (auto RootVLenPos =
ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
Expand Down Expand Up @@ -11401,35 +11405,42 @@ SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) {
return SDValue();
}

template <class MatchContextClass>
static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) &&
"Expected a (v)select");
assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT ||
N->getOpcode() == ISD::VP_SELECT) &&
"Expected a (v)(vp.)select");
SDValue Cond = N->getOperand(0);
SDValue T = N->getOperand(1), F = N->getOperand(2);
EVT VT = N->getValueType(0);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
MatchContextClass matcher(DAG, TLI, N);

if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
return SDValue();

// select Cond, Cond, F --> or Cond, F
// select Cond, 1, F --> or Cond, F
if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
return matcher.getNode(ISD::OR, SDLoc(N), VT, Cond, F);

// select Cond, T, Cond --> and Cond, T
// select Cond, T, 0 --> and Cond, T
if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
return matcher.getNode(ISD::AND, SDLoc(N), VT, Cond, T);

// select Cond, T, 1 --> or (not Cond), T
if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
DAG.getAllOnesConstant(SDLoc(N), VT));
return matcher.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
}

// select Cond, 0, F --> and (not Cond), F
if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT);
return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
DAG.getAllOnesConstant(SDLoc(N), VT));
return matcher.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
}

return SDValue();
Expand Down Expand Up @@ -11505,7 +11516,7 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) {
if (SDValue V = DAG.simplifySelect(N0, N1, N2))
return V;

if (SDValue V = foldBoolSelectToLogic(N, DAG))
if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
return V;

// select (not Cond), N1, N2 -> select Cond, N2, N1
Expand Down Expand Up @@ -12119,6 +12130,20 @@ SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);

if (SDValue V = DAG.simplifySelect(N0, N1, N2))
return V;

if (SDValue V = foldBoolSelectToLogic<VPMatchContext>(N, DAG))
return V;

return SDValue();
}

SDValue DAGCombiner::visitVSELECT(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
Expand All @@ -12129,7 +12154,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
if (SDValue V = DAG.simplifySelect(N0, N1, N2))
return V;

if (SDValue V = foldBoolSelectToLogic(N, DAG))
if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DAG))
return V;

// vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
Expand Down Expand Up @@ -26374,6 +26399,8 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
return visitVP_FSUB(N);
case ISD::VP_FMA:
return visitFMA<VPMatchContext>(N);
case ISD::VP_SELECT:
return visitVP_SELECT(N);
}
return SDValue();
}
Expand Down
133 changes: 133 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,136 @@ define <vscale x 16 x double> @select_nxv16f64(<vscale x 16 x i1> %a, <vscale x
%v = call <vscale x 16 x double> @llvm.vp.select.nxv16f64(<vscale x 16 x i1> %a, <vscale x 16 x double> %b, <vscale x 16 x double> %c, i32 %evl)
ret <vscale x 16 x double> %v
}

define <vscale x 2 x i1> @select_zero(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: select_zero:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
; CHECK-NEXT: vmand.mm v0, v0, v8
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> zeroinitializer, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_one(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: select_one:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
; CHECK-NEXT: vmorn.mm v0, v8, v0
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_x_zero(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
; CHECK-LABEL: select_x_zero:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
; CHECK-NEXT: vmand.mm v0, v0, v8
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> zeroinitializer, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_x_one(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
; CHECK-LABEL: select_x_one:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
; CHECK-NEXT: vmorn.mm v0, v8, v0
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_zero_x(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
; CHECK-LABEL: select_zero_x:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
; CHECK-NEXT: vmandn.mm v0, v8, v0
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> zeroinitializer, <vscale x 2 x i1> %y, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_one_x(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
; CHECK-LABEL: select_one_x:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
; CHECK-NEXT: vmor.mm v0, v0, v8
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), <vscale x 2 x i1> %y, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_cond_cond_x(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: select_cond_cond_x:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
; CHECK-NEXT: vmor.mm v0, v0, v8
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_cond_x_cond(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: select_cond_x_cond:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma
; CHECK-NEXT: vmand.mm v0, v0, v8
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %x, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_undef_T_F(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
; CHECK-LABEL: select_undef_T_F:
; CHECK: # %bb.0:
; CHECK-NEXT: vmv1r.v v0, v8
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> undef, <vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_undef_undef_F(<vscale x 2 x i1> %x, i32 zeroext %evl) {
; CHECK-LABEL: select_undef_undef_F:
; CHECK: # %bb.0:
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> undef, <vscale x 2 x i1> undef, <vscale x 2 x i1> %x, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_unknown_undef_F(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
; CHECK-LABEL: select_unknown_undef_F:
; CHECK: # %bb.0:
; CHECK-NEXT: vmv1r.v v0, v8
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> undef, <vscale x 2 x i1> %y, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_unknown_T_undef(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
; CHECK-LABEL: select_unknown_T_undef:
; CHECK: # %bb.0:
; CHECK-NEXT: vmv1r.v v0, v8
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> undef, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_false_T_F(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %z, i32 zeroext %evl) {
; CHECK-LABEL: select_false_T_F:
; CHECK: # %bb.0:
; CHECK-NEXT: vmv1r.v v0, v9
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> zeroinitializer, <vscale x 2 x i1> %y, <vscale x 2 x i1> %z, i32 %evl)
ret <vscale x 2 x i1> %a
}

define <vscale x 2 x i1> @select_unknown_T_T(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, i32 zeroext %evl) {
; CHECK-LABEL: select_unknown_T_T:
; CHECK: # %bb.0:
; CHECK-NEXT: vmv1r.v v0, v8
; CHECK-NEXT: ret
%a = call <vscale x 2 x i1> @llvm.vp.select.nxv2i1(<vscale x 2 x i1> %x, <vscale x 2 x i1> %y, <vscale x 2 x i1> %y, i32 %evl)
ret <vscale x 2 x i1> %a
}

0 comments on commit 3ff1963

Please sign in to comment.