diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp index 93d9a6547883a..52b4d06c98e32 100644 --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -22,6 +22,8 @@ namespace mlir { using namespace mlir; +#define DEBUG_TYPE "symbol-dce" + namespace { struct SymbolDCE : public impl::SymbolDCEBase { void runOnOperation() override; @@ -84,6 +86,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, SymbolTableCollection &symbolTable, bool symbolTableIsHidden, DenseSet &liveSymbols) { + LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName() + << "\n"); // A worklist of live operations to propagate uses from. SmallVector worklist; @@ -108,6 +112,7 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, // that are referenced within. while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); + LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n"); // If this is a symbol table, recursively compute its liveness. if (op->hasTrait()) { @@ -115,8 +120,34 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, // symbol, or if it is a private symbol. SymbolOpInterface symbol = dyn_cast(op); bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate(); + LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName() + << " is hidden: " << symIsHidden << "\n"); if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols))) return failure(); + } else { + LLVM_DEBUG(llvm::dbgs() + << "\tnon-symbol table: " << op->getName() << " is hidden\n"); + // If the op is not a symbol table, then, unless op itself is dead which + // would be handled by DCE, we need to check all the regions and blocks + // within the op to find the uses (e.g., consider visibility within op as + // if top level rather than relying on pure symbol table visibility). This + // is more conservative than SymbolTable::walkSymbolTables in the case + // where there is again SymbolTable information to take advantage of. + for (auto ®ion : op->getRegions()) { + for (auto &block : region.getBlocks()) { + for (Operation &op : block) { + SymbolOpInterface symbol = dyn_cast(&op); + if (!symbol) { + worklist.push_back(&op); + continue; + } + bool isDiscardable = + symbol.isPrivate() && symbol.canDiscardOnUseEmpty(); + if (!isDiscardable && liveSymbols.insert(&op).second) + worklist.push_back(&op); + } + } + } } // Collect the uses held by this operation. @@ -128,13 +159,27 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, } SmallVector resolvedSymbols; + // Get the first parent symbol table op. + Operation *parentOp = op->getParentOp(); + while (parentOp && !parentOp->hasTrait()) { + parentOp = parentOp->getParentOp(); + } + assert(parentOp && "operation has no parent symbol table"); + + LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n"); for (const SymbolTable::SymbolUse &use : *uses) { + LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n"); // Lookup the symbols referenced by this use. resolvedSymbols.clear(); - if (failed(symbolTable.lookupSymbolIn( - op->getParentOp(), use.getSymbolRef(), resolvedSymbols))) + if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(), + resolvedSymbols))) // Ignore references to unknown symbols. continue; + LLVM_DEBUG({ + llvm::dbgs() << "\t\tresolved symbols: "; + llvm::interleaveComma(resolvedSymbols, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); // Mark each of the resolved symbols as live. for (Operation *resolvedSymbol : resolvedSymbols) diff --git a/mlir/test/Transforms/test-symbol-dce.mlir b/mlir/test/Transforms/test-symbol-dce.mlir index 7bd784928e6f3..d44af1b93d241 100644 --- a/mlir/test/Transforms/test-symbol-dce.mlir +++ b/mlir/test/Transforms/test-symbol-dce.mlir @@ -98,3 +98,22 @@ module { // CHECK: "live.user"() {uses = [@unknown_symbol]} : () -> () "live.user"() {uses = [@unknown_symbol]} : () -> () } + +// ----- + +// Check that we don't DCE nested symbols if they are used even if nested inside +// an unnamed region. +// CHECK-LABEL: module attributes {test.nested_unnamed_region} +module attributes {test.nested_unnamed_region} { + "test.one_region_op"() ({ + "test.symbol_scope"() ({ + // CHECK: func @nested_function + func.func @nested_function() { + return + } + func.call @nested_function() : () -> () + "test.finish"() : () -> () + }) : () -> () + "test.finish"() : () -> () + }) : () -> () +}