-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir] Add reshape propagation patterns for tensor.pad #94489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -956,6 +956,69 @@ class FoldWithProducerReshapeOpByExpansion | |
ControlFusionFn controlFoldingReshapes; | ||
}; | ||
|
||
class FoldPadWithProducerReshapeOpByExpansion | ||
: public OpRewritePattern<tensor::PadOp> { | ||
public: | ||
FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context, | ||
ControlFusionFn foldReshapes, | ||
PatternBenefit benefit = 1) | ||
: OpRewritePattern<tensor::PadOp>(context, benefit), | ||
controlFoldingReshapes(std::move(foldReshapes)) {} | ||
|
||
LogicalResult matchAndRewrite(tensor::PadOp padOp, | ||
PatternRewriter &rewriter) const override { | ||
tensor::CollapseShapeOp reshapeOp = | ||
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); | ||
if (!reshapeOp) | ||
return failure(); | ||
if (!reshapeOp->hasOneUse()) | ||
return failure(); | ||
|
||
if (!controlFoldingReshapes(&padOp.getSourceMutable())) { | ||
return rewriter.notifyMatchFailure(padOp, | ||
"fusion blocked by control function"); | ||
} | ||
|
||
ArrayRef<int64_t> low = padOp.getStaticLow(); | ||
ArrayRef<int64_t> high = padOp.getStaticHigh(); | ||
SmallVector<ReassociationIndices> reassociations = | ||
reshapeOp.getReassociationIndices(); | ||
|
||
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { | ||
if (reInd.size() != 1 && (l != 0 || h != 0)) | ||
return failure(); | ||
} | ||
|
||
SmallVector<OpFoldResult> newLow, newHigh; | ||
RankedTensorType expandedType = reshapeOp.getSrcType(); | ||
RankedTensorType paddedType = padOp.getResultType(); | ||
SmallVector<int64_t> expandedPaddedShape(expandedType.getShape()); | ||
for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | ||
if (reInd.size() == 1) { | ||
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; | ||
} | ||
for (size_t i = 0; i < reInd.size(); ++i) { | ||
newLow.push_back(padOp.getMixedLowPad()[idx]); | ||
newHigh.push_back(padOp.getMixedHighPad()[idx]); | ||
} | ||
} | ||
|
||
Location loc = padOp->getLoc(); | ||
RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); | ||
auto newPadOp = rewriter.create<tensor::PadOp>( | ||
loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, | ||
padOp.getConstantPaddingValue(), padOp.getNofold()); | ||
|
||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( | ||
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); | ||
|
||
return success(); | ||
} | ||
|
||
private: | ||
ControlFusionFn controlFoldingReshapes; | ||
}; | ||
|
||
/// Pattern to fold a tensor.expand_shape op with its producer generic op | ||
/// by expanding the dimensionality of the loop in the producer op. | ||
struct FoldReshapeWithGenericOpByExpansion | ||
|
@@ -1702,6 +1765,85 @@ class FoldWithProducerReshapeOpByCollapsing | |
ControlFusionFn controlFoldingReshapes; | ||
}; | ||
|
||
class FoldPadWithProducerReshapeOpByCollapsing | ||
: public OpRewritePattern<tensor::PadOp> { | ||
public: | ||
FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context, | ||
ControlFusionFn foldReshapes, | ||
PatternBenefit benefit = 1) | ||
: OpRewritePattern<tensor::PadOp>(context, benefit), | ||
controlFoldingReshapes(std::move(foldReshapes)) {} | ||
|
||
LogicalResult matchAndRewrite(tensor::PadOp padOp, | ||
PatternRewriter &rewriter) const override { | ||
tensor::ExpandShapeOp reshapeOp = | ||
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); | ||
if (!reshapeOp) | ||
return failure(); | ||
if (!reshapeOp->hasOneUse()) | ||
return failure(); | ||
|
||
if (!controlFoldingReshapes(&padOp.getSourceMutable())) { | ||
return rewriter.notifyMatchFailure(padOp, | ||
"fusion blocked by control function"); | ||
} | ||
|
||
ArrayRef<int64_t> low = padOp.getStaticLow(); | ||
ArrayRef<int64_t> high = padOp.getStaticHigh(); | ||
SmallVector<ReassociationIndices> reassociations = | ||
reshapeOp.getReassociationIndices(); | ||
|
||
for (auto reInd : reassociations) { | ||
if (reInd.size() == 1) | ||
continue; | ||
if (llvm::any_of(reInd, [&](int64_t ind) { | ||
return low[ind] != 0 || high[ind] != 0; | ||
})) { | ||
return failure(); | ||
} | ||
} | ||
|
||
SmallVector<OpFoldResult> newLow, newHigh; | ||
RankedTensorType collapsedType = reshapeOp.getSrcType(); | ||
RankedTensorType paddedType = padOp.getResultType(); | ||
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape()); | ||
SmallVector<OpFoldResult> expandedPaddedSizes( | ||
getMixedValues(reshapeOp.getStaticOutputShape(), | ||
reshapeOp.getOutputShape(), rewriter)); | ||
AffineExpr d0, d1, d2; | ||
bindDims(rewriter.getContext(), d0, d1, d2); | ||
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2}); | ||
Location loc = reshapeOp->getLoc(); | ||
for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | ||
OpFoldResult l = padOp.getMixedLowPad()[reInd[0]]; | ||
OpFoldResult h = padOp.getMixedHighPad()[reInd[0]]; | ||
if (reInd.size() == 1) { | ||
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; | ||
OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( | ||
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]}); | ||
expandedPaddedSizes[reInd[0]] = paddedSize; | ||
} | ||
newLow.push_back(l); | ||
newHigh.push_back(h); | ||
} | ||
|
||
RankedTensorType collapsedPaddedType = | ||
paddedType.clone(collapsedPaddedShape); | ||
auto newPadOp = rewriter.create<tensor::PadOp>( | ||
loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, | ||
padOp.getConstantPaddingValue(), padOp.getNofold()); | ||
|
||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( | ||
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations, | ||
expandedPaddedSizes); | ||
|
||
return success(); | ||
} | ||
|
||
private: | ||
ControlFusionFn controlFoldingReshapes; | ||
}; | ||
|
||
/// Pattern to collapse dimensions. | ||
template <typename LinalgType> | ||
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> { | ||
|
@@ -1937,6 +2079,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( | |
const ControlFusionFn &controlFoldingReshapes) { | ||
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), | ||
controlFoldingReshapes); | ||
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAICS both the patterns added here are "propagating by collapsing". You are moving the expand_shape down and collapse_shape up. In both cases the pad is happening on collapsed dimensions. So you should add both to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I twisted myself in a knot when looking at tests. You are right. |
||
controlFoldingReshapes); | ||
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), | ||
controlFoldingReshapes); | ||
} | ||
|
@@ -1946,6 +2090,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( | |
const ControlFusionFn &controlFoldingReshapes) { | ||
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(), | ||
controlFoldingReshapes); | ||
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>( | ||
patterns.getContext(), controlFoldingReshapes); | ||
} | ||
|
||
void mlir::linalg::populateElementwiseOpsFusionPatterns( | ||
|
Uh oh!
There was an error while loading. Please reload this page.