Skip to content

Commit aa7cfa7

Browse files
committed
[SDAG][AArch64] Legalize VECREDUCE
Fixes https://bugs.llvm.org/show_bug.cgi?id=36796. Implement basic legalizations (PromoteIntRes, PromoteIntOp, ExpandIntRes, ScalarizeVecOp, WidenVecOp) for VECREDUCE opcodes. There are more legalizations missing (esp float legalizations), but there's no way to test them right now, so I'm not adding them. This also includes a few more changes to make this work somewhat reasonably: * Add support for expanding VECREDUCE in SDAG. Usually experimental.vector.reduce is expanded prior to codegen, but if the target does have native vector reduce, it may of course still be necessary to expand due to legalization issues. This uses a shuffle reduction if possible, followed by a naive scalar reduction. * Allow the result type of integer VECREDUCE to be larger than the vector element type. For example we need to be able to reduce a v8i8 into an (nominally) i32 result type on AArch64. * Use the vector operand type rather than the scalar result type to determine the action, so we can control exactly which vector types are supported. Also change the legalize vector op code to handle operations that only have vector operands, but no vector results, as is the case for VECREDUCE. * Default VECREDUCE to Expand. On AArch64 (only target using VECREDUCE), explicitly specify for which vector types the reductions are supported. This does not handle anything related to VECREDUCE_STRICT_*. Differential Revision: https://reviews.llvm.org/D58015 llvm-svn: 355860
1 parent a495c64 commit aa7cfa7

16 files changed

+1068
-10
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,11 +872,14 @@ namespace ISD {
872872
VECREDUCE_STRICT_FADD, VECREDUCE_STRICT_FMUL,
873873
/// These reductions are non-strict, and have a single vector operand.
874874
VECREDUCE_FADD, VECREDUCE_FMUL,
875+
/// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants.
876+
VECREDUCE_FMAX, VECREDUCE_FMIN,
877+
/// Integer reductions may have a result type larger than the vector element
878+
/// type. However, the reduction is performed using the vector element type
879+
/// and the value in the top bits is unspecified.
875880
VECREDUCE_ADD, VECREDUCE_MUL,
876881
VECREDUCE_AND, VECREDUCE_OR, VECREDUCE_XOR,
877882
VECREDUCE_SMAX, VECREDUCE_SMIN, VECREDUCE_UMAX, VECREDUCE_UMIN,
878-
/// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants.
879-
VECREDUCE_FMAX, VECREDUCE_FMIN,
880883

881884
/// BUILTIN_OP_END - This must be the last enum value in this list.
882885
/// The target-specific pre-isel opcode values start here.

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3893,6 +3893,10 @@ class TargetLowering : public TargetLoweringBase {
38933893
bool expandMULO(SDNode *Node, SDValue &Result, SDValue &Overflow,
38943894
SelectionDAG &DAG) const;
38953895

3896+
/// Expand a VECREDUCE_* into an explicit calculation. If Count is specified,
3897+
/// only the first Count elements of the vector are used.
3898+
SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const;
3899+
38963900
//===--------------------------------------------------------------------===//
38973901
// Instruction Emitting Hooks
38983902
//

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ namespace {
398398
SDValue visitMSCATTER(SDNode *N);
399399
SDValue visitFP_TO_FP16(SDNode *N);
400400
SDValue visitFP16_TO_FP(SDNode *N);
401+
SDValue visitVECREDUCE(SDNode *N);
401402

402403
SDValue visitFADDForFMACombine(SDNode *N);
403404
SDValue visitFSUBForFMACombine(SDNode *N);
@@ -1592,6 +1593,19 @@ SDValue DAGCombiner::visit(SDNode *N) {
15921593
case ISD::MSTORE: return visitMSTORE(N);
15931594
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
15941595
case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
1596+
case ISD::VECREDUCE_FADD:
1597+
case ISD::VECREDUCE_FMUL:
1598+
case ISD::VECREDUCE_ADD:
1599+
case ISD::VECREDUCE_MUL:
1600+
case ISD::VECREDUCE_AND:
1601+
case ISD::VECREDUCE_OR:
1602+
case ISD::VECREDUCE_XOR:
1603+
case ISD::VECREDUCE_SMAX:
1604+
case ISD::VECREDUCE_SMIN:
1605+
case ISD::VECREDUCE_UMAX:
1606+
case ISD::VECREDUCE_UMIN:
1607+
case ISD::VECREDUCE_FMAX:
1608+
case ISD::VECREDUCE_FMIN: return visitVECREDUCE(N);
15951609
}
15961610
return SDValue();
15971611
}
@@ -18307,6 +18321,24 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
1830718321
return SDValue();
1830818322
}
1830918323

18324+
SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
18325+
SDValue N0 = N->getOperand(0);
18326+
EVT VT = N0.getValueType();
18327+
18328+
// VECREDUCE over 1-element vector is just an extract.
18329+
if (VT.getVectorNumElements() == 1) {
18330+
SDLoc dl(N);
18331+
SDValue Res = DAG.getNode(
18332+
ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
18333+
DAG.getConstant(0, dl, TLI.getVectorIdxTy(DAG.getDataLayout())));
18334+
if (Res.getValueType() != N->getValueType(0))
18335+
Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res);
18336+
return Res;
18337+
}
18338+
18339+
return SDValue();
18340+
}
18341+
1831018342
/// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle
1831118343
/// with the destination vector and a zero vector.
1831218344
/// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==>

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,22 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
11401140
Action = TLI.getOperationAction(Node->getOpcode(),
11411141
cast<MaskedStoreSDNode>(Node)->getValue().getValueType());
11421142
break;
1143+
case ISD::VECREDUCE_FADD:
1144+
case ISD::VECREDUCE_FMUL:
1145+
case ISD::VECREDUCE_ADD:
1146+
case ISD::VECREDUCE_MUL:
1147+
case ISD::VECREDUCE_AND:
1148+
case ISD::VECREDUCE_OR:
1149+
case ISD::VECREDUCE_XOR:
1150+
case ISD::VECREDUCE_SMAX:
1151+
case ISD::VECREDUCE_SMIN:
1152+
case ISD::VECREDUCE_UMAX:
1153+
case ISD::VECREDUCE_UMIN:
1154+
case ISD::VECREDUCE_FMAX:
1155+
case ISD::VECREDUCE_FMIN:
1156+
Action = TLI.getOperationAction(
1157+
Node->getOpcode(), Node->getOperand(0).getValueType());
1158+
break;
11431159
default:
11441160
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
11451161
Action = TargetLowering::Legal;
@@ -3602,6 +3618,21 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
36023618
ReplaceNode(SDValue(Node, 0), Result);
36033619
break;
36043620
}
3621+
case ISD::VECREDUCE_FADD:
3622+
case ISD::VECREDUCE_FMUL:
3623+
case ISD::VECREDUCE_ADD:
3624+
case ISD::VECREDUCE_MUL:
3625+
case ISD::VECREDUCE_AND:
3626+
case ISD::VECREDUCE_OR:
3627+
case ISD::VECREDUCE_XOR:
3628+
case ISD::VECREDUCE_SMAX:
3629+
case ISD::VECREDUCE_SMIN:
3630+
case ISD::VECREDUCE_UMAX:
3631+
case ISD::VECREDUCE_UMIN:
3632+
case ISD::VECREDUCE_FMAX:
3633+
case ISD::VECREDUCE_FMIN:
3634+
Results.push_back(TLI.expandVecReduce(Node, DAG));
3635+
break;
36053636
case ISD::GLOBAL_OFFSET_TABLE:
36063637
case ISD::GlobalAddress:
36073638
case ISD::GlobalTLSAddress:

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
172172
case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
173173
Res = PromoteIntRes_AtomicCmpSwap(cast<AtomicSDNode>(N), ResNo);
174174
break;
175+
176+
case ISD::VECREDUCE_ADD:
177+
case ISD::VECREDUCE_MUL:
178+
case ISD::VECREDUCE_AND:
179+
case ISD::VECREDUCE_OR:
180+
case ISD::VECREDUCE_XOR:
181+
case ISD::VECREDUCE_SMAX:
182+
case ISD::VECREDUCE_SMIN:
183+
case ISD::VECREDUCE_UMAX:
184+
case ISD::VECREDUCE_UMIN:
185+
Res = PromoteIntRes_VECREDUCE(N);
186+
break;
175187
}
176188

177189
// If the result is null then the sub-method took care of registering it.
@@ -1107,6 +1119,16 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
11071119
case ISD::UMULFIX: Res = PromoteIntOp_MULFIX(N); break;
11081120

11091121
case ISD::FPOWI: Res = PromoteIntOp_FPOWI(N); break;
1122+
1123+
case ISD::VECREDUCE_ADD:
1124+
case ISD::VECREDUCE_MUL:
1125+
case ISD::VECREDUCE_AND:
1126+
case ISD::VECREDUCE_OR:
1127+
case ISD::VECREDUCE_XOR:
1128+
case ISD::VECREDUCE_SMAX:
1129+
case ISD::VECREDUCE_SMIN:
1130+
case ISD::VECREDUCE_UMAX:
1131+
case ISD::VECREDUCE_UMIN: Res = PromoteIntOp_VECREDUCE(N); break;
11101132
}
11111133

11121134
// If the result is null, the sub-method took care of registering results etc.
@@ -1483,6 +1505,39 @@ SDValue DAGTypeLegalizer::PromoteIntOp_FPOWI(SDNode *N) {
14831505
return SDValue(DAG.UpdateNodeOperands(N, N->getOperand(0), Op), 0);
14841506
}
14851507

1508+
SDValue DAGTypeLegalizer::PromoteIntOp_VECREDUCE(SDNode *N) {
1509+
SDLoc dl(N);
1510+
SDValue Op;
1511+
switch (N->getOpcode()) {
1512+
default: llvm_unreachable("Expected integer vector reduction");
1513+
case ISD::VECREDUCE_ADD:
1514+
case ISD::VECREDUCE_MUL:
1515+
case ISD::VECREDUCE_AND:
1516+
case ISD::VECREDUCE_OR:
1517+
case ISD::VECREDUCE_XOR:
1518+
Op = GetPromotedInteger(N->getOperand(0));
1519+
break;
1520+
case ISD::VECREDUCE_SMAX:
1521+
case ISD::VECREDUCE_SMIN:
1522+
Op = SExtPromotedInteger(N->getOperand(0));
1523+
break;
1524+
case ISD::VECREDUCE_UMAX:
1525+
case ISD::VECREDUCE_UMIN:
1526+
Op = ZExtPromotedInteger(N->getOperand(0));
1527+
break;
1528+
}
1529+
1530+
EVT EltVT = Op.getValueType().getVectorElementType();
1531+
EVT VT = N->getValueType(0);
1532+
if (VT.bitsGE(EltVT))
1533+
return DAG.getNode(N->getOpcode(), SDLoc(N), VT, Op);
1534+
1535+
// Result size must be >= element size. If this is not the case after
1536+
// promotion, also promote the result type and then truncate.
1537+
SDValue Reduce = DAG.getNode(N->getOpcode(), dl, EltVT, Op);
1538+
return DAG.getNode(ISD::TRUNCATE, dl, VT, Reduce);
1539+
}
1540+
14861541
//===----------------------------------------------------------------------===//
14871542
// Integer Result Expansion
14881543
//===----------------------------------------------------------------------===//
@@ -1624,6 +1679,16 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
16241679
case ISD::USUBSAT: ExpandIntRes_ADDSUBSAT(N, Lo, Hi); break;
16251680
case ISD::SMULFIX:
16261681
case ISD::UMULFIX: ExpandIntRes_MULFIX(N, Lo, Hi); break;
1682+
1683+
case ISD::VECREDUCE_ADD:
1684+
case ISD::VECREDUCE_MUL:
1685+
case ISD::VECREDUCE_AND:
1686+
case ISD::VECREDUCE_OR:
1687+
case ISD::VECREDUCE_XOR:
1688+
case ISD::VECREDUCE_SMAX:
1689+
case ISD::VECREDUCE_SMIN:
1690+
case ISD::VECREDUCE_UMAX:
1691+
case ISD::VECREDUCE_UMIN: ExpandIntRes_VECREDUCE(N, Lo, Hi); break;
16271692
}
16281693

16291694
// If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -3172,6 +3237,14 @@ void DAGTypeLegalizer::ExpandIntRes_ATOMIC_LOAD(SDNode *N,
31723237
ReplaceValueWith(SDValue(N, 1), Swap.getValue(2));
31733238
}
31743239

3240+
void DAGTypeLegalizer::ExpandIntRes_VECREDUCE(SDNode *N,
3241+
SDValue &Lo, SDValue &Hi) {
3242+
// TODO For VECREDUCE_(AND|OR|XOR) we could split the vector and calculate
3243+
// both halves independently.
3244+
SDValue Res = TLI.expandVecReduce(N, DAG);
3245+
SplitInteger(Res, Lo, Hi);
3246+
}
3247+
31753248
//===----------------------------------------------------------------------===//
31763249
// Integer Operand Expansion
31773250
//===----------------------------------------------------------------------===//
@@ -3840,6 +3913,14 @@ SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
38403913
V0, ConvElem, N->getOperand(2));
38413914
}
38423915

3916+
SDValue DAGTypeLegalizer::PromoteIntRes_VECREDUCE(SDNode *N) {
3917+
// The VECREDUCE result size may be larger than the element size, so
3918+
// we can simply change the result type.
3919+
SDLoc dl(N);
3920+
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
3921+
return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
3922+
}
3923+
38433924
SDValue DAGTypeLegalizer::PromoteIntOp_EXTRACT_VECTOR_ELT(SDNode *N) {
38443925
SDLoc dl(N);
38453926
SDValue V0 = GetPromotedInteger(N->getOperand(0));

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
346346
SDValue PromoteIntRes_ADDSUBSAT(SDNode *N);
347347
SDValue PromoteIntRes_MULFIX(SDNode *N);
348348
SDValue PromoteIntRes_FLT_ROUNDS(SDNode *N);
349+
SDValue PromoteIntRes_VECREDUCE(SDNode *N);
349350

350351
// Integer Operand Promotion.
351352
bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
@@ -380,6 +381,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
380381
SDValue PromoteIntOp_PREFETCH(SDNode *N, unsigned OpNo);
381382
SDValue PromoteIntOp_MULFIX(SDNode *N);
382383
SDValue PromoteIntOp_FPOWI(SDNode *N);
384+
SDValue PromoteIntOp_VECREDUCE(SDNode *N);
383385

384386
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
385387

@@ -438,6 +440,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
438440
void ExpandIntRes_MULFIX (SDNode *N, SDValue &Lo, SDValue &Hi);
439441

440442
void ExpandIntRes_ATOMIC_LOAD (SDNode *N, SDValue &Lo, SDValue &Hi);
443+
void ExpandIntRes_VECREDUCE (SDNode *N, SDValue &Lo, SDValue &Hi);
441444

442445
void ExpandShiftByConstant(SDNode *N, const APInt &Amt,
443446
SDValue &Lo, SDValue &Hi);
@@ -705,6 +708,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
705708
SDValue ScalarizeVecOp_VSETCC(SDNode *N);
706709
SDValue ScalarizeVecOp_STORE(StoreSDNode *N, unsigned OpNo);
707710
SDValue ScalarizeVecOp_FP_ROUND(SDNode *N, unsigned OpNo);
711+
SDValue ScalarizeVecOp_VECREDUCE(SDNode *N);
708712

709713
//===--------------------------------------------------------------------===//
710714
// Vector Splitting Support: LegalizeVectorTypes.cpp
@@ -835,6 +839,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
835839

836840
SDValue WidenVecOp_Convert(SDNode *N);
837841
SDValue WidenVecOp_FCOPYSIGN(SDNode *N);
842+
SDValue WidenVecOp_VECREDUCE(SDNode *N);
838843

839844
//===--------------------------------------------------------------------===//
840845
// Vector Widening Utilities Support: LegalizeVectorTypes.cpp

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,13 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
294294
}
295295
}
296296

297-
bool HasVectorValue = false;
298-
for (SDNode::value_iterator J = Node->value_begin(), E = Node->value_end();
299-
J != E;
300-
++J)
301-
HasVectorValue |= J->isVector();
302-
if (!HasVectorValue)
297+
bool HasVectorValueOrOp = false;
298+
for (auto J = Node->value_begin(), E = Node->value_end(); J != E; ++J)
299+
HasVectorValueOrOp |= J->isVector();
300+
for (const SDValue &Op : Node->op_values())
301+
HasVectorValueOrOp |= Op.getValueType().isVector();
302+
303+
if (!HasVectorValueOrOp)
303304
return TranslateLegalizeResults(Op, Result);
304305

305306
TargetLowering::LegalizeAction Action = TargetLowering::Legal;
@@ -441,6 +442,19 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
441442
break;
442443
case ISD::SINT_TO_FP:
443444
case ISD::UINT_TO_FP:
445+
case ISD::VECREDUCE_ADD:
446+
case ISD::VECREDUCE_MUL:
447+
case ISD::VECREDUCE_AND:
448+
case ISD::VECREDUCE_OR:
449+
case ISD::VECREDUCE_XOR:
450+
case ISD::VECREDUCE_SMAX:
451+
case ISD::VECREDUCE_SMIN:
452+
case ISD::VECREDUCE_UMAX:
453+
case ISD::VECREDUCE_UMIN:
454+
case ISD::VECREDUCE_FADD:
455+
case ISD::VECREDUCE_FMUL:
456+
case ISD::VECREDUCE_FMAX:
457+
case ISD::VECREDUCE_FMIN:
444458
Action = TLI.getOperationAction(Node->getOpcode(),
445459
Node->getOperand(0).getValueType());
446460
break;
@@ -816,6 +830,20 @@ SDValue VectorLegalizer::Expand(SDValue Op) {
816830
case ISD::STRICT_FROUND:
817831
case ISD::STRICT_FTRUNC:
818832
return ExpandStrictFPOp(Op);
833+
case ISD::VECREDUCE_ADD:
834+
case ISD::VECREDUCE_MUL:
835+
case ISD::VECREDUCE_AND:
836+
case ISD::VECREDUCE_OR:
837+
case ISD::VECREDUCE_XOR:
838+
case ISD::VECREDUCE_SMAX:
839+
case ISD::VECREDUCE_SMIN:
840+
case ISD::VECREDUCE_UMAX:
841+
case ISD::VECREDUCE_UMIN:
842+
case ISD::VECREDUCE_FADD:
843+
case ISD::VECREDUCE_FMUL:
844+
case ISD::VECREDUCE_FMAX:
845+
case ISD::VECREDUCE_FMIN:
846+
return TLI.expandVecReduce(Op.getNode(), DAG);
819847
default:
820848
return DAG.UnrollVectorOp(Op.getNode());
821849
}

0 commit comments

Comments
 (0)