Skip to content

[mlir][Transforms][NFC] Dialect Conversion: Keep unresolvedMaterializations up to date #144254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 18, 2025

Conversation

matthias-springer
Copy link
Member

unresolvedMaterializations is a mapping from UnrealizedConversionCastOp to UnresolvedMaterializationRewrite. This mapping is needed to find the correct type converter for an unresolved materialization.

With this commit, unresolvedMaterializations is updated immediately when an op is being erased. This also cleans up the code base a bit: SingleEraseRewriter is now used only during the "cleanup" phase and no longer needed as a field of ConversionRewriterImpl.

This commit is in preparation of the One-Shot Dialect Conversion refactoring: allowPatternRollback = false will in the future trigger immediate materialization of all IR changes.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 15, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 15, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

unresolvedMaterializations is a mapping from UnrealizedConversionCastOp to UnresolvedMaterializationRewrite. This mapping is needed to find the correct type converter for an unresolved materialization.

With this commit, unresolvedMaterializations is updated immediately when an op is being erased. This also cleans up the code base a bit: SingleEraseRewriter is now used only during the "cleanup" phase and no longer needed as a field of ConversionRewriterImpl.

This commit is in preparation of the One-Shot Dialect Conversion refactoring: allowPatternRollback = false will in the future trigger immediate materialization of all IR changes.


Full diff: https://github.com/llvm/llvm-project/pull/144254.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-13)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7de26d7cfa84d..b5345fb1a2dcb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,7 +848,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
   struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
   public:
-    SingleEraseRewriter(MLIRContext *context)
-        : RewriterBase(context, /*listener=*/this) {}
+    SingleEraseRewriter(
+        MLIRContext *context,
+        llvm::function_ref<void(Operation *)> opErasedCallback = nullptr)
+        : RewriterBase(context, /*listener=*/this),
+          opErasedCallback(opErasedCallback) {}
 
     /// Erase the given op (unless it was already erased).
     void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     bool wasErased(void *ptr) const { return erased.contains(ptr); }
 
-    void notifyOperationErased(Operation *op) override { erased.insert(op); }
+    void notifyOperationErased(Operation *op) override {
+      erased.insert(op);
+      if (opErasedCallback)
+        opErasedCallback(op);
+    }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
   private:
     /// Pointers to all erased operations and blocks.
     DenseSet<void *> erased;
+
+    /// A callback that is invoked when an operation is erased.
+    llvm::function_ref<void(Operation *)> opErasedCallback;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
-  /// A rewriter that keeps track of ops/block that were already erased and
-  /// skips duplicate op/block erasures. This rewriter is used during the
-  /// "cleanup" phase.
-  SingleEraseRewriter eraseRewriter;
-
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrites[i]->commit(rewriter);
 
   // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(
+      context, /*opErasedCallback=*/[&](Operation *op) {
+        if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+          unresolvedMaterializations.erase(castOp);
+      });
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   SmallVector<UnrealizedConversionCastOp> allCastOps;
   const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
       &materializations = rewriterImpl.unresolvedMaterializations;
-  for (auto it : materializations) {
-    if (rewriterImpl.eraseRewriter.wasErased(it.first))
-      continue;
+  for (auto it : materializations)
     allCastOps.push_back(it.first);
-  }
 
   // Reconcile all UnrealizedConversionCastOps that were inserted by the
   // dialect conversion frameworks. (Not the one that were inserted by

@llvmbot
Copy link
Member

llvmbot commented Jun 15, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

unresolvedMaterializations is a mapping from UnrealizedConversionCastOp to UnresolvedMaterializationRewrite. This mapping is needed to find the correct type converter for an unresolved materialization.

With this commit, unresolvedMaterializations is updated immediately when an op is being erased. This also cleans up the code base a bit: SingleEraseRewriter is now used only during the "cleanup" phase and no longer needed as a field of ConversionRewriterImpl.

This commit is in preparation of the One-Shot Dialect Conversion refactoring: allowPatternRollback = false will in the future trigger immediate materialization of all IR changes.


Full diff: https://github.com/llvm/llvm-project/pull/144254.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-13)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7de26d7cfa84d..b5345fb1a2dcb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,7 +848,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
   struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
   public:
-    SingleEraseRewriter(MLIRContext *context)
-        : RewriterBase(context, /*listener=*/this) {}
+    SingleEraseRewriter(
+        MLIRContext *context,
+        llvm::function_ref<void(Operation *)> opErasedCallback = nullptr)
+        : RewriterBase(context, /*listener=*/this),
+          opErasedCallback(opErasedCallback) {}
 
     /// Erase the given op (unless it was already erased).
     void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     bool wasErased(void *ptr) const { return erased.contains(ptr); }
 
-    void notifyOperationErased(Operation *op) override { erased.insert(op); }
+    void notifyOperationErased(Operation *op) override {
+      erased.insert(op);
+      if (opErasedCallback)
+        opErasedCallback(op);
+    }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
   private:
     /// Pointers to all erased operations and blocks.
     DenseSet<void *> erased;
+
+    /// A callback that is invoked when an operation is erased.
+    llvm::function_ref<void(Operation *)> opErasedCallback;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
-  /// A rewriter that keeps track of ops/block that were already erased and
-  /// skips duplicate op/block erasures. This rewriter is used during the
-  /// "cleanup" phase.
-  SingleEraseRewriter eraseRewriter;
-
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrites[i]->commit(rewriter);
 
   // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(
+      context, /*opErasedCallback=*/[&](Operation *op) {
+        if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+          unresolvedMaterializations.erase(castOp);
+      });
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   SmallVector<UnrealizedConversionCastOp> allCastOps;
   const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
       &materializations = rewriterImpl.unresolvedMaterializations;
-  for (auto it : materializations) {
-    if (rewriterImpl.eraseRewriter.wasErased(it.first))
-      continue;
+  for (auto it : materializations)
     allCastOps.push_back(it.first);
-  }
 
   // Reconcile all UnrealizedConversionCastOps that were inserted by the
   // dialect conversion frameworks. (Not the one that were inserted by

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conv_erase branch from b780a86 to 7596373 Compare June 18, 2025 09:10
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer merged commit 66580f7 into main Jun 18, 2025
7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/dialect_conv_erase branch June 18, 2025 12:42
fschlimb pushed a commit to fschlimb/llvm-project that referenced this pull request Jun 18, 2025
…zations` up to date (llvm#144254)

`unresolvedMaterializations` is a mapping from
`UnrealizedConversionCastOp` to `UnresolvedMaterializationRewrite`. This
mapping is needed to find the correct type converter for an unresolved
materialization.

With this commit, `unresolvedMaterializations` is updated immediately
when an op is being erased. This also cleans up the code base a bit:
`SingleEraseRewriter` is now used only during the "cleanup" phase and no
longer needed as a field of `ConversionRewriterImpl`.

This commit is in preparation of the One-Shot Dialect Conversion
refactoring: `allowPatternRollback = false` will in the future trigger
immediate materialization of all IR changes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants