diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index e8f6edc3f133e..89bc57f09ec8b 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -85,21 +85,49 @@ bool isReassociationValid(ArrayRef reassociation, template static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef operands) { - + // Fold identity reshape. if (reshapeOp.getSrcType() == reshapeOp.getType()) return reshapeOp.getSrc(); - // Fold producer-consumer reshape ops where the operand type of the - // producer is same as the return type of the consumer. - auto reshapeSrcOp = - reshapeOp.getSrc().template getDefiningOp(); - if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) - return reshapeSrcOp.getSrc(); - // Reshape of a constant can be replaced with a new constant. if (auto elements = dyn_cast_or_null(operands.front())) return elements.reshape(cast(reshapeOp.getResult().getType())); + // Fold if the producer reshape source has the same shape with at most 1 + // dynamic dimension. + auto reshapeSrcOp = + reshapeOp.getSrc().template getDefiningOp(); + if (!reshapeSrcOp) + return nullptr; + auto srcType = reshapeSrcOp.getSrcType(); + auto resultType = reshapeOp.getResultType(); + if (srcType != resultType) + return nullptr; + + if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) { + return reshapeSrcOp.getSrc(); + } + + // Fold producer-consumer reshape ops when they are perfect inverses of each + // other: + // 1) Reassociation indices are equivalent. + // 2) Boundary types are equivalent. + // 3) No reassociations have more than 1 dynamic dimension, and reassociated + // shapes are equal for each reassociation. + auto reassociations = reshapeOp.getReassociationIndices(); + if (reassociations != reshapeSrcOp.getReassociationIndices()) + return nullptr; + // If the reshapes are expanding and then collapsing, the ops can be folded + // despite multiple dynamic dimensions. + if (srcType.getRank() < reshapeSrcOp.getResultType().getRank()) + return reshapeSrcOp.getSrc(); + if (llvm::all_of(reassociations, [&](auto reInd) { + ArrayRef srcSlice = + srcType.getShape().slice(reInd.front(), reInd.size()); + return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2; + })) { + return reshapeSrcOp.getSrc(); + } return nullptr; } @@ -360,10 +388,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern { resultShape.slice(resultIndices.front(), resultIndices.size()); if (srcSubShape.size() == resultSubShape.size()) { - if (srcSubShape == resultSubShape) + if (srcSubShape == resultSubShape && + llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) { composedReassociation.push_back(srcIndices); - else + } else { return std::nullopt; + } } // Find reassociation to collapse `srcSubShape` into `resultSubShape`. diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index f7fbd3834288b..9a6b03986ccb6 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> return %1 : tensor<12x4xf32> } // CHECK-LABEL: @fold_collapse_of_expand -// CHECK-NOT: linalg.{{.*}}shape +// CHECK-NOT: tensor.{{.*}}_shape // ----- @@ -1152,7 +1152,75 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor, %arg1: index return %1 : tensor } // CHECK-LABEL: @fold_collapse_of_expand_dynamic -// CHECK-NOT: linalg.{{.*}}_shape +// CHECK-NOT: tensor.{{.*}}_shape + +// ----- + +func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index) + -> tensor { + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0, 1], [2]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic +// CHECK-NOT: tensor.{{.*}}_shape + +// ----- + +func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index, %arg4: index) + -> tensor { + %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4] + : tensor into tensor + %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic +// CHECK: tensor.expand_shape +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK: return %[[COLLAPSE]] + +// ----- + +func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] + : tensor<3x4x4xf32> into tensor<12x4xf32> + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4] + : tensor<12x4xf32> into tensor<3x4x4xf32> + return %1 : tensor<3x4x4xf32> +} +// CHECK-LABEL: @fold_expand_of_collapse +// CHECK-NOT: tensor.{{.*}}_shape + +// ----- + +func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: index, %arg2: index) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] + : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @fold_expand_of_collapse_dynamic +// CHECK-NOT: tensor.{{.*}}_shape + +// ----- + +func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] + : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic +// CHECK: tensor.collapse_shape +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape +// CHECK: return %[[EXPAND]] // -----