diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index c92a0c2a06d4..0b0600aa067b 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -474,6 +474,7 @@ namespace { SDValue visitCTPOP(SDNode *N); SDValue visitSELECT(SDNode *N); SDValue visitVSELECT(SDNode *N); + SDValue visitVP_SELECT(SDNode *N); SDValue visitSELECT_CC(SDNode *N); SDValue visitSETCC(SDNode *N); SDValue visitSETCCCARRY(SDNode *N); @@ -922,6 +923,9 @@ class VPMatchContext { assert(Root->isVPOpcode()); if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode())) RootMaskOp = Root->getOperand(*RootMaskPos); + else if (Root->getOpcode() == ISD::VP_SELECT) + RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root), + Root->getOperand(0).getValueType()); if (auto RootVLenPos = ISD::getVPExplicitVectorLengthIdx(Root->getOpcode())) @@ -11401,35 +11405,42 @@ SDValue DAGCombiner::foldSelectOfConstants(SDNode *N) { return SDValue(); } +template static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) { - assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT) && - "Expected a (v)select"); + assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT || + N->getOpcode() == ISD::VP_SELECT) && + "Expected a (v)(vp.)select"); SDValue Cond = N->getOperand(0); SDValue T = N->getOperand(1), F = N->getOperand(2); EVT VT = N->getValueType(0); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + MatchContextClass matcher(DAG, TLI, N); + if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1) return SDValue(); // select Cond, Cond, F --> or Cond, F // select Cond, 1, F --> or Cond, F if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true)) - return DAG.getNode(ISD::OR, SDLoc(N), VT, Cond, F); + return matcher.getNode(ISD::OR, SDLoc(N), VT, Cond, F); // select Cond, T, Cond --> and Cond, T // select Cond, T, 0 --> and Cond, T if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true)) - return DAG.getNode(ISD::AND, SDLoc(N), VT, Cond, T); + return matcher.getNode(ISD::AND, SDLoc(N), VT, Cond, T); // select Cond, T, 1 --> or (not Cond), T if (isOneOrOneSplat(F, /* AllowUndefs */ true)) { - SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT); - return DAG.getNode(ISD::OR, SDLoc(N), VT, NotCond, T); + SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond, + DAG.getAllOnesConstant(SDLoc(N), VT)); + return matcher.getNode(ISD::OR, SDLoc(N), VT, NotCond, T); } // select Cond, 0, F --> and (not Cond), F if (isNullOrNullSplat(T, /* AllowUndefs */ true)) { - SDValue NotCond = DAG.getNOT(SDLoc(N), Cond, VT); - return DAG.getNode(ISD::AND, SDLoc(N), VT, NotCond, F); + SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond, + DAG.getAllOnesConstant(SDLoc(N), VT)); + return matcher.getNode(ISD::AND, SDLoc(N), VT, NotCond, F); } return SDValue(); @@ -11505,7 +11516,7 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { if (SDValue V = DAG.simplifySelect(N0, N1, N2)) return V; - if (SDValue V = foldBoolSelectToLogic(N, DAG)) + if (SDValue V = foldBoolSelectToLogic(N, DAG)) return V; // select (not Cond), N1, N2 -> select Cond, N2, N1 @@ -12119,6 +12130,20 @@ SDValue DAGCombiner::foldVSelectOfConstants(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitVP_SELECT(SDNode *N) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDValue N2 = N->getOperand(2); + + if (SDValue V = DAG.simplifySelect(N0, N1, N2)) + return V; + + if (SDValue V = foldBoolSelectToLogic(N, DAG)) + return V; + + return SDValue(); +} + SDValue DAGCombiner::visitVSELECT(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -12129,7 +12154,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { if (SDValue V = DAG.simplifySelect(N0, N1, N2)) return V; - if (SDValue V = foldBoolSelectToLogic(N, DAG)) + if (SDValue V = foldBoolSelectToLogic(N, DAG)) return V; // vselect (not Cond), N1, N2 -> vselect Cond, N2, N1 @@ -26374,6 +26399,8 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) { return visitVP_FSUB(N); case ISD::VP_FMA: return visitFMA(N); + case ISD::VP_SELECT: + return visitVP_SELECT(N); } return SDValue(); } diff --git a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll index 9e7df5eab8dd..38e190a2336d 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vselect-vp.ll @@ -745,3 +745,136 @@ define @select_nxv16f64( %a, @llvm.vp.select.nxv16f64( %a, %b, %c, i32 %evl) ret %v } + +define @select_zero( %x, %y, %m, i32 zeroext %evl) { +; CHECK-LABEL: select_zero: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vmand.mm v0, v0, v8 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, %y, zeroinitializer, i32 %evl) + ret %a +} + +define @select_one( %x, %y, %m, i32 zeroext %evl) { +; CHECK-LABEL: select_one: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vmorn.mm v0, v8, v0 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, %y, shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer), i32 %evl) + ret %a +} + +define @select_x_zero( %x, %y, i32 zeroext %evl) { +; CHECK-LABEL: select_x_zero: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vmand.mm v0, v0, v8 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, %y, zeroinitializer, i32 %evl) + ret %a +} + +define @select_x_one( %x, %y, i32 zeroext %evl) { +; CHECK-LABEL: select_x_one: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vmorn.mm v0, v8, v0 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, %y, shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer), i32 %evl) + ret %a +} + +define @select_zero_x( %x, %y, i32 zeroext %evl) { +; CHECK-LABEL: select_zero_x: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vmandn.mm v0, v8, v0 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, zeroinitializer, %y, i32 %evl) + ret %a +} + +define @select_one_x( %x, %y, i32 zeroext %evl) { +; CHECK-LABEL: select_one_x: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vmor.mm v0, v0, v8 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer), %y, i32 %evl) + ret %a +} + +define @select_cond_cond_x( %x, %y, %m, i32 zeroext %evl) { +; CHECK-LABEL: select_cond_cond_x: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vmor.mm v0, v0, v8 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, %x, %y, i32 %evl) + ret %a +} + +define @select_cond_x_cond( %x, %y, %m, i32 zeroext %evl) { +; CHECK-LABEL: select_cond_x_cond: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vmand.mm v0, v0, v8 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, %y, %x, i32 %evl) + ret %a +} + +define @select_undef_T_F( %x, %y, i32 zeroext %evl) { +; CHECK-LABEL: select_undef_T_F: +; CHECK: # %bb.0: +; CHECK-NEXT: vmv1r.v v0, v8 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( undef, %x, %y, i32 %evl) + ret %a +} + +define @select_undef_undef_F( %x, i32 zeroext %evl) { +; CHECK-LABEL: select_undef_undef_F: +; CHECK: # %bb.0: +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( undef, undef, %x, i32 %evl) + ret %a +} + +define @select_unknown_undef_F( %x, %y, i32 zeroext %evl) { +; CHECK-LABEL: select_unknown_undef_F: +; CHECK: # %bb.0: +; CHECK-NEXT: vmv1r.v v0, v8 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, undef, %y, i32 %evl) + ret %a +} + +define @select_unknown_T_undef( %x, %y, i32 zeroext %evl) { +; CHECK-LABEL: select_unknown_T_undef: +; CHECK: # %bb.0: +; CHECK-NEXT: vmv1r.v v0, v8 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, %y, undef, i32 %evl) + ret %a +} + +define @select_false_T_F( %x, %y, %z, i32 zeroext %evl) { +; CHECK-LABEL: select_false_T_F: +; CHECK: # %bb.0: +; CHECK-NEXT: vmv1r.v v0, v9 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( zeroinitializer, %y, %z, i32 %evl) + ret %a +} + +define @select_unknown_T_T( %x, %y, i32 zeroext %evl) { +; CHECK-LABEL: select_unknown_T_T: +; CHECK: # %bb.0: +; CHECK-NEXT: vmv1r.v v0, v8 +; CHECK-NEXT: ret + %a = call @llvm.vp.select.nxv2i1( %x, %y, %y, i32 %evl) + ret %a +}