@@ -80,37 +80,30 @@ bool BackwardSliceMatcher<Matcher>::matches(
80
80
BackwardSliceOptions &options, int64_t maxDepth) {
81
81
backwardSlice.clear ();
82
82
llvm::DenseMap<Operation *, int64_t > opDepths;
83
- // The starting point is the root op; therefore, we set its depth to 0.
83
+ // Initializing the root op with a depth of 0
84
84
opDepths[rootOp] = 0 ;
85
85
options.filter = [&](Operation *subOp) {
86
- // If the subOp's depth exceeds maxDepth, we stop further slicing for this
87
- // branch .
88
- if (opDepths[ subOp] > maxDepth )
86
+ // If the subOp hasn't been recorded in opDepths, it is deeper than
87
+ // maxDepth .
88
+ if (! opDepths. contains ( subOp) )
89
89
return false ;
90
90
// Examine subOp's operands to compute depths of their defining operations.
91
91
for (auto operand : subOp->getOperands ()) {
92
+ int64_t newDepth = opDepths[subOp] + 1 ;
93
+ // If the newDepth is greater than maxDepth, further computation can be
94
+ // skipped.
95
+ if (newDepth > maxDepth)
96
+ continue ;
92
97
if (auto definingOp = operand.getDefiningOp ()) {
93
- // Set the defining operation's depth to one level greater than
94
- // subOp's depth.
95
- int64_t newDepth = opDepths[subOp] + 1 ;
96
- if (!opDepths.contains (definingOp)) {
98
+ if (!opDepths.contains (definingOp) || newDepth < opDepths[definingOp])
97
99
opDepths[definingOp] = newDepth;
98
- } else {
99
- opDepths[definingOp] = std::min (opDepths[definingOp], newDepth);
100
- }
101
- return !(opDepths[subOp] > maxDepth);
102
100
} else {
103
101
auto blockArgument = cast<BlockArgument>(operand);
104
102
Operation *parentOp = blockArgument.getOwner ()->getParentOp ();
105
103
if (!parentOp)
106
104
continue ;
107
- int64_t newDepth = opDepths[subOp] + 1 ;
108
- if (!opDepths.contains (parentOp)) {
105
+ if (!opDepths.contains (parentOp) || newDepth < opDepths[parentOp])
109
106
opDepths[parentOp] = newDepth;
110
- } else {
111
- opDepths[parentOp] = std::min (opDepths[parentOp], newDepth);
112
- }
113
- return !(opDepths[parentOp] > maxDepth);
114
107
}
115
108
}
116
109
return true ;
@@ -130,6 +123,14 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
130
123
omitUsesFromAbove);
131
124
}
132
125
126
+ template <typename Matcher>
127
+ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions (Matcher innerMatcher,
128
+ int64_t maxDepth) {
129
+ assert (maxDepth >= 0 && " maxDepth must be non-negative" );
130
+ return BackwardSliceMatcher<Matcher>(std::move (innerMatcher), maxDepth, true ,
131
+ false , false );
132
+ }
133
+
133
134
} // namespace mlir::query::matcher
134
135
135
136
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
0 commit comments