-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[mlir][Transforms][NFC] Dialect Conversion: Keep unresolvedMaterializations
up to date
#144254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] Dialect Conversion: Keep unresolvedMaterializations
up to date
#144254
Conversation
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) Changes
With this commit, This commit is in preparation of the One-Shot Dialect Conversion refactoring: Full diff: https://github.com/llvm/llvm-project/pull/144254.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7de26d7cfa84d..b5345fb1a2dcb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,7 +848,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : context(ctx), eraseRewriter(ctx), config(config) {}
+ : context(ctx), config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// no new IR is created between calls to `eraseOp`/`eraseBlock`.
struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
public:
- SingleEraseRewriter(MLIRContext *context)
- : RewriterBase(context, /*listener=*/this) {}
+ SingleEraseRewriter(
+ MLIRContext *context,
+ llvm::function_ref<void(Operation *)> opErasedCallback = nullptr)
+ : RewriterBase(context, /*listener=*/this),
+ opErasedCallback(opErasedCallback) {}
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasErased(void *ptr) const { return erased.contains(ptr); }
- void notifyOperationErased(Operation *op) override { erased.insert(op); }
+ void notifyOperationErased(Operation *op) override {
+ erased.insert(op);
+ if (opErasedCallback)
+ opErasedCallback(op);
+ }
void notifyBlockErased(Block *block) override { erased.insert(block); }
private:
/// Pointers to all erased operations and blocks.
DenseSet<void *> erased;
+
+ /// A callback that is invoked when an operation is erased.
+ llvm::function_ref<void(Operation *)> opErasedCallback;
};
//===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// MLIR context.
MLIRContext *context;
- /// A rewriter that keeps track of ops/block that were already erased and
- /// skips duplicate op/block erasures. This rewriter is used during the
- /// "cleanup" phase.
- SingleEraseRewriter eraseRewriter;
-
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
rewrites[i]->commit(rewriter);
// Clean up all rewrites.
+ SingleEraseRewriter eraseRewriter(
+ context, /*opErasedCallback=*/[&](Operation *op) {
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+ unresolvedMaterializations.erase(castOp);
+ });
for (auto &rewrite : rewrites)
rewrite->cleanup(eraseRewriter);
}
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
SmallVector<UnrealizedConversionCastOp> allCastOps;
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
&materializations = rewriterImpl.unresolvedMaterializations;
- for (auto it : materializations) {
- if (rewriterImpl.eraseRewriter.wasErased(it.first))
- continue;
+ for (auto it : materializations)
allCastOps.push_back(it.first);
- }
// Reconcile all UnrealizedConversionCastOps that were inserted by the
// dialect conversion frameworks. (Not the one that were inserted by
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) Changes
With this commit, This commit is in preparation of the One-Shot Dialect Conversion refactoring: Full diff: https://github.com/llvm/llvm-project/pull/144254.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7de26d7cfa84d..b5345fb1a2dcb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,7 +848,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : context(ctx), eraseRewriter(ctx), config(config) {}
+ : context(ctx), config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// no new IR is created between calls to `eraseOp`/`eraseBlock`.
struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
public:
- SingleEraseRewriter(MLIRContext *context)
- : RewriterBase(context, /*listener=*/this) {}
+ SingleEraseRewriter(
+ MLIRContext *context,
+ llvm::function_ref<void(Operation *)> opErasedCallback = nullptr)
+ : RewriterBase(context, /*listener=*/this),
+ opErasedCallback(opErasedCallback) {}
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasErased(void *ptr) const { return erased.contains(ptr); }
- void notifyOperationErased(Operation *op) override { erased.insert(op); }
+ void notifyOperationErased(Operation *op) override {
+ erased.insert(op);
+ if (opErasedCallback)
+ opErasedCallback(op);
+ }
void notifyBlockErased(Block *block) override { erased.insert(block); }
private:
/// Pointers to all erased operations and blocks.
DenseSet<void *> erased;
+
+ /// A callback that is invoked when an operation is erased.
+ llvm::function_ref<void(Operation *)> opErasedCallback;
};
//===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// MLIR context.
MLIRContext *context;
- /// A rewriter that keeps track of ops/block that were already erased and
- /// skips duplicate op/block erasures. This rewriter is used during the
- /// "cleanup" phase.
- SingleEraseRewriter eraseRewriter;
-
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
rewrites[i]->commit(rewriter);
// Clean up all rewrites.
+ SingleEraseRewriter eraseRewriter(
+ context, /*opErasedCallback=*/[&](Operation *op) {
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+ unresolvedMaterializations.erase(castOp);
+ });
for (auto &rewrite : rewrites)
rewrite->cleanup(eraseRewriter);
}
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
SmallVector<UnrealizedConversionCastOp> allCastOps;
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
&materializations = rewriterImpl.unresolvedMaterializations;
- for (auto it : materializations) {
- if (rewriterImpl.eraseRewriter.wasErased(it.first))
- continue;
+ for (auto it : materializations)
allCastOps.push_back(it.first);
- }
// Reconcile all UnrealizedConversionCastOps that were inserted by the
// dialect conversion frameworks. (Not the one that were inserted by
|
b780a86
to
7596373
Compare
…zations` up to date
7596373
to
c6ca8c8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
…zations` up to date (llvm#144254) `unresolvedMaterializations` is a mapping from `UnrealizedConversionCastOp` to `UnresolvedMaterializationRewrite`. This mapping is needed to find the correct type converter for an unresolved materialization. With this commit, `unresolvedMaterializations` is updated immediately when an op is being erased. This also cleans up the code base a bit: `SingleEraseRewriter` is now used only during the "cleanup" phase and no longer needed as a field of `ConversionRewriterImpl`. This commit is in preparation of the One-Shot Dialect Conversion refactoring: `allowPatternRollback = false` will in the future trigger immediate materialization of all IR changes.
unresolvedMaterializations
is a mapping fromUnrealizedConversionCastOp
toUnresolvedMaterializationRewrite
. This mapping is needed to find the correct type converter for an unresolved materialization.With this commit,
unresolvedMaterializations
is updated immediately when an op is being erased. This also cleans up the code base a bit:SingleEraseRewriter
is now used only during the "cleanup" phase and no longer needed as a field ofConversionRewriterImpl
.This commit is in preparation of the One-Shot Dialect Conversion refactoring:
allowPatternRollback = false
will in the future trigger immediate materialization of all IR changes.