Skip to content

[mlir] Add reshape propagation patterns for tensor.pad #94489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,69 @@ class FoldWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};

class FoldPadWithProducerReshapeOpByExpansion
: public OpRewritePattern<tensor::PadOp> {
public:
FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::PadOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
tensor::CollapseShapeOp reshapeOp =
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
return failure();
if (!reshapeOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}

ArrayRef<int64_t> low = padOp.getStaticLow();
ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();

for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
if (reInd.size() != 1 && (l != 0 || h != 0))
return failure();
}

SmallVector<OpFoldResult> newLow, newHigh;
RankedTensorType expandedType = reshapeOp.getSrcType();
RankedTensorType paddedType = padOp.getResultType();
SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() == 1) {
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
}
for (size_t i = 0; i < reInd.size(); ++i) {
newLow.push_back(padOp.getMixedLowPad()[idx]);
newHigh.push_back(padOp.getMixedHighPad()[idx]);
}
}

Location loc = padOp->getLoc();
RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
auto newPadOp = rewriter.create<tensor::PadOp>(
loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());

rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);

return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
Expand Down Expand Up @@ -1702,6 +1765,85 @@ class FoldWithProducerReshapeOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};

class FoldPadWithProducerReshapeOpByCollapsing
: public OpRewritePattern<tensor::PadOp> {
public:
FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::PadOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
tensor::ExpandShapeOp reshapeOp =
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
if (!reshapeOp)
return failure();
if (!reshapeOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}

ArrayRef<int64_t> low = padOp.getStaticLow();
ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();

for (auto reInd : reassociations) {
if (reInd.size() == 1)
continue;
if (llvm::any_of(reInd, [&](int64_t ind) {
return low[ind] != 0 || high[ind] != 0;
})) {
return failure();
}
}

SmallVector<OpFoldResult> newLow, newHigh;
RankedTensorType collapsedType = reshapeOp.getSrcType();
RankedTensorType paddedType = padOp.getResultType();
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
SmallVector<OpFoldResult> expandedPaddedSizes(
getMixedValues(reshapeOp.getStaticOutputShape(),
reshapeOp.getOutputShape(), rewriter));
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
Location loc = reshapeOp->getLoc();
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
if (reInd.size() == 1) {
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
expandedPaddedSizes[reInd[0]] = paddedSize;
}
newLow.push_back(l);
newHigh.push_back(h);
}

RankedTensorType collapsedPaddedType =
paddedType.clone(collapsedPaddedShape);
auto newPadOp = rewriter.create<tensor::PadOp>(
loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());

rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
expandedPaddedSizes);

return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
Expand Down Expand Up @@ -1937,6 +2079,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICS both the patterns added here are "propagating by collapsing". You are moving the expand_shape down and collapse_shape up. In both cases the pad is happening on collapsed dimensions. So you should add both to the populateFoldReshapeOpsByCollapsingPatterns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FoldPadWithProducerReshapeOpByExpansion pushes the producer collapse_shape down, expanding the pad. FoldPadWithProducerReshapeOpByCollapsing pushes the producer expand_shape down, collapsing the pad. I think these are in the right places here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I twisted myself in a knot when looking at tests. You are right.

controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
Expand All @@ -1946,6 +2090,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
}

void mlir::linalg::populateElementwiseOpsFusionPatterns(
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0:
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPAND_ARG0]] :
// CHECK: return %[[GENERIC]]

// -----

func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
%cst = arith.constant 0 : i32
%padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
%arg5: index, %arg6: index, %arg7: index, %arg8: index):
tensor.yield %cst : i32
} : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32>
}
// CHECK: func @fuse_by_collapsing_pad(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
// CHECK-SAME: output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
// CHECK: return %[[EXPAND]]

// -----

func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
%cst = arith.constant 0 : i32
%padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
%arg5: index, %arg6: index, %arg7: index, %arg8: index):
tensor.yield %cst : i32
} : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32>
}
// CHECK: func @no_fuse_by_collapsing_pad(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
// CHECK-SAME: output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]]
// CHECK-SAME: low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
// CHECK: return %[[PAD]]

// -----

func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
%s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index,
%l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?x?x?xf32> {
%expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
%cst = arith.constant 0.0 : f32
%padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
return %padded_0 : tensor<?x?x?x?x?x?xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
// CHECK: func @fuse_by_collapsing_dynamic_pad(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
// CHECK: %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]]
// CHECK: %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]]
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
// CHECK-SAME: low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
// CHECK: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
// CHECK: return %[[EXPAND]]
61 changes: 61 additions & 0 deletions mlir/test/Dialect/Linalg/reshape_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,64 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
// CHECK: return %[[T4]]

// -----

func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
%cst = arith.constant 0 : i32
%padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
tensor.yield %cst : i32
} : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
return %padded_0 : tensor<8x12x17x336x14xi32>
}
// CHECK: func @fuse_by_expanding_pad(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
// CHECK-SAME: low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
// CHECK-SAME: : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
// CHECK: return %[[COLLAPSE]]

// -----

func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
%cst = arith.constant 0 : i32
%padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
tensor.yield %cst : i32
} : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
return %padded_0 : tensor<8x12x17x339x14xi32>
}
// CHECK: func @no_fuse_by_expanding_pad(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
// CHECK-SAME: : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2]
// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
// CHECK: return %[[PAD]]

// -----

func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?xi32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
%cst = arith.constant 0 : i32
%padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
tensor.yield %cst : i32
} : tensor<?x?x?x?xi32> to tensor<?x?x?x?xi32>
return %padded_0 : tensor<?x?x?x?xi32>
}
// CHECK: func @fuse_by_expanding_dynamic_pad(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xi32>
// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
// CHECK-SAME: low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
// CHECK: tensor<?x?x?x?x?x?xi32> to tensor<?x?x?x?x?x?xi32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
// CHECK: return %[[COLLAPSE]]
Loading