-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms] Dialect conversion: Add missing erasure notifications #145030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Dialect conversion: Add missing erasure notifications #145030
Conversation
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd missing listener notifications when erasing nested blocks/operations. This commit also moves some of the functionality from Full diff: https://github.com/llvm/llvm-project/pull/145030.diff 2 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ff48647f43305..7419d79cd8856 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -274,6 +274,26 @@ struct RewriterState {
// IR rewrites
//===----------------------------------------------------------------------===//
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
+
+/// Notify the listener that the given block and its contents are being erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
+ for (Operation &op : b)
+ notifyIRErased(listener, op);
+ listener->notifyBlockErased(&b);
+}
+
+/// Notify the listener that the given operation and its contents are being
+/// erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
+ for (Region &r : op.getRegions()) {
+ for (Block &b : r) {
+ notifyIRErased(listener, b);
+ }
+ }
+ listener->notifyOperationErased(&op);
+}
+
/// An IR rewrite that can be committed (upon success) or rolled back (upon
/// failure).
///
@@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite {
}
void commit(RewriterBase &rewriter) override {
- // Erase the block.
assert(block && "expected block");
- assert(block->empty() && "expected empty block");
- // Notify the listener that the block is about to be erased.
+ // Notify the listener that the block and its contents are being erased.
if (auto *listener =
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
- listener->notifyBlockErased(block);
+ notifyIRErased(listener, *block);
}
void cleanup(RewriterBase &rewriter) override {
+ // Erase the contents of the block.
+ for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
+ rewriter.eraseOp(&op);
+ assert(block->empty() && "expected empty block");
+
// Erase the block.
block->dropAllDefinedValueUses();
delete block;
@@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
if (getConfig().unlegalizedOps)
getConfig().unlegalizedOps->erase(op);
- // Notify the listener that the operation (and its nested operations) was
- // erased.
- if (listener) {
- op->walk<WalkOrder::PostOrder>(
- [&](Operation *op) { listener->notifyOperationErased(op); });
- }
+ // Notify the listener that the operation and its contents are being erased.
+ if (listener)
+ notifyIRErased(listener, *op);
// Do not erase the operation yet. It may still be referenced in `mapping`.
// Just unlink it for now and erase it during cleanup.
@@ -1605,6 +1625,8 @@ void ConversionPatternRewriterImpl::replaceOp(
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+ assert(!wasOpReplaced(block->getParentOp()) &&
+ "attempting to erase a block within a replaced/erased op");
appendRewrite<EraseBlockRewrite>(block);
// Unlink the block from its parent region. The block is kept in the rewrite
@@ -1612,12 +1634,16 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
block->getParent()->getBlocks().remove(block);
+
+ // Mark all nested ops as erased.
+ block->walk([&](Operation *op) { replacedOps.insert(op); });
}
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
- assert(!wasOpReplaced(block->getParentOp()) &&
- "attempting to insert into a region within a replaced/erased op");
+ assert(
+ (!config.allowPatternRollback || !wasOpReplaced(block->getParentOp())) &&
+ "attempting to insert into a region within a replaced/erased op");
LLVM_DEBUG(
{
Operation *parent = block->getParentOp();
@@ -1630,6 +1656,11 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
}
});
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed. No extra bookkeeping is needed.
+ return;
+ }
+
if (!previous) {
// This is a newly created block.
appendRewrite<CreateBlockRewrite>(block);
@@ -1709,13 +1740,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
- assert(!impl->wasOpReplaced(block->getParentOp()) &&
- "attempting to erase a block within a replaced/erased op");
-
- // Mark all ops for erasure.
- for (Operation &op : *block)
- eraseOp(&op);
-
impl->eraseBlock(block);
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 34948ae685f0a..204c8c1456826 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -461,12 +461,26 @@ func.func @convert_detached_signature() {
// -----
+// CHECK: notifyOperationReplaced: test.erase_op
+// CHECK: notifyOperationErased: test.dummy_op_lvl_2
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.dummy_op_lvl_1
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.erase_op
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
+// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
+
// CHECK-LABEL: func @circular_mapping()
// CHECK-NEXT: "test.valid"() : () -> ()
func.func @circular_mapping() {
// Regression test that used to crash due to circular
- // unrealized_conversion_cast ops.
- %0 = "test.erase_op"() : () -> (i64)
+ // unrealized_conversion_cast ops.
+ %0 = "test.erase_op"() ({
+ "test.dummy_op_lvl_1"() ({
+ "test.dummy_op_lvl_2"() : () -> ()
+ }) : () -> ()
+ }): () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesAdd missing listener notifications when erasing nested blocks/operations. This commit also moves some of the functionality from Full diff: https://github.com/llvm/llvm-project/pull/145030.diff 2 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ff48647f43305..7419d79cd8856 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -274,6 +274,26 @@ struct RewriterState {
// IR rewrites
//===----------------------------------------------------------------------===//
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
+
+/// Notify the listener that the given block and its contents are being erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
+ for (Operation &op : b)
+ notifyIRErased(listener, op);
+ listener->notifyBlockErased(&b);
+}
+
+/// Notify the listener that the given operation and its contents are being
+/// erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
+ for (Region &r : op.getRegions()) {
+ for (Block &b : r) {
+ notifyIRErased(listener, b);
+ }
+ }
+ listener->notifyOperationErased(&op);
+}
+
/// An IR rewrite that can be committed (upon success) or rolled back (upon
/// failure).
///
@@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite {
}
void commit(RewriterBase &rewriter) override {
- // Erase the block.
assert(block && "expected block");
- assert(block->empty() && "expected empty block");
- // Notify the listener that the block is about to be erased.
+ // Notify the listener that the block and its contents are being erased.
if (auto *listener =
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
- listener->notifyBlockErased(block);
+ notifyIRErased(listener, *block);
}
void cleanup(RewriterBase &rewriter) override {
+ // Erase the contents of the block.
+ for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
+ rewriter.eraseOp(&op);
+ assert(block->empty() && "expected empty block");
+
// Erase the block.
block->dropAllDefinedValueUses();
delete block;
@@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
if (getConfig().unlegalizedOps)
getConfig().unlegalizedOps->erase(op);
- // Notify the listener that the operation (and its nested operations) was
- // erased.
- if (listener) {
- op->walk<WalkOrder::PostOrder>(
- [&](Operation *op) { listener->notifyOperationErased(op); });
- }
+ // Notify the listener that the operation and its contents are being erased.
+ if (listener)
+ notifyIRErased(listener, *op);
// Do not erase the operation yet. It may still be referenced in `mapping`.
// Just unlink it for now and erase it during cleanup.
@@ -1605,6 +1625,8 @@ void ConversionPatternRewriterImpl::replaceOp(
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+ assert(!wasOpReplaced(block->getParentOp()) &&
+ "attempting to erase a block within a replaced/erased op");
appendRewrite<EraseBlockRewrite>(block);
// Unlink the block from its parent region. The block is kept in the rewrite
@@ -1612,12 +1634,16 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
block->getParent()->getBlocks().remove(block);
+
+ // Mark all nested ops as erased.
+ block->walk([&](Operation *op) { replacedOps.insert(op); });
}
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
- assert(!wasOpReplaced(block->getParentOp()) &&
- "attempting to insert into a region within a replaced/erased op");
+ assert(
+ (!config.allowPatternRollback || !wasOpReplaced(block->getParentOp())) &&
+ "attempting to insert into a region within a replaced/erased op");
LLVM_DEBUG(
{
Operation *parent = block->getParentOp();
@@ -1630,6 +1656,11 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
}
});
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed. No extra bookkeeping is needed.
+ return;
+ }
+
if (!previous) {
// This is a newly created block.
appendRewrite<CreateBlockRewrite>(block);
@@ -1709,13 +1740,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
- assert(!impl->wasOpReplaced(block->getParentOp()) &&
- "attempting to erase a block within a replaced/erased op");
-
- // Mark all ops for erasure.
- for (Operation &op : *block)
- eraseOp(&op);
-
impl->eraseBlock(block);
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 34948ae685f0a..204c8c1456826 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -461,12 +461,26 @@ func.func @convert_detached_signature() {
// -----
+// CHECK: notifyOperationReplaced: test.erase_op
+// CHECK: notifyOperationErased: test.dummy_op_lvl_2
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.dummy_op_lvl_1
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.erase_op
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
+// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
+
// CHECK-LABEL: func @circular_mapping()
// CHECK-NEXT: "test.valid"() : () -> ()
func.func @circular_mapping() {
// Regression test that used to crash due to circular
- // unrealized_conversion_cast ops.
- %0 = "test.erase_op"() : () -> (i64)
+ // unrealized_conversion_cast ops.
+ %0 = "test.erase_op"() ({
+ "test.dummy_op_lvl_1"() ({
+ "test.dummy_op_lvl_2"() : () -> ()
+ }) : () -> ()
+ }): () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}
|
d65161f
to
edb49ec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
edb49ec
to
2fe4d14
Compare
Add missing listener notifications when erasing nested blocks/operations.
This commit also moves some of the functionality from
ConversionPatternRewriter
toConversionPatternRewriterImpl
. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations inConversionPatternRewriter
should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case,ConversionPatternRewriterImpl
can be bypassed to some degree, andPatternRewriter::eraseBlock
etc. can be used.)Depends on #145018.