@@ -1436,16 +1436,18 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
1436
1436
}
1437
1437
1438
1438
// Vector min/max reductions
1439
- if (Subtarget.hasSSE41())
1440
- {
1439
+ // These are lowered to PHMINPOSUW if possible,
1440
+ // otherwise they are expaned to shuffles + binops.
1441
+ if (Subtarget.hasSSE41()) {
1441
1442
for (MVT VT : MVT::vector_valuetypes()) {
1442
- if (VT.getScalarType() == MVT::i8 || VT.getScalarType() == MVT::i16)
1443
- {
1444
- setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1445
- setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1446
- setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1447
- setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1448
- }
1443
+ if (!VT.isFixedLengthVector() || (VT.getSizeInBits() % 128) != 0 ||
1444
+ !(VT.getScalarType() == MVT::i8 || VT.getScalarType() == MVT::i16))
1445
+ continue;
1446
+
1447
+ setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1448
+ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1449
+ setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1450
+ setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1449
1451
}
1450
1452
}
1451
1453
@@ -25426,9 +25428,11 @@ static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op,
25426
25428
// Create a min/max v8i16/v16i8 horizontal reduction with PHMINPOSUW.
25427
25429
static SDValue createMinMaxReduction(SDValue Src, EVT TargetVT, SDLoc DL,
25428
25430
ISD::NodeType BinOp, SelectionDAG &DAG,
25429
- const X86Subtarget &Subtarget)
25430
- {
25431
- assert(Subtarget.hasSSE41() && "The caller must check if SSE4.1 is available");
25431
+ const X86Subtarget &Subtarget) {
25432
+ assert(Subtarget.hasSSE41() &&
25433
+ "The caller must check if SSE4.1 is available");
25434
+ assert(TargetVT == MVT::i16 ||
25435
+ TargetVT == MVT::i8 && "Unexpected return type");
25432
25436
25433
25437
EVT SrcVT = Src.getValueType();
25434
25438
EVT SrcSVT = SrcVT.getScalarType();
@@ -25484,31 +25488,11 @@ static SDValue createMinMaxReduction(SDValue Src, EVT TargetVT, SDLoc DL,
25484
25488
}
25485
25489
25486
25490
static SDValue LowerVECTOR_REDUCE_MINMAX(SDValue Op,
25487
- const X86Subtarget& Subtarget,
25488
- SelectionDAG& DAG)
25489
- {
25490
- ISD::NodeType BinOp;
25491
- switch (Op.getOpcode())
25492
- {
25493
- default:
25494
- assert(false && "Expected min/max reduction");
25495
- break;
25496
- case ISD::VECREDUCE_UMIN:
25497
- BinOp = ISD::UMIN;
25498
- break;
25499
- case ISD::VECREDUCE_UMAX:
25500
- BinOp = ISD::UMAX;
25501
- break;
25502
- case ISD::VECREDUCE_SMIN:
25503
- BinOp = ISD::SMIN;
25504
- break;
25505
- case ISD::VECREDUCE_SMAX:
25506
- BinOp = ISD::SMAX;
25507
- break;
25508
- }
25509
-
25491
+ const X86Subtarget &Subtarget,
25492
+ SelectionDAG &DAG) {
25493
+ ISD::NodeType BinOp = ISD::getVecReduceBaseOpcode(Op.getOpcode());
25510
25494
return createMinMaxReduction(Op->getOperand(0), Op.getValueType(), SDLoc(Op),
25511
- BinOp, DAG, Subtarget);
25495
+ BinOp, DAG, Subtarget);
25512
25496
}
25513
25497
25514
25498
static SDValue LowerSIGN_EXTEND(SDValue Op, const X86Subtarget &Subtarget,
@@ -46299,8 +46283,8 @@ static SDValue combineMinMaxReduction(SDNode *Extract, SelectionDAG &DAG,
46299
46283
if (!Src)
46300
46284
return SDValue();
46301
46285
46302
- return createMinMaxReduction(Src, ExtractVT, SDLoc(Extract),
46303
- BinOp, DAG, Subtarget);
46286
+ return createMinMaxReduction(Src, ExtractVT, SDLoc(Extract), BinOp, DAG,
46287
+ Subtarget);
46304
46288
}
46305
46289
46306
46290
// Attempt to replace an all_of/any_of/parity style horizontal reduction with a MOVMSK.
@@ -47136,8 +47120,8 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
47136
47120
/// scalars back, while for x64 we should use 64-bit extracts and shifts.
47137
47121
static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
47138
47122
TargetLowering::DAGCombinerInfo &DCI,
47139
- const X86Subtarget &Subtarget,
47140
- bool& TransformedBinOpReduction) {
47123
+ const X86Subtarget &Subtarget,
47124
+ bool & TransformedBinOpReduction) {
47141
47125
if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget))
47142
47126
return NewOp;
47143
47127
@@ -47321,26 +47305,27 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
47321
47305
return SDValue();
47322
47306
}
47323
47307
47324
- static SDValue combineExtractVectorEltAndOperand(SDNode* N, SelectionDAG& DAG,
47325
- TargetLowering::DAGCombinerInfo& DCI ,
47326
- const X86Subtarget& Subtarget)
47327
- {
47308
+ static SDValue
47309
+ combineExtractVectorEltAndOperand(SDNode *N, SelectionDAG &DAG ,
47310
+ TargetLowering::DAGCombinerInfo &DCI,
47311
+ const X86Subtarget &Subtarget) {
47328
47312
bool TransformedBinOpReduction = false;
47329
- auto Op = combineExtractVectorElt(N, DAG, DCI, Subtarget, TransformedBinOpReduction);
47313
+ auto Op = combineExtractVectorElt(N, DAG, DCI, Subtarget,
47314
+ TransformedBinOpReduction);
47330
47315
47331
- if (TransformedBinOpReduction)
47332
- {
47316
+ if (TransformedBinOpReduction) {
47333
47317
// In case we simplified N = extract_vector_element(V, 0) with Op and V
47334
47318
// resulted from a reduction, then we need to replace all uses of V with
47335
47319
// scalar_to_vector(Op) to make sure that we eliminated the binop + shuffle
47336
- // pyramid. This is safe to do, because the elements of V are undefined except
47337
- // for the zeroth element and Op does not depend on V.
47320
+ // pyramid. This is safe to do, because the elements of V are undefined
47321
+ // except for the zeroth element and Op does not depend on V.
47338
47322
47339
47323
auto OldV = N->getOperand(0);
47340
- assert(!Op.getNode()->hasPredecessor(OldV.getNode()) &&
47341
- "Op must not depend on the converted reduction");
47324
+ assert(!Op.getNode()->hasPredecessor(OldV.getNode()) &&
47325
+ "Op must not depend on the converted reduction");
47342
47326
47343
- auto NewV = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op);
47327
+ auto NewV =
47328
+ DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op);
47344
47329
47345
47330
auto NV = DCI.CombineTo(N, Op);
47346
47331
DCI.CombineTo(OldV.getNode(), NewV);
0 commit comments