diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h index 6ed35ac0ddccc..012bf7b9ec4a9 100644 --- a/mlir/include/mlir/Query/Matcher/Marshallers.h +++ b/mlir/include/mlir/Query/Matcher/Marshallers.h @@ -50,6 +50,36 @@ struct ArgTypeTraits { } }; +template <> +struct ArgTypeTraits { + static bool hasCorrectType(const VariantValue &value) { + return value.isSigned(); + } + + static unsigned get(const VariantValue &value) { return value.getSigned(); } + + static ArgKind getKind() { return ArgKind::Signed; } + + static std::optional getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +template <> +struct ArgTypeTraits { + static bool hasCorrectType(const VariantValue &value) { + return value.isBoolean(); + } + + static unsigned get(const VariantValue &value) { return value.getBoolean(); } + + static ArgKind getKind() { return ArgKind::Boolean; } + + static std::optional getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + template <> struct ArgTypeTraits { diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h index b008a21f53ae2..f8abf20ef60bb 100644 --- a/mlir/include/mlir/Query/Matcher/MatchFinder.h +++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // This file contains the MatchFinder class, which is used to find operations -// that match a given matcher. +// that match a given matcher and print them. // //===----------------------------------------------------------------------===// @@ -15,25 +15,43 @@ #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H #include "MatchersInternal.h" +#include "mlir/Query/Query.h" +#include "mlir/Query/QuerySession.h" +#include "llvm/ADT/SetVector.h" namespace mlir::query::matcher { -// MatchFinder is used to find all operations that match a given matcher. +/// A class that provides utilities to find operations in the IR. class MatchFinder { + public: - // Returns all operations that match the given matcher. - static std::vector getMatches(Operation *root, - DynMatcher matcher) { - std::vector matches; - - // Simple match finding with walk. - root->walk([&](Operation *subOp) { - if (matcher.match(subOp)) - matches.push_back(subOp); - }); - - return matches; - } + /// A subclass which preserves the matching information. Each instance + /// contains the `rootOp` along with the matching environment. + struct MatchResult { + MatchResult() = default; + MatchResult(Operation *rootOp, std::vector matchedOps); + + Operation *rootOp = nullptr; + /// Contains the matching environment. + std::vector matchedOps; + }; + + /// Traverses the IR and returns a vector of `MatchResult` for each match of + /// the `matcher`. + std::vector collectMatches(Operation *root, + DynMatcher matcher) const; + + /// Prints the matched operation. + void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const; + + /// Labels the matched operation with the given binding (e.g., `"root"`) and + /// prints it. + void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op, + const std::string &binding) const; + + /// Flattens a vector of `MatchResult` into a vector of operations. + std::vector + flattenMatchedOps(std::vector &matches) const; }; } // namespace mlir::query::matcher diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h index 117f7d4edef9e..183b2514e109f 100644 --- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h +++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h @@ -8,8 +8,9 @@ // // Implements the base layer of the matcher framework. // -// Matchers are methods that return a Matcher which provides a method -// match(Operation *op) +// Matchers are methods that return a Matcher which provides a method one of the +// following methods: match(Operation *op), match(Operation *op, +// SetVector &matchedOps) // // The matcher functions are defined in include/mlir/IR/Matchers.h. // This file contains the wrapper classes needed to construct matchers for @@ -25,6 +26,31 @@ namespace mlir::query::matcher { +// Defaults to false if T has no match() method with the signature: +// match(Operation* op). +template +struct has_simple_match : std::false_type {}; + +// Specialized type trait that evaluates to true if T has a match() method +// with the signature: match(Operation* op). +template +struct has_simple_match().match( + std::declval()))>> + : std::true_type {}; + +// Defaults to false if T has no match() method with the signature: +// match(Operation* op, SetVector&). +template +struct has_bound_match : std::false_type {}; + +// Specialized type trait that evaluates to true if T has a match() method +// with the signature: match(Operation* op, SetVector&). +template +struct has_bound_match().match( + std::declval(), + std::declval &>()))>> + : std::true_type {}; + // Generic interface for matchers on an MLIR operation. class MatcherInterface : public llvm::ThreadSafeRefCountedBase { @@ -32,6 +58,7 @@ class MatcherInterface virtual ~MatcherInterface() = default; virtual bool match(Operation *op) = 0; + virtual bool match(Operation *op, SetVector &matchedOps) = 0; }; // MatcherFnImpl takes a matcher function object and implements @@ -40,14 +67,25 @@ template class MatcherFnImpl : public MatcherInterface { public: MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {} - bool match(Operation *op) override { return matcherFn.match(op); } + + bool match(Operation *op) override { + if constexpr (has_simple_match::value) + return matcherFn.match(op); + return false; + } + + bool match(Operation *op, SetVector &matchedOps) override { + if constexpr (has_bound_match::value) + return matcherFn.match(op, matchedOps); + return false; + } private: MatcherFn matcherFn; }; -// Matcher wraps a MatcherInterface implementation and provides a match() -// method that redirects calls to the underlying implementation. +// Matcher wraps a MatcherInterface implementation and provides match() +// methods that redirect calls to the underlying implementation. class DynMatcher { public: // Takes ownership of the provided implementation pointer. @@ -62,12 +100,13 @@ class DynMatcher { } bool match(Operation *op) const { return implementation->match(op); } + bool match(Operation *op, SetVector &matchedOps) const { + return implementation->match(op, matchedOps); + } - void setFunctionName(StringRef name) { functionName = name.str(); }; - - bool hasFunctionName() const { return !functionName.empty(); }; - - StringRef getFunctionName() const { return functionName; }; + void setFunctionName(StringRef name) { functionName = name.str(); } + bool hasFunctionName() const { return !functionName.empty(); } + StringRef getFunctionName() const { return functionName; } private: llvm::IntrusiveRefCntPtr implementation; diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h new file mode 100644 index 0000000000000..5bb8251672eb7 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h @@ -0,0 +1,141 @@ +//===- SliceMatchers.h - Matchers for slicing analysis ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides matchers for MLIRQuery that peform slicing analysis +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H +#define MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H + +#include "mlir/Analysis/SliceAnalysis.h" + +/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h. +/// Additionally, it limits the slice computation to a certain depth level using +/// a custom filter. +/// +/// Example: starting from node 9, assuming the matcher +/// computes the slice for the first two depth levels: +/// ============================ +/// 1 2 3 4 +/// |_______| |______| +/// | | | +/// | 5 6 +/// |___|_____________| +/// | | +/// 7 8 +/// |_______________| +/// | +/// 9 +/// +/// Assuming all local orders match the numbering order: +/// {5, 7, 6, 8, 9} +namespace mlir::query::matcher { + +template +class BackwardSliceMatcher { +public: + BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive, + bool omitBlockArguments, bool omitUsesFromAbove) + : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth), + inclusive(inclusive), omitBlockArguments(omitBlockArguments), + omitUsesFromAbove(omitUsesFromAbove) {} + + bool match(Operation *rootOp, SetVector &backwardSlice) { + BackwardSliceOptions options; + options.inclusive = inclusive; + options.omitUsesFromAbove = omitUsesFromAbove; + options.omitBlockArguments = omitBlockArguments; + return (innerMatcher.match(rootOp) && + matches(rootOp, backwardSlice, options, maxDepth)); + } + +private: + bool matches(Operation *rootOp, llvm::SetVector &backwardSlice, + BackwardSliceOptions &options, int64_t maxDepth); + +private: + // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher + // to determine whether we want to traverse the IR or not. For example, we + // want to explore the IR only if the top-level operation name is + // `"arith.addf"`. + Matcher innerMatcher; + // `maxDepth` specifies the maximum depth that the matcher can traverse the + // IR. For example, if `maxDepth` is 2, the matcher will explore the defining + // operations of the top-level op up to 2 levels. + int64_t maxDepth; + bool inclusive; + bool omitBlockArguments; + bool omitUsesFromAbove; +}; + +template +bool BackwardSliceMatcher::matches( + Operation *rootOp, llvm::SetVector &backwardSlice, + BackwardSliceOptions &options, int64_t maxDepth) { + backwardSlice.clear(); + llvm::DenseMap opDepths; + // Initializing the root op with a depth of 0 + opDepths[rootOp] = 0; + options.filter = [&](Operation *subOp) { + // If the subOp hasn't been recorded in opDepths, it is deeper than + // maxDepth. + if (!opDepths.contains(subOp)) + return false; + // Examine subOp's operands to compute depths of their defining operations. + for (auto operand : subOp->getOperands()) { + int64_t newDepth = opDepths[subOp] + 1; + // If the newDepth is greater than maxDepth, further computation can be + // skipped. + if (newDepth > maxDepth) + continue; + + if (auto definingOp = operand.getDefiningOp()) { + // Registers the minimum depth + if (!opDepths.contains(definingOp) || newDepth < opDepths[definingOp]) + opDepths[definingOp] = newDepth; + } else { + auto blockArgument = cast(operand); + Operation *parentOp = blockArgument.getOwner()->getParentOp(); + if (!parentOp) + continue; + + if (!opDepths.contains(parentOp) || newDepth < opDepths[parentOp]) + opDepths[parentOp] = newDepth; + } + } + return true; + }; + getBackwardSlice(rootOp, &backwardSlice, options); + return options.inclusive ? backwardSlice.size() > 1 + : backwardSlice.size() >= 1; +} + +/// Matches transitive defs of a top-level operation up to N levels. +template +inline BackwardSliceMatcher +m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive, + bool omitBlockArguments, bool omitUsesFromAbove) { + assert(maxDepth >= 0 && "maxDepth must be non-negative"); + return BackwardSliceMatcher(std::move(innerMatcher), maxDepth, + inclusive, omitBlockArguments, + omitUsesFromAbove); +} + +/// Matches all transitive defs of a top-level operation up to N levels +template +inline BackwardSliceMatcher m_GetAllDefinitions(Matcher innerMatcher, + int64_t maxDepth) { + assert(maxDepth >= 0 && "maxDepth must be non-negative"); + return BackwardSliceMatcher(std::move(innerMatcher), maxDepth, true, + false, false); +} + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h index 449f8b3a01e02..98c0a18e25101 100644 --- a/mlir/include/mlir/Query/Matcher/VariantValue.h +++ b/mlir/include/mlir/Query/Matcher/VariantValue.h @@ -21,7 +21,7 @@ namespace mlir::query::matcher { // All types that VariantValue can contain. -enum class ArgKind { Matcher, String }; +enum class ArgKind { Boolean, Matcher, Signed, String }; // A variant matcher object to abstract simple and complex matchers into a // single object type. @@ -81,6 +81,8 @@ class VariantValue { // Specific constructors for each supported type. VariantValue(const llvm::StringRef string); VariantValue(const VariantMatcher &matcher); + VariantValue(int64_t signedValue); + VariantValue(bool setBoolean); // String value functions. bool isString() const; @@ -92,21 +94,36 @@ class VariantValue { const VariantMatcher &getMatcher() const; void setMatcher(const VariantMatcher &matcher); + // Signed value functions. + bool isSigned() const; + int64_t getSigned() const; + void setSigned(int64_t signedValue); + + // Boolean value functions. + bool isBoolean() const; + bool getBoolean() const; + void setBoolean(bool booleanValue); // String representation of the type of the value. std::string getTypeAsString() const; + explicit operator bool() const { return hasValue(); } + bool hasValue() const { return type != ValueType::Nothing; } private: void reset(); // All supported value types. enum class ValueType { + Boolean, + Matcher, Nothing, + Signed, String, - Matcher, }; // All supported value types. union AllValues { + bool Boolean; + int64_t Signed; llvm::StringRef *String; VariantMatcher *Matcher; }; diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt index 3adff9f99243f..629479bf7adc1 100644 --- a/mlir/lib/Query/Matcher/CMakeLists.txt +++ b/mlir/lib/Query/Matcher/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRQueryMatcher + MatchFinder.cpp Parser.cpp RegistryManager.cpp VariantValue.cpp diff --git a/mlir/lib/Query/Matcher/MatchFinder.cpp b/mlir/lib/Query/Matcher/MatchFinder.cpp new file mode 100644 index 0000000000000..1d4817e32417d --- /dev/null +++ b/mlir/lib/Query/Matcher/MatchFinder.cpp @@ -0,0 +1,68 @@ +//===- MatchFinder.cpp - --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the method definitions for the `MatchFinder` class +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Matcher/MatchFinder.h" +namespace mlir::query::matcher { + +MatchFinder::MatchResult::MatchResult(Operation *rootOp, + std::vector matchedOps) + : rootOp(rootOp), matchedOps(std::move(matchedOps)) {} + +std::vector +MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const { + std::vector results; + llvm::SetVector tempStorage; + root->walk([&](Operation *subOp) { + if (matcher.match(subOp)) { + MatchResult match; + match.rootOp = subOp; + match.matchedOps.push_back(subOp); + results.push_back(std::move(match)); + } else if (matcher.match(subOp, tempStorage)) { + results.emplace_back(subOp, std::vector(tempStorage.begin(), + tempStorage.end())); + } + tempStorage.clear(); + }); + return results; +} + +void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs, + Operation *op) const { + auto fileLoc = cast(op->getLoc()); + SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn( + qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn()); + llvm::SMDiagnostic diag = + qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note, ""); + diag.print("", os, true, false, true); +} + +void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs, + Operation *op, const std::string &binding) const { + auto fileLoc = cast(op->getLoc()); + auto smloc = qs.getSourceManager().FindLocForLineAndColumn( + qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn()); + qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note, + "\"" + binding + "\" binds here"); +} + +std::vector +MatchFinder::flattenMatchedOps(std::vector &matches) const { + std::vector newVector; + for (auto &result : matches) { + newVector.insert(newVector.end(), result.matchedOps.begin(), + result.matchedOps.end()); + } + return newVector; +} + +} // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp index 3609e24f9939f..e392a885c511b 100644 --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -135,6 +135,18 @@ class Parser::CodeTokenizer { case '\'': consumeStringLiteral(&result); break; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + consumeNumberLiteral(&result); + break; default: parseIdentifierOrInvalid(&result); break; @@ -144,6 +156,18 @@ class Parser::CodeTokenizer { return result; } + void consumeNumberLiteral(TokenInfo *result) { + StringRef original = code; + unsigned value = 0; + if (!code.consumeInteger(0, value)) { + size_t numConsumed = original.size() - code.size(); + result->text = original.take_front(numConsumed); + result->kind = TokenKind::Literal; + result->value = static_cast(value); + return; + } + } + // Consume a string literal, handle escape sequences and missing closing // quote. void consumeStringLiteral(TokenInfo *result) { @@ -195,9 +219,22 @@ class Parser::CodeTokenizer { break; ++tokenLength; } - result->kind = TokenKind::Ident; - result->text = code.substr(0, tokenLength); + llvm::StringRef token = code.substr(0, tokenLength); code = code.drop_front(tokenLength); + // Check if the identifier is a boolean literal + if (token == "true") { + result->text = "false"; + result->kind = TokenKind::Literal; + result->value = true; + } else if (token == "false") { + result->text = "false"; + result->kind = TokenKind::Literal; + result->value = false; + } else { + // Otherwise it is treated as a normal identifier + result->kind = TokenKind::Ident; + result->text = token; + } } else { result->kind = TokenKind::InvalidChar; result->text = code.substr(0, 1); @@ -257,13 +294,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { if (tokenizer->nextTokenKind() != TokenKind::OpenParen) { // Parse as a named value. - auto namedValue = - namedValues ? namedValues->lookup(nameToken.text) : VariantValue(); + if (auto namedValue = namedValues ? namedValues->lookup(nameToken.text) + : VariantValue()) { - if (!namedValue.isMatcher()) { - error->addError(tokenizer->peekNextToken().range, - ErrorType::ParserNotAMatcher); - return false; + if (tokenizer->nextTokenKind() != TokenKind::Period) { + *value = namedValue; + return true; + } + + if (!namedValue.isMatcher()) { + error->addError(tokenizer->peekNextToken().range, + ErrorType::ParserNotAMatcher); + return false; + } } if (tokenizer->nextTokenKind() == TokenKind::NewLine) { diff --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h index 58968023022d5..2199a2335ba9c 100644 --- a/mlir/lib/Query/Matcher/Parser.h +++ b/mlir/lib/Query/Matcher/Parser.h @@ -16,8 +16,11 @@ // provided to the parser. // // The grammar for the supported expressions is as follows: -// := | +// := | +// := | | // := "quoted string" +// := "true" | "false" +// := [0-9]+ // := () // := [a-zA-Z]+ // := | , diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp index 645db7109c2de..4b511c5f009e7 100644 --- a/mlir/lib/Query/Matcher/RegistryManager.cpp +++ b/mlir/lib/Query/Matcher/RegistryManager.cpp @@ -19,16 +19,15 @@ namespace mlir::query::matcher { namespace { -// This is needed because these matchers are defined as overloaded functions. -using IsConstantOp = detail::constant_op_matcher(); -using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef); -using HasOpName = detail::NameOpMatcher(llvm::StringRef); - // Enum to string for autocomplete. static std::string asArgString(ArgKind kind) { switch (kind) { + case ArgKind::Boolean: + return "Boolean"; case ArgKind::Matcher: return "Matcher"; + case ArgKind::Signed: + return "Signed"; case ArgKind::String: return "String"; } diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp index 65bd4bd77bcf8..1cb2d48f9d56f 100644 --- a/mlir/lib/Query/Matcher/VariantValue.cpp +++ b/mlir/lib/Query/Matcher/VariantValue.cpp @@ -56,6 +56,14 @@ VariantValue::VariantValue(const VariantMatcher &matcher) value.Matcher = new VariantMatcher(matcher); } +VariantValue::VariantValue(int64_t signedValue) : type(ValueType::Signed) { + value.Signed = signedValue; +} + +VariantValue::VariantValue(bool setBoolean) : type(ValueType::Boolean) { + value.Boolean = setBoolean; +} + VariantValue::~VariantValue() { reset(); } VariantValue &VariantValue::operator=(const VariantValue &other) { @@ -69,6 +77,12 @@ VariantValue &VariantValue::operator=(const VariantValue &other) { case ValueType::Matcher: setMatcher(other.getMatcher()); break; + case ValueType::Signed: + setSigned(other.getSigned()); + break; + case ValueType::Boolean: + setBoolean(other.getBoolean()); + break; case ValueType::Nothing: type = ValueType::Nothing; break; @@ -85,12 +99,34 @@ void VariantValue::reset() { delete value.Matcher; break; // Cases that do nothing. + case ValueType::Signed: + case ValueType::Boolean: case ValueType::Nothing: break; } type = ValueType::Nothing; } +// Signed +bool VariantValue::isSigned() const { return type == ValueType::Signed; } + +int64_t VariantValue::getSigned() const { return value.Signed; } + +void VariantValue::setSigned(int64_t newValue) { + type = ValueType::Signed; + value.Signed = newValue; +} + +// Boolean +bool VariantValue::isBoolean() const { return type == ValueType::Boolean; } + +bool VariantValue::getBoolean() const { return value.Signed; } + +void VariantValue::setBoolean(bool newValue) { + type = ValueType::Boolean; + value.Signed = newValue; +} + bool VariantValue::isString() const { return type == ValueType::String; } const llvm::StringRef &VariantValue::getString() const { @@ -123,6 +159,10 @@ std::string VariantValue::getTypeAsString() const { return "String"; case ValueType::Matcher: return "Matcher"; + case ValueType::Signed: + return "Signed"; + case ValueType::Boolean: + return "Boolean"; case ValueType::Nothing: return "Nothing"; } diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp index 73f313cd37fd0..803284d6df86a 100644 --- a/mlir/lib/Query/Query.cpp +++ b/mlir/lib/Query/Query.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/Query/Matcher/MatchFinder.h" #include "mlir/Query/QuerySession.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" @@ -26,15 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) { return QueryParser::complete(line, pos, qs); } -static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op, - const std::string &binding) { - auto fileLoc = op->getLoc()->findInstanceOf(); - auto smloc = qs.getSourceManager().FindLocForLineAndColumn( - qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn()); - qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note, - "\"" + binding + "\" binds here"); -} - // TODO: Extract into a helper function that can be reused outside query // context. static Operation *extractFunction(std::vector &ops, @@ -126,28 +118,34 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { Operation *rootOp = qs.getRootOp(); int matchCount = 0; - std::vector matches = - matcher::MatchFinder().getMatches(rootOp, matcher); + matcher::MatchFinder finder; + auto matches = finder.collectMatches(rootOp, std::move(matcher)); // An extract call is recognized by considering if the matcher has a name. // TODO: Consider making the extract more explicit. if (matcher.hasFunctionName()) { auto functionName = matcher.getFunctionName(); + std::vector flattenedMatches = + finder.flattenMatchedOps(matches); Operation *function = - extractFunction(matches, rootOp->getContext(), functionName); + extractFunction(flattenedMatches, rootOp->getContext(), functionName); os << "\n" << *function << "\n\n"; function->erase(); return mlir::success(); } os << "\n"; - for (Operation *op : matches) { + for (auto &results : matches) { os << "Match #" << ++matchCount << ":\n\n"; - // Placeholder "root" binding for the initial draft. - printMatch(os, qs, op, "root"); + for (auto op : results.matchedOps) { + if (op == results.rootOp) { + finder.printMatch(os, qs, op, "root"); + } else { + finder.printMatch(os, qs, op); + } + } } os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n"); - return mlir::success(); } diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir new file mode 100644 index 0000000000000..ad96f03747a43 --- /dev/null +++ b/mlir/test/mlir-query/complex-test.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) { + %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %2 = arith.addf %in, %in : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %c2 = arith.constant 2 : index + %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32> + %2 = arith.addf %extracted, %extracted : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + return +} + +// CHECK: Match #1: + +// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) +// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32 + +// CHECK: Match #2: + +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32> +// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32> +// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32 diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp index 0ed4f94d5802b..78c0ec97c0cdf 100644 --- a/mlir/tools/mlir-query/mlir-query.cpp +++ b/mlir/tools/mlir-query/mlir-query.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/InitAllDialects.h" #include "mlir/Query/Matcher/Registry.h" +#include "mlir/Query/Matcher/SliceMatchers.h" #include "mlir/Tools/mlir-query/MlirQueryMain.h" using namespace mlir; @@ -39,6 +40,12 @@ int main(int argc, char **argv) { query::matcher::Registry matcherRegistry; // Matchers registered in alphabetical order for consistency: + matcherRegistry.registerMatcher( + "getDefinitions", + query::matcher::m_GetDefinitions); + matcherRegistry.registerMatcher( + "getAllDefinitions", + query::matcher::m_GetAllDefinitions); matcherRegistry.registerMatcher("hasOpAttrName", static_cast(m_Attr)); matcherRegistry.registerMatcher("hasOpName", static_cast(m_Op));