diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index d3dca1427e517..5934d85373b03 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -161,29 +161,85 @@ static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, } /// Prepends operations of firstPloop's body into secondPloop's body. -static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, - OpBuilder b, +/// Updates secondPloop with new loop. +static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, + OpBuilder builder, llvm::function_ref mayAlias) { + Block *block1 = firstPloop.getBody(); + Block *block2 = secondPloop.getBody(); IRMapping firstToSecondPloopIndices; - firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), - secondPloop.getBody()->getArguments()); + firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments()); if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)) return; - b.setInsertionPointToStart(secondPloop.getBody()); - for (auto &op : firstPloop.getBody()->without_terminator()) - b.clone(op, firstToSecondPloopIndices); + DominanceInfo dom; + // We are fusing first loop into second, make sure there are no users of the + // first loop results between loops. + for (Operation *user : firstPloop->getUsers()) + if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) + return; + + ValueRange inits1 = firstPloop.getInitVals(); + ValueRange inits2 = secondPloop.getInitVals(); + + SmallVector newInitVars(inits1.begin(), inits1.end()); + newInitVars.append(inits2.begin(), inits2.end()); + + IRRewriter b(builder); + b.setInsertionPoint(secondPloop); + auto newSecondPloop = b.create( + secondPloop.getLoc(), secondPloop.getLowerBound(), + secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); + + Block *newBlock = newSecondPloop.getBody(); + auto term1 = cast(block1->getTerminator()); + auto term2 = cast(block2->getTerminator()); + + b.inlineBlockBefore(block2, newBlock, newBlock->begin(), + newBlock->getArguments()); + b.inlineBlockBefore(block1, newBlock, newBlock->begin(), + newBlock->getArguments()); + + ValueRange results = newSecondPloop.getResults(); + if (!results.empty()) { + b.setInsertionPointToEnd(newBlock); + + ValueRange reduceArgs1 = term1.getOperands(); + ValueRange reduceArgs2 = term2.getOperands(); + SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); + newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); + + auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); + + for (auto &&[i, reg] : llvm::enumerate(llvm::concat( + term1.getReductions(), term2.getReductions()))) { + Block &oldRedBlock = reg.front(); + Block &newRedBlock = newReduceOp.getReductions()[i].front(); + b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), + newRedBlock.getArguments()); + } + + firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); + secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); + } + term1->erase(); + term2->erase(); firstPloop.erase(); + secondPloop.erase(); + secondPloop = newSecondPloop; } void mlir::scf::naivelyFuseParallelOps( Region ®ion, llvm::function_ref mayAlias) { OpBuilder b(region); // Consider every single block and attempt to fuse adjacent loops. + SmallVector, 1> ploopChains; for (auto &block : region) { - SmallVector, 1> ploopChains{{}}; + ploopChains.clear(); + ploopChains.push_back({}); + // Not using `walk()` to traverse only top-level parallel loops and also // make sure that there are no side-effecting ops between the parallel // loops. @@ -201,7 +257,7 @@ void mlir::scf::naivelyFuseParallelOps( // TODO: Handle region side effects properly. noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0; } - for (ArrayRef ploops : ploopChains) { + for (MutableArrayRef ploops : ploopChains) { for (int i = 0, e = ploops.size(); i + 1 < e; ++i) fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias); } diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir index 9c136bb635658..0d4ea6f20e8d9 100644 --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -24,6 +24,32 @@ func.func @fuse_empty_loops() { // ----- +func.func @fuse_ops_between(%A: f32, %B: f32) -> f32 { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + scf.reduce + } + %res = arith.addf %A, %B : f32 + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + scf.reduce + } + return %res : f32 +} +// CHECK-LABEL: func @fuse_ops_between +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32 +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: scf.reduce +// CHECK: } +// CHECK-NOT: scf.parallel + +// ----- + func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index @@ -89,7 +115,7 @@ func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { memref.store %product_elem, %prod[%i, %j] : memref<2x2xf32> scf.reduce } - scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> %res_elem = arith.addf %A_elem, %c2fp : f32 memref.store %res_elem, %B[%i, %j] : memref<2x2xf32> @@ -575,3 +601,215 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var( // CHECK-NEXT: } // CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32> // CHECK-NEXT: return + +// ----- + +func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2 : f32, f32 +} + +// CHECK-LABEL: func @fuse_reductions_two +// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) +// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32) +// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] +// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] +// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32 + +// ----- + +func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %init3 = arith.constant 3.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 { + %A_elem = memref.load %C[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2, %res3 : f32, f32, f32 +} + +// CHECK-LABEL: func @fuse_reductions_three +// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32 +// CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) +// CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32) +// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] +// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] +// CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]] +// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32 + +// ----- + +func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2 : f32, f32 +} + +// %res1 is used as second scf.parallel arg, cannot fuse +// CHECK-LABEL: func @reductions_use_res +// CHECK: scf.parallel +// CHECK: scf.parallel + +// ----- + +func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + %sum = arith.addf %B_elem, %res1 : f32 + scf.reduce(%sum : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2 : f32, f32 +} + +// %res1 is used inside second scf.parallel, cannot fuse +// CHECK-LABEL: func @reductions_use_res_inside +// CHECK: scf.parallel +// CHECK: scf.parallel + +// ----- + +func.func @reductions_use_res_between(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32, f32) { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init1 = arith.constant 1.0 : f32 + %init2 = arith.constant 2.0 : f32 + %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { + %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> + scf.reduce(%A_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + %res3 = arith.addf %res1, %init2 : f32 + %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { + %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> + scf.reduce(%B_elem : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %1 : f32 + } + } + return %res1, %res2, %res3 : f32, f32, f32 +} + +// instruction in between the loops uses the first loop result +// CHECK-LABEL: func @reductions_use_res_between +// CHECK: scf.parallel +// CHECK: scf.parallel