diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 6074498d9144f..57a3f6a65e002 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -28,6 +28,8 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" @@ -147,6 +149,14 @@ class VectorLegalizer { void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl &Results); void ExpandREM(SDNode *Node, SmallVectorImpl &Results); + bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC, + SmallVectorImpl &Results); + bool tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall Call_F32, + RTLIB::Libcall Call_F64, RTLIB::Libcall Call_F80, + RTLIB::Libcall Call_F128, + RTLIB::Libcall Call_PPCF128, + SmallVectorImpl &Results); + void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl &Results); /// Implements vector promotion. @@ -1139,6 +1149,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl &Results) { case ISD::VP_MERGE: Results.push_back(ExpandVP_MERGE(Node)); return; + case ISD::FREM: + if (tryExpandVecMathCall(Node, RTLIB::REM_F32, RTLIB::REM_F64, + RTLIB::REM_F80, RTLIB::REM_F128, + RTLIB::REM_PPCF128, Results)) + return; + + break; } SDValue Unrolled = DAG.UnrollVectorOp(Node); @@ -1842,6 +1859,117 @@ void VectorLegalizer::ExpandREM(SDNode *Node, Results.push_back(Result); } +// Try to expand libm nodes into vector math routine calls. Callers provide the +// LibFunc equivalent of the passed in Node, which is used to lookup mappings +// within TargetLibraryInfo. The only mappings considered are those where the +// result and all operands are the same vector type. While predicated nodes are +// not supported, we will emit calls to masked routines by passing in an all +// true mask. +bool VectorLegalizer::tryExpandVecMathCall(SDNode *Node, RTLIB::Libcall LC, + SmallVectorImpl &Results) { + // Chain must be propagated but currently strict fp operations are down + // converted to their none strict counterpart. + assert(!Node->isStrictFPOpcode() && "Unexpected strict fp operation!"); + + const char *LCName = TLI.getLibcallName(LC); + if (!LCName) + return false; + LLVM_DEBUG(dbgs() << "Looking for vector variant of " << LCName << "\n"); + + EVT VT = Node->getValueType(0); + ElementCount VL = VT.getVectorElementCount(); + + // Lookup a vector function equivalent to the specified libcall. Prefer + // unmasked variants but we will generate a mask if need be. + const TargetLibraryInfo &TLibInfo = DAG.getLibInfo(); + const VecDesc *VD = TLibInfo.getVectorMappingInfo(LCName, VL, false); + if (!VD) + VD = TLibInfo.getVectorMappingInfo(LCName, VL, /*Masked=*/true); + if (!VD) + return false; + + LLVMContext *Ctx = DAG.getContext(); + Type *Ty = VT.getTypeForEVT(*Ctx); + Type *ScalarTy = Ty->getScalarType(); + + // Construct a scalar function type based on Node's operands. + SmallVector ArgTys; + for (unsigned i = 0; i < Node->getNumOperands(); ++i) { + assert(Node->getOperand(i).getValueType() == VT && + "Expected matching vector types!"); + ArgTys.push_back(ScalarTy); + } + FunctionType *ScalarFTy = FunctionType::get(ScalarTy, ArgTys, false); + + // Generate call information for the vector function. + const std::string MangledName = VD->getVectorFunctionABIVariantString(); + auto OptVFInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy); + if (!OptVFInfo) + return false; + + LLVM_DEBUG(dbgs() << "Found vector variant " << VD->getVectorFnName() + << "\n"); + + // Sanity check just in case OptVFInfo has unexpected parameters. + if (OptVFInfo->Shape.Parameters.size() != + Node->getNumOperands() + VD->isMasked()) + return false; + + // Collect vector call operands. + + SDLoc DL(Node); + TargetLowering::ArgListTy Args; + TargetLowering::ArgListEntry Entry; + Entry.IsSExt = false; + Entry.IsZExt = false; + + unsigned OpNum = 0; + for (auto &VFParam : OptVFInfo->Shape.Parameters) { + if (VFParam.ParamKind == VFParamKind::GlobalPredicate) { + EVT MaskVT = TLI.getSetCCResultType(DAG.getDataLayout(), *Ctx, VT); + Entry.Node = DAG.getBoolConstant(true, DL, MaskVT, VT); + Entry.Ty = MaskVT.getTypeForEVT(*Ctx); + Args.push_back(Entry); + continue; + } + + // Only vector operands are supported. + if (VFParam.ParamKind != VFParamKind::Vector) + return false; + + Entry.Node = Node->getOperand(OpNum++); + Entry.Ty = Ty; + Args.push_back(Entry); + } + + // Emit a call to the vector function. + SDValue Callee = DAG.getExternalSymbol(VD->getVectorFnName().data(), + TLI.getPointerTy(DAG.getDataLayout())); + TargetLowering::CallLoweringInfo CLI(DAG); + CLI.setDebugLoc(DL) + .setChain(DAG.getEntryNode()) + .setLibCallee(CallingConv::C, Ty, Callee, std::move(Args)); + + std::pair CallResult = TLI.LowerCallTo(CLI); + Results.push_back(CallResult.first); + return true; +} + +/// Try to expand the node to a vector libcall based on the result type. +bool VectorLegalizer::tryExpandVecMathCall( + SDNode *Node, RTLIB::Libcall Call_F32, RTLIB::Libcall Call_F64, + RTLIB::Libcall Call_F80, RTLIB::Libcall Call_F128, + RTLIB::Libcall Call_PPCF128, SmallVectorImpl &Results) { + RTLIB::Libcall LC = RTLIB::getFPLibCall( + Node->getValueType(0).getVectorElementType(), Call_F32, Call_F64, + Call_F80, Call_F128, Call_PPCF128); + + if (LC == RTLIB::UNKNOWN_LIBCALL) + return false; + + return tryExpandVecMathCall(Node, LC, Results); +} + void VectorLegalizer::UnrollStrictFPOp(SDNode *Node, SmallVectorImpl &Results) { EVT VT = Node->getValueType(0); diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp index cf068ece8d4ca..8832b51333d91 100644 --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -205,6 +205,10 @@ static cl::opt MISchedPostRA( static cl::opt EarlyLiveIntervals("early-live-intervals", cl::Hidden, cl::desc("Run live interval analysis earlier in the pipeline")); +static cl::opt DisableReplaceWithVecLib( + "disable-replace-with-vec-lib", cl::Hidden, + cl::desc("Disable replace with vector math call pass")); + /// Option names for limiting the codegen pipeline. /// Those are used in error reporting and we didn't want /// to duplicate their names all over the place. @@ -856,7 +860,7 @@ void TargetPassConfig::addIRPasses() { if (getOptLevel() != CodeGenOptLevel::None && !DisableConstantHoisting) addPass(createConstantHoistingPass()); - if (getOptLevel() != CodeGenOptLevel::None) + if (getOptLevel() != CodeGenOptLevel::None && !DisableReplaceWithVecLib) addPass(createReplaceWithVeclibLegacyPass()); if (getOptLevel() != CodeGenOptLevel::None && !DisablePartialLibcallInlining) diff --git a/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll new file mode 100644 index 0000000000000..67c056c780cc8 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/fp-veclib-expansion.ll @@ -0,0 +1,116 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 +; RUN: llc --disable-replace-with-vec-lib --vector-library=ArmPL < %s -o - | FileCheck --check-prefix=ARMPL %s +; RUN: llc --disable-replace-with-vec-lib --vector-library=sleefgnuabi < %s -o - | FileCheck --check-prefix=SLEEF %s + +target triple = "aarch64-unknown-linux-gnu" + +define <2 x double> @frem_v2f64(<2 x double> %unused, <2 x double> %a, <2 x double> %b) #0 { +; ARMPL-LABEL: frem_v2f64: +; ARMPL: // %bb.0: +; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; ARMPL-NEXT: .cfi_def_cfa_offset 16 +; ARMPL-NEXT: .cfi_offset w30, -16 +; ARMPL-NEXT: mov v0.16b, v1.16b +; ARMPL-NEXT: mov v1.16b, v2.16b +; ARMPL-NEXT: bl armpl_vfmodq_f64 +; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; ARMPL-NEXT: ret +; +; SLEEF-LABEL: frem_v2f64: +; SLEEF: // %bb.0: +; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; SLEEF-NEXT: .cfi_def_cfa_offset 16 +; SLEEF-NEXT: .cfi_offset w30, -16 +; SLEEF-NEXT: mov v0.16b, v1.16b +; SLEEF-NEXT: mov v1.16b, v2.16b +; SLEEF-NEXT: bl _ZGVnN2vv_fmod +; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; SLEEF-NEXT: ret + %res = frem <2 x double> %a, %b + ret <2 x double> %res +} + +define <4 x float> @frem_strict_v4f32(<4 x float> %unused, <4 x float> %a, <4 x float> %b) #1 { +; ARMPL-LABEL: frem_strict_v4f32: +; ARMPL: // %bb.0: +; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; ARMPL-NEXT: .cfi_def_cfa_offset 16 +; ARMPL-NEXT: .cfi_offset w30, -16 +; ARMPL-NEXT: mov v0.16b, v1.16b +; ARMPL-NEXT: mov v1.16b, v2.16b +; ARMPL-NEXT: bl armpl_vfmodq_f32 +; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; ARMPL-NEXT: ret +; +; SLEEF-LABEL: frem_strict_v4f32: +; SLEEF: // %bb.0: +; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; SLEEF-NEXT: .cfi_def_cfa_offset 16 +; SLEEF-NEXT: .cfi_offset w30, -16 +; SLEEF-NEXT: mov v0.16b, v1.16b +; SLEEF-NEXT: mov v1.16b, v2.16b +; SLEEF-NEXT: bl _ZGVnN4vv_fmodf +; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; SLEEF-NEXT: ret + %res = frem <4 x float> %a, %b + ret <4 x float> %res +} + +define @frem_nxv4f32( %unused, %a, %b) #0 { +; ARMPL-LABEL: frem_nxv4f32: +; ARMPL: // %bb.0: +; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; ARMPL-NEXT: .cfi_def_cfa_offset 16 +; ARMPL-NEXT: .cfi_offset w30, -16 +; ARMPL-NEXT: ptrue p0.s +; ARMPL-NEXT: mov z0.d, z1.d +; ARMPL-NEXT: mov z1.d, z2.d +; ARMPL-NEXT: bl armpl_svfmod_f32_x +; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; ARMPL-NEXT: ret +; +; SLEEF-LABEL: frem_nxv4f32: +; SLEEF: // %bb.0: +; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; SLEEF-NEXT: .cfi_def_cfa_offset 16 +; SLEEF-NEXT: .cfi_offset w30, -16 +; SLEEF-NEXT: ptrue p0.s +; SLEEF-NEXT: mov z0.d, z1.d +; SLEEF-NEXT: mov z1.d, z2.d +; SLEEF-NEXT: bl _ZGVsMxvv_fmodf +; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; SLEEF-NEXT: ret + %res = frem %a, %b + ret %res +} + +define @frem_strict_nxv2f64( %unused, %a, %b) #1 { +; ARMPL-LABEL: frem_strict_nxv2f64: +; ARMPL: // %bb.0: +; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; ARMPL-NEXT: .cfi_def_cfa_offset 16 +; ARMPL-NEXT: .cfi_offset w30, -16 +; ARMPL-NEXT: ptrue p0.d +; ARMPL-NEXT: mov z0.d, z1.d +; ARMPL-NEXT: mov z1.d, z2.d +; ARMPL-NEXT: bl armpl_svfmod_f64_x +; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; ARMPL-NEXT: ret +; +; SLEEF-LABEL: frem_strict_nxv2f64: +; SLEEF: // %bb.0: +; SLEEF-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; SLEEF-NEXT: .cfi_def_cfa_offset 16 +; SLEEF-NEXT: .cfi_offset w30, -16 +; SLEEF-NEXT: ptrue p0.d +; SLEEF-NEXT: mov z0.d, z1.d +; SLEEF-NEXT: mov z1.d, z2.d +; SLEEF-NEXT: bl _ZGVsMxvv_fmod +; SLEEF-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; SLEEF-NEXT: ret + %res = frem %a, %b + ret %res +} + +attributes #0 = { "target-features"="+sve" } +attributes #1 = { "target-features"="+sve" strictfp }