From 930a4d72e8d818af62744070c00cd667aaacbd9e Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 6 Jun 2024 10:46:45 -0400 Subject: [PATCH 1/4] [mlir] Fix bugs in expand_shape patterns after semantics changes --- .../mlir/Dialect/Utils/ReshapeOpsUtils.h | 56 ++++++++++++++---- mlir/test/Dialect/Tensor/canonicalize.mlir | 57 ++++++++++++++++++- 2 files changed, 101 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index e8f6edc3f133e..3b986f4a60064 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -85,21 +85,55 @@ bool isReassociationValid(ArrayRef reassociation, template static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef operands) { - + // Fold identity reshape. if (reshapeOp.getSrcType() == reshapeOp.getType()) return reshapeOp.getSrc(); - // Fold producer-consumer reshape ops where the operand type of the - // producer is same as the return type of the consumer. - auto reshapeSrcOp = - reshapeOp.getSrc().template getDefiningOp(); - if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) - return reshapeSrcOp.getSrc(); - // Reshape of a constant can be replaced with a new constant. if (auto elements = dyn_cast_or_null(operands.front())) return elements.reshape(cast(reshapeOp.getResult().getType())); + // Fold if the producer reshape source has the same shape with at most 1 + // dynamic dimension. + auto reshapeSrcOp = + reshapeOp.getSrc().template getDefiningOp(); + if (!reshapeSrcOp) + return nullptr; + auto srcType = reshapeSrcOp.getSrcType(); + auto resultType = reshapeOp.getResultType(); + if (srcType != resultType) + return nullptr; + + // If the reshapes are expanding and then collapsing, the ops can be folded + // despite multiple dynamic dimensions. + if (srcType.getRank() < reshapeSrcOp.getResultType().getRank()) + return reshapeSrcOp.getSrc(); + // Otherwise, only 1 dynamic dimension is allowed. + if (srcType == resultType && + llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) { + return reshapeSrcOp.getSrc(); + } + + // Fold producer-consumer reshape ops when they are perfect inverses of each + // other: + // 1) Reassociation indices are equivalent. + // 2) Boundary types are equivalent. + // 3) No reassociations have more than 1 dynamic dimension, and reassociated + // shapes are equal for each reassociation. + auto reassociations = reshapeOp.getReassociationIndices(); + auto inverseReassociations = reshapeSrcOp.getReassociationIndices(); + if (reassociations != inverseReassociations) + return nullptr; + ArrayRef expandedSrcShape = srcType.getShape(); + ArrayRef expandedResultShape = resultType.getShape(); + if (llvm::none_of(reassociations, [&](auto reInd) { + auto srcSlice = expandedSrcShape.slice(reInd.front(), reInd.size()); + auto resSlice = expandedResultShape.slice(reInd.front(), reInd.size()); + return srcSlice == resSlice && + llvm::count_if(srcSlice, ShapedType::isDynamic) > 1; + })) { + return reshapeSrcOp.getSrc(); + } return nullptr; } @@ -360,10 +394,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern { resultShape.slice(resultIndices.front(), resultIndices.size()); if (srcSubShape.size() == resultSubShape.size()) { - if (srcSubShape == resultSubShape) + if (srcSubShape == resultSubShape && + llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) { composedReassociation.push_back(srcIndices); - else + } else { return std::nullopt; + } } // Find reassociation to collapse `srcSubShape` into `resultSubShape`. diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index f7fbd3834288b..4a04d37d4be29 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> return %1 : tensor<12x4xf32> } // CHECK-LABEL: @fold_collapse_of_expand -// CHECK-NOT: linalg.{{.*}}shape +// CHECK-NOT: tensor.{{.*}}_shape // ----- @@ -1152,7 +1152,60 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor, %arg1: index return %1 : tensor } // CHECK-LABEL: @fold_collapse_of_expand_dynamic -// CHECK-NOT: linalg.{{.*}}_shape +// CHECK-NOT: tensor.{{.*}}_shape + +// ----- + +func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index) + -> tensor { + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0, 1], [2]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic +// CHECK-NOT: tensor.{{.*}}_shape + +// ----- + +func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] + : tensor<3x4x4xf32> into tensor<12x4xf32> + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4] + : tensor<12x4xf32> into tensor<3x4x4xf32> + return %1 : tensor<3x4x4xf32> +} +// CHECK-LABEL: @fold_expand_of_collapse +// CHECK-NOT: tensor.{{.*}}_shape + +// ----- + +func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: index, %arg2: index) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] + : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @fold_expand_of_collapse_dynamic +// CHECK-NOT: tensor.{{.*}}_shape + +// ----- + +func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] + : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic +// CHECK: tensor.collapse_shape +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape +// CHECK: return %[[EXPAND]] // ----- From 0536c7ec051aa8b2d6b2b9cc04a54a2d5bcfdb8d Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 6 Jun 2024 12:07:17 -0400 Subject: [PATCH 2/4] fix bug --- mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 9 ++++----- mlir/test/Dialect/Tensor/canonicalize.mlir | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index 3b986f4a60064..31a23be26d5a7 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -104,11 +104,6 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, if (srcType != resultType) return nullptr; - // If the reshapes are expanding and then collapsing, the ops can be folded - // despite multiple dynamic dimensions. - if (srcType.getRank() < reshapeSrcOp.getResultType().getRank()) - return reshapeSrcOp.getSrc(); - // Otherwise, only 1 dynamic dimension is allowed. if (srcType == resultType && llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) { return reshapeSrcOp.getSrc(); @@ -124,6 +119,10 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, auto inverseReassociations = reshapeSrcOp.getReassociationIndices(); if (reassociations != inverseReassociations) return nullptr; + // If the reshapes are expanding and then collapsing, the ops can be folded + // despite multiple dynamic dimensions. + if (srcType.getRank() < reshapeSrcOp.getResultType().getRank()) + return reshapeSrcOp.getSrc(); ArrayRef expandedSrcShape = srcType.getShape(); ArrayRef expandedResultShape = resultType.getShape(); if (llvm::none_of(reassociations, [&](auto reInd) { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 4a04d37d4be29..9a6b03986ccb6 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1169,6 +1169,21 @@ func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor, %arg1: // ----- +func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index, %arg4: index) + -> tensor { + %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic +// CHECK: tensor.expand_shape +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK: return %[[COLLAPSE]] + +// ----- + func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<3x4x4xf32> into tensor<12x4xf32> From 2dc8fea7edb4797e15bf1c555dae57ec42a393b4 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 6 Jun 2024 12:16:45 -0400 Subject: [PATCH 3/4] address comments --- mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index 31a23be26d5a7..96f0f7bf1aa49 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -104,8 +104,7 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, if (srcType != resultType) return nullptr; - if (srcType == resultType && - llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) { + if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) { return reshapeSrcOp.getSrc(); } @@ -116,8 +115,7 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, // 3) No reassociations have more than 1 dynamic dimension, and reassociated // shapes are equal for each reassociation. auto reassociations = reshapeOp.getReassociationIndices(); - auto inverseReassociations = reshapeSrcOp.getReassociationIndices(); - if (reassociations != inverseReassociations) + if (reassociations != reshapeSrcOp.getReassociationIndices()) return nullptr; // If the reshapes are expanding and then collapsing, the ops can be folded // despite multiple dynamic dimensions. @@ -125,11 +123,10 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, return reshapeSrcOp.getSrc(); ArrayRef expandedSrcShape = srcType.getShape(); ArrayRef expandedResultShape = resultType.getShape(); - if (llvm::none_of(reassociations, [&](auto reInd) { - auto srcSlice = expandedSrcShape.slice(reInd.front(), reInd.size()); - auto resSlice = expandedResultShape.slice(reInd.front(), reInd.size()); - return srcSlice == resSlice && - llvm::count_if(srcSlice, ShapedType::isDynamic) > 1; + if (llvm::all_of(reassociations, [&](auto reInd) { + ArrayRef srcSlice = + expandedSrcShape.slice(reInd.front(), reInd.size()); + return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2; })) { return reshapeSrcOp.getSrc(); } From e842079c0420b94739f7fe4e44b39dfea303214d Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 6 Jun 2024 13:51:54 -0400 Subject: [PATCH 4/4] more comments --- mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index 96f0f7bf1aa49..89bc57f09ec8b 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -121,11 +121,9 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, // despite multiple dynamic dimensions. if (srcType.getRank() < reshapeSrcOp.getResultType().getRank()) return reshapeSrcOp.getSrc(); - ArrayRef expandedSrcShape = srcType.getShape(); - ArrayRef expandedResultShape = resultType.getShape(); if (llvm::all_of(reassociations, [&](auto reInd) { ArrayRef srcSlice = - expandedSrcShape.slice(reInd.front(), reInd.size()); + srcType.getShape().slice(reInd.front(), reInd.size()); return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2; })) { return reshapeSrcOp.getSrc();