diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index ff48647f43305..ad82a007b7996 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -274,6 +274,26 @@ struct RewriterState { // IR rewrites //===----------------------------------------------------------------------===// +static void notifyIRErased(RewriterBase::Listener *listener, Operation &op); + +/// Notify the listener that the given block and its contents are being erased. +static void notifyIRErased(RewriterBase::Listener *listener, Block &b) { + for (Operation &op : b) + notifyIRErased(listener, op); + listener->notifyBlockErased(&b); +} + +/// Notify the listener that the given operation and its contents are being +/// erased. +static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) { + for (Region &r : op.getRegions()) { + for (Block &b : r) { + notifyIRErased(listener, b); + } + } + listener->notifyOperationErased(&op); +} + /// An IR rewrite that can be committed (upon success) or rolled back (upon /// failure). /// @@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite { } void commit(RewriterBase &rewriter) override { - // Erase the block. assert(block && "expected block"); - assert(block->empty() && "expected empty block"); - // Notify the listener that the block is about to be erased. + // Notify the listener that the block and its contents are being erased. if (auto *listener = dyn_cast_or_null(rewriter.getListener())) - listener->notifyBlockErased(block); + notifyIRErased(listener, *block); } void cleanup(RewriterBase &rewriter) override { + // Erase the contents of the block. + for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) + rewriter.eraseOp(&op); + assert(block->empty() && "expected empty block"); + // Erase the block. block->dropAllDefinedValueUses(); delete block; @@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { if (getConfig().unlegalizedOps) getConfig().unlegalizedOps->erase(op); - // Notify the listener that the operation (and its nested operations) was - // erased. - if (listener) { - op->walk( - [&](Operation *op) { listener->notifyOperationErased(op); }); - } + // Notify the listener that the operation and its contents are being erased. + if (listener) + notifyIRErased(listener, *op); // Do not erase the operation yet. It may still be referenced in `mapping`. // Just unlink it for now and erase it during cleanup. @@ -1605,6 +1625,8 @@ void ConversionPatternRewriterImpl::replaceOp( } void ConversionPatternRewriterImpl::eraseBlock(Block *block) { + assert(!wasOpReplaced(block->getParentOp()) && + "attempting to erase a block within a replaced/erased op"); appendRewrite(block); // Unlink the block from its parent region. The block is kept in the rewrite @@ -1612,6 +1634,9 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) { // allows us to keep the operations in the block live and undo the removal by // re-inserting the block. block->getParent()->getBlocks().remove(block); + + // Mark all nested ops as erased. + block->walk([&](Operation *op) { replacedOps.insert(op); }); } void ConversionPatternRewriterImpl::notifyBlockInserted( @@ -1709,13 +1734,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { } void ConversionPatternRewriter::eraseBlock(Block *block) { - assert(!impl->wasOpReplaced(block->getParentOp()) && - "attempting to erase a block within a replaced/erased op"); - - // Mark all ops for erasure. - for (Operation &op : *block) - eraseOp(&op); - impl->eraseBlock(block); } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 34948ae685f0a..204c8c1456826 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -461,12 +461,26 @@ func.func @convert_detached_signature() { // ----- +// CHECK: notifyOperationReplaced: test.erase_op +// CHECK: notifyOperationErased: test.dummy_op_lvl_2 +// CHECK: notifyBlockErased +// CHECK: notifyOperationErased: test.dummy_op_lvl_1 +// CHECK: notifyBlockErased +// CHECK: notifyOperationErased: test.erase_op +// CHECK: notifyOperationInserted: test.valid, was unlinked +// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid +// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid + // CHECK-LABEL: func @circular_mapping() // CHECK-NEXT: "test.valid"() : () -> () func.func @circular_mapping() { // Regression test that used to crash due to circular - // unrealized_conversion_cast ops. - %0 = "test.erase_op"() : () -> (i64) + // unrealized_conversion_cast ops. + %0 = "test.erase_op"() ({ + "test.dummy_op_lvl_1"() ({ + "test.dummy_op_lvl_2"() : () -> () + }) : () -> () + }): () -> (i64) "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> () }