Skip to content

Commit

Permalink
[aarch64] atan2 intrinsic lowering (p5) (#112611)
Browse files Browse the repository at this point in the history
This change is part of this proposal:
https://discourse.llvm.org/t/rfc-all-the-math-intrinsics/78294

- `VecFuncs.def`: define intrinsic to sleef/armpl mapping
- `LegalizerHelper.cpp`: add missing fewerElementsVector handling for
the new atan2 intrinsic
- `AArch64ISelLowering.cpp`: Add arch64 specializations for lowering
like neon instructions
- `AArch64LegalizerInfo.cpp`: Legalize atan2.

Part 5 for Implement the atan2 HLSL Function #70096.
  • Loading branch information
tex3d authored Oct 25, 2024
1 parent b1be213 commit c03d09c
Show file tree
Hide file tree
Showing 19 changed files with 698 additions and 22 deletions.
11 changes: 11 additions & 0 deletions llvm/include/llvm/Analysis/VecFuncs.def
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ TLI_DEFINE_VECFUNC("llvm.atan.f64", "_simd_atan_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("atanf", "_simd_atan_f4", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.atan.f32", "_simd_atan_f4", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("atan2", "_simd_atan2_d2", FIXED(2), "_ZGV_LLVM_N2vv")
TLI_DEFINE_VECFUNC("llvm.atan2.f64", "_simd_atan2_d2", FIXED(2), "_ZGV_LLVM_N2vv")
TLI_DEFINE_VECFUNC("atan2f", "_simd_atan2_f4", FIXED(4), "_ZGV_LLVM_N4vv")
TLI_DEFINE_VECFUNC("llvm.atan2.f32", "_simd_atan2_f4", FIXED(4), "_ZGV_LLVM_N4vv")

TLI_DEFINE_VECFUNC("cos", "_simd_cos_d2", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("llvm.cos.f64", "_simd_cos_d2", FIXED(2), "_ZGV_LLVM_N2v")
Expand Down Expand Up @@ -531,6 +533,7 @@ TLI_DEFINE_VECFUNC("atan", "_ZGVnN2v_atan", FIXED(2), "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("llvm.atan.f64", "_ZGVnN2v_atan", FIXED(2), "_ZGV_LLVM_N2v")

TLI_DEFINE_VECFUNC("atan2", "_ZGVnN2vv_atan2", FIXED(2), "_ZGV_LLVM_N2vv")
TLI_DEFINE_VECFUNC("llvm.atan2.f64", "_ZGVnN2vv_atan2", FIXED(2), "_ZGV_LLVM_N2vv")

TLI_DEFINE_VECFUNC("atanh", "_ZGVnN2v_atanh", FIXED(2), "_ZGV_LLVM_N2v")

Expand Down Expand Up @@ -635,6 +638,7 @@ TLI_DEFINE_VECFUNC("atanf", "_ZGVnN4v_atanf", FIXED(4), "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("llvm.atan.f32", "_ZGVnN4v_atanf", FIXED(4), "_ZGV_LLVM_N4v")

TLI_DEFINE_VECFUNC("atan2f", "_ZGVnN4vv_atan2f", FIXED(4), "_ZGV_LLVM_N4vv")
TLI_DEFINE_VECFUNC("llvm.atan2.f32", "_ZGVnN4vv_atan2f", FIXED(4), "_ZGV_LLVM_N4vv")

TLI_DEFINE_VECFUNC("atanhf", "_ZGVnN4v_atanhf", FIXED(4), "_ZGV_LLVM_N4v")

Expand Down Expand Up @@ -748,6 +752,8 @@ TLI_DEFINE_VECFUNC("llvm.atan.f32", "_ZGVsMxv_atanf", SCALABLE(4), MASKED, "_ZGV

TLI_DEFINE_VECFUNC("atan2", "_ZGVsMxvv_atan2", SCALABLE(2), MASKED, "_ZGVsMxvv")
TLI_DEFINE_VECFUNC("atan2f", "_ZGVsMxvv_atan2f", SCALABLE(4), MASKED, "_ZGVsMxvv")
TLI_DEFINE_VECFUNC("llvm.atan2.f64", "_ZGVsMxvv_atan2", SCALABLE(2), MASKED, "_ZGVsMxvv")
TLI_DEFINE_VECFUNC("llvm.atan2.f32", "_ZGVsMxvv_atan2f", SCALABLE(4), MASKED, "_ZGVsMxvv")

TLI_DEFINE_VECFUNC("atanh", "_ZGVsMxv_atanh", SCALABLE(2), MASKED, "_ZGVsMxv")
TLI_DEFINE_VECFUNC("atanhf", "_ZGVsMxv_atanhf", SCALABLE(4), MASKED, "_ZGVsMxv")
Expand Down Expand Up @@ -933,6 +939,11 @@ TLI_DEFINE_VECFUNC("atan2f", "armpl_vatan2q_f32", FIXED(4), NOMASK, "_ZGV_LLVM_N
TLI_DEFINE_VECFUNC("atan2", "armpl_svatan2_f64_x", SCALABLE(2), MASKED, "_ZGVsMxvv")
TLI_DEFINE_VECFUNC("atan2f", "armpl_svatan2_f32_x", SCALABLE(4), MASKED, "_ZGVsMxvv")

TLI_DEFINE_VECFUNC("llvm.atan2.f64", "armpl_vatan2q_f64", FIXED(2), NOMASK, "_ZGV_LLVM_N2vv")
TLI_DEFINE_VECFUNC("llvm.atan2.f32", "armpl_vatan2q_f32", FIXED(4), NOMASK, "_ZGV_LLVM_N4vv")
TLI_DEFINE_VECFUNC("llvm.atan2.f64", "armpl_svatan2_f64_x", SCALABLE(2), MASKED, "_ZGVsMxvv")
TLI_DEFINE_VECFUNC("llvm.atan2.f32", "armpl_svatan2_f32_x", SCALABLE(4), MASKED, "_ZGVsMxvv")

TLI_DEFINE_VECFUNC("atanh", "armpl_vatanhq_f64", FIXED(2), NOMASK, "_ZGV_LLVM_N2v")
TLI_DEFINE_VECFUNC("atanhf", "armpl_vatanhq_f32", FIXED(4), NOMASK, "_ZGV_LLVM_N4v")
TLI_DEFINE_VECFUNC("atanh", "armpl_svatanh_f64_x", SCALABLE(2), MASKED, "_ZGVsMxv")
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def : GINodeEquiv<G_FTAN, ftan>;
def : GINodeEquiv<G_FACOS, facos>;
def : GINodeEquiv<G_FASIN, fasin>;
def : GINodeEquiv<G_FATAN, fatan>;
def : GINodeEquiv<G_FATAN2, fatan2>;
def : GINodeEquiv<G_FCOSH, fcosh>;
def : GINodeEquiv<G_FSINH, fsinh>;
def : GINodeEquiv<G_FTANH, ftanh>;
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ static RTLIB::Libcall getRTLibDesc(unsigned Opcode, unsigned Size) {
RTLIBCASE(ACOS_F);
case TargetOpcode::G_FATAN:
RTLIBCASE(ATAN_F);
case TargetOpcode::G_FATAN2:
RTLIBCASE(ATAN2_F);
case TargetOpcode::G_FSINH:
RTLIBCASE(SINH_F);
case TargetOpcode::G_FCOSH:
Expand Down Expand Up @@ -1202,6 +1204,7 @@ LegalizerHelper::libcall(MachineInstr &MI, LostDebugLocObserver &LocObserver) {
case TargetOpcode::G_FACOS:
case TargetOpcode::G_FASIN:
case TargetOpcode::G_FATAN:
case TargetOpcode::G_FATAN2:
case TargetOpcode::G_FCOSH:
case TargetOpcode::G_FSINH:
case TargetOpcode::G_FTANH:
Expand Down Expand Up @@ -3122,6 +3125,7 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
case TargetOpcode::G_FACOS:
case TargetOpcode::G_FASIN:
case TargetOpcode::G_FATAN:
case TargetOpcode::G_FATAN2:
case TargetOpcode::G_FCOSH:
case TargetOpcode::G_FSINH:
case TargetOpcode::G_FTANH:
Expand Down Expand Up @@ -5141,6 +5145,7 @@ LegalizerHelper::fewerElementsVector(MachineInstr &MI, unsigned TypeIdx,
case G_FACOS:
case G_FASIN:
case G_FATAN:
case G_FATAN2:
case G_FCOSH:
case G_FSINH:
case G_FTANH:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ bool llvm::isKnownNeverNaN(Register Val, const MachineRegisterInfo &MRI,
case TargetOpcode::G_FACOS:
case TargetOpcode::G_FASIN:
case TargetOpcode::G_FATAN:
case TargetOpcode::G_FATAN2:
case TargetOpcode::G_FCOSH:
case TargetOpcode::G_FSINH:
case TargetOpcode::G_FTANH:
Expand Down Expand Up @@ -1715,6 +1716,7 @@ bool llvm::isPreISelGenericFloatingPointOpcode(unsigned Opc) {
case TargetOpcode::G_FACOS:
case TargetOpcode::G_FASIN:
case TargetOpcode::G_FATAN:
case TargetOpcode::G_FATAN2:
case TargetOpcode::G_FCOSH:
case TargetOpcode::G_FSINH:
case TargetOpcode::G_FTANH:
Expand Down
29 changes: 16 additions & 13 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,18 +734,19 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
}

for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
ISD::FACOS, ISD::FASIN, ISD::FATAN,
ISD::FCOSH, ISD::FSINH, ISD::FTANH,
ISD::FTAN, ISD::FEXP, ISD::FEXP2,
ISD::FEXP10, ISD::FLOG, ISD::FLOG2,
ISD::FLOG10, ISD::STRICT_FREM, ISD::STRICT_FPOW,
ISD::STRICT_FPOWI, ISD::STRICT_FCOS, ISD::STRICT_FSIN,
ISD::STRICT_FACOS, ISD::STRICT_FASIN, ISD::STRICT_FATAN,
ISD::STRICT_FCOSH, ISD::STRICT_FSINH, ISD::STRICT_FTANH,
ISD::STRICT_FEXP, ISD::STRICT_FEXP2, ISD::STRICT_FLOG,
ISD::STRICT_FLOG2, ISD::STRICT_FLOG10, ISD::STRICT_FTAN}) {
for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
ISD::FACOS, ISD::FASIN, ISD::FATAN,
ISD::FATAN2, ISD::FCOSH, ISD::FSINH,
ISD::FTANH, ISD::FTAN, ISD::FEXP,
ISD::FEXP2, ISD::FEXP10, ISD::FLOG,
ISD::FLOG2, ISD::FLOG10, ISD::STRICT_FREM,
ISD::STRICT_FPOW, ISD::STRICT_FPOWI, ISD::STRICT_FCOS,
ISD::STRICT_FSIN, ISD::STRICT_FACOS, ISD::STRICT_FASIN,
ISD::STRICT_FATAN, ISD::STRICT_FATAN2, ISD::STRICT_FCOSH,
ISD::STRICT_FSINH, ISD::STRICT_FTANH, ISD::STRICT_FEXP,
ISD::STRICT_FEXP2, ISD::STRICT_FLOG, ISD::STRICT_FLOG2,
ISD::STRICT_FLOG10, ISD::STRICT_FTAN}) {
setOperationAction(Op, MVT::f16, Promote);
setOperationAction(Op, MVT::v4f16, Expand);
setOperationAction(Op, MVT::v8f16, Expand);
Expand Down Expand Up @@ -1190,7 +1191,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// silliness like this:
// clang-format off
for (auto Op :
{ISD::SELECT, ISD::SELECT_CC,
{ISD::SELECT, ISD::SELECT_CC, ISD::FATAN2,
ISD::BR_CC, ISD::FADD, ISD::FSUB,
ISD::FMUL, ISD::FDIV, ISD::FMA,
ISD::FNEG, ISD::FABS, ISD::FCEIL,
Expand Down Expand Up @@ -1649,6 +1650,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FACOS, VT, Expand);
setOperationAction(ISD::FASIN, VT, Expand);
setOperationAction(ISD::FATAN, VT, Expand);
setOperationAction(ISD::FATAN2, VT, Expand);
setOperationAction(ISD::FCOSH, VT, Expand);
setOperationAction(ISD::FSINH, VT, Expand);
setOperationAction(ISD::FTANH, VT, Expand);
Expand Down Expand Up @@ -1904,6 +1906,7 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
setOperationAction(ISD::FASIN, VT, Expand);
setOperationAction(ISD::FACOS, VT, Expand);
setOperationAction(ISD::FATAN, VT, Expand);
setOperationAction(ISD::FATAN2, VT, Expand);
setOperationAction(ISD::FSINH, VT, Expand);
setOperationAction(ISD::FCOSH, VT, Expand);
setOperationAction(ISD::FTANH, VT, Expand);
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.libcallFor({{s64, s128}})
.minScalarOrElt(1, MinFPScalar);

getActionDefinitionsBuilder(
{G_FCOS, G_FSIN, G_FPOW, G_FLOG, G_FLOG2, G_FLOG10, G_FTAN, G_FEXP,
G_FEXP2, G_FEXP10, G_FACOS, G_FASIN, G_FATAN, G_FCOSH, G_FSINH, G_FTANH})
getActionDefinitionsBuilder({G_FCOS, G_FSIN, G_FPOW, G_FLOG, G_FLOG2,
G_FLOG10, G_FTAN, G_FEXP, G_FEXP2, G_FEXP10,
G_FACOS, G_FASIN, G_FATAN, G_FATAN2, G_FCOSH,
G_FSINH, G_FTANH})
// We need a call for these, so we always need to scalarize.
.scalarize(0)
// Regardless of FP16 support, widen 16-bit elements to 32-bits.
Expand Down
8 changes: 8 additions & 0 deletions llvm/test/CodeGen/AArch64/GlobalISel/arm64-irtranslator.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2345,6 +2345,14 @@ define float @test_atan_f32(float %x) {
ret float %y
}

declare float @llvm.atan2.f32(float, float)
define float @test_atan2_f32(float %x, float %y) {
; CHECK-LABEL: name: test_atan2_f32
; CHECK: %{{[0-9]+}}:_(s32) = G_FATAN2 %{{[0-9]+}}
%z = call float @llvm.atan2.f32(float %x, float %y)
ret float %z
}

declare float @llvm.cosh.f32(float)
define float @test_cosh_f32(float %x) {
; CHECK-LABEL: name: test_cosh_f32
Expand Down
Loading

0 comments on commit c03d09c

Please sign in to comment.