Skip to content

Commit 8cb630f

Browse files
committed
Update to new reductions format
1 parent 57cfd83 commit 8cb630f

File tree

2 files changed

+102
-25
lines changed

2 files changed

+102
-25
lines changed

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,38 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
169169
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
170170

171171
Block *newBlock = newSecondPloop.getBody();
172-
newBlock->getTerminator()->erase();
172+
auto term1 = cast<ReduceOp>(block1->getTerminator());
173+
auto term2 = cast<ReduceOp>(block2->getTerminator());
173174

174-
block1->getTerminator()->erase();
175-
176-
b.inlineBlockBefore(block1, newBlock, newBlock->end(),
175+
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
177176
newBlock->getArguments());
178-
b.inlineBlockBefore(block2, newBlock, newBlock->end(),
177+
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
179178
newBlock->getArguments());
180179

181180
ValueRange results = newSecondPloop.getResults();
182-
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
183-
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
181+
if (!results.empty()) {
182+
b.setInsertionPointToEnd(newBlock);
183+
184+
ValueRange reduceArgs1 = term1.getOperands();
185+
ValueRange reduceArgs2 = term2.getOperands();
186+
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
187+
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
188+
189+
auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
190+
191+
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
192+
term1.getReductions(), term2.getReductions()))) {
193+
Block &oldRedBlock = reg.front();
194+
Block &newRedBlock = newReduceOp.getReductions()[i].front();
195+
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
196+
newRedBlock.getArguments());
197+
}
198+
199+
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
200+
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
201+
}
202+
term1->erase();
203+
term2->erase();
184204
firstPloop.erase();
185205
secondPloop.erase();
186206
secondPloop = newSecondPloop;

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -483,34 +483,32 @@ func.func @do_not_fuse_multiple_stores_on_diff_indices(
483483

484484
// -----
485485

486-
func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
486+
func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
487487
%c2 = arith.constant 2 : index
488488
%c0 = arith.constant 0 : index
489489
%c1 = arith.constant 1 : index
490490
%init1 = arith.constant 1.0 : f32
491491
%init2 = arith.constant 2.0 : f32
492492
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
493493
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
494-
scf.reduce(%A_elem) : f32 {
494+
scf.reduce(%A_elem : f32) {
495495
^bb0(%lhs: f32, %rhs: f32):
496496
%1 = arith.addf %lhs, %rhs : f32
497497
scf.reduce.return %1 : f32
498498
}
499-
scf.yield
500499
}
501500
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
502501
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
503-
scf.reduce(%B_elem) : f32 {
502+
scf.reduce(%B_elem : f32) {
504503
^bb0(%lhs: f32, %rhs: f32):
505504
%1 = arith.mulf %lhs, %rhs : f32
506505
scf.reduce.return %1 : f32
507506
}
508-
scf.yield
509507
}
510508
return %res1, %res2 : f32, f32
511509
}
512510

513-
// CHECK-LABEL: func @fuse_reductions
511+
// CHECK-LABEL: func @fuse_reductions_two
514512
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
515513
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
516514
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -521,44 +519,105 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
521519
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
522520
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
523521
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
524-
// CHECK: scf.reduce(%[[VAL_A]]) : f32 {
522+
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
523+
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
525524
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
526525
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
527526
// CHECK: scf.reduce.return %[[R]] : f32
528527
// CHECK: }
529-
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
530-
// CHECK: scf.reduce(%[[VAL_B]]) : f32 {
531528
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
532529
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
533530
// CHECK: scf.reduce.return %[[R]] : f32
534531
// CHECK: }
535-
// CHECK: scf.yield
536532
// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
537533

538534
// -----
539535

536+
func.func @fuse_reductions_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C: memref<2x2xf32>) -> (f32, f32, f32) {
537+
%c2 = arith.constant 2 : index
538+
%c0 = arith.constant 0 : index
539+
%c1 = arith.constant 1 : index
540+
%init1 = arith.constant 1.0 : f32
541+
%init2 = arith.constant 2.0 : f32
542+
%init3 = arith.constant 3.0 : f32
543+
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
544+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
545+
scf.reduce(%A_elem : f32) {
546+
^bb0(%lhs: f32, %rhs: f32):
547+
%1 = arith.addf %lhs, %rhs : f32
548+
scf.reduce.return %1 : f32
549+
}
550+
}
551+
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
552+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
553+
scf.reduce(%B_elem : f32) {
554+
^bb0(%lhs: f32, %rhs: f32):
555+
%1 = arith.mulf %lhs, %rhs : f32
556+
scf.reduce.return %1 : f32
557+
}
558+
}
559+
%res3 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init3) -> f32 {
560+
%A_elem = memref.load %C[%i, %j] : memref<2x2xf32>
561+
scf.reduce(%A_elem : f32) {
562+
^bb0(%lhs: f32, %rhs: f32):
563+
%1 = arith.addf %lhs, %rhs : f32
564+
scf.reduce.return %1 : f32
565+
}
566+
}
567+
return %res1, %res2, %res3 : f32, f32, f32
568+
}
569+
570+
// CHECK-LABEL: func @fuse_reductions_three
571+
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
572+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
573+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
574+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
575+
// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
576+
// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
577+
// CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
578+
// CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
579+
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
580+
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
581+
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
582+
// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
583+
// CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
584+
// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
585+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
586+
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
587+
// CHECK: scf.reduce.return %[[R]] : f32
588+
// CHECK: }
589+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
590+
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
591+
// CHECK: scf.reduce.return %[[R]] : f32
592+
// CHECK: }
593+
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
594+
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
595+
// CHECK: scf.reduce.return %[[R]] : f32
596+
// CHECK: }
597+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
598+
599+
// -----
600+
540601
func.func @reductions_use_res(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
541602
%c2 = arith.constant 2 : index
542603
%c0 = arith.constant 0 : index
543604
%c1 = arith.constant 1 : index
544605
%init1 = arith.constant 1.0 : f32
545606
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
546607
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
547-
scf.reduce(%A_elem) : f32 {
608+
scf.reduce(%A_elem : f32) {
548609
^bb0(%lhs: f32, %rhs: f32):
549610
%1 = arith.addf %lhs, %rhs : f32
550611
scf.reduce.return %1 : f32
551612
}
552-
scf.yield
553613
}
554614
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%res1) -> f32 {
555615
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
556-
scf.reduce(%B_elem) : f32 {
616+
scf.reduce(%B_elem : f32) {
557617
^bb0(%lhs: f32, %rhs: f32):
558618
%1 = arith.mulf %lhs, %rhs : f32
559619
scf.reduce.return %1 : f32
560620
}
561-
scf.yield
562621
}
563622
return %res1, %res2 : f32, f32
564623
}
@@ -578,22 +637,20 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
578637
%init2 = arith.constant 2.0 : f32
579638
%res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 {
580639
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
581-
scf.reduce(%A_elem) : f32 {
640+
scf.reduce(%A_elem : f32) {
582641
^bb0(%lhs: f32, %rhs: f32):
583642
%1 = arith.addf %lhs, %rhs : f32
584643
scf.reduce.return %1 : f32
585644
}
586-
scf.yield
587645
}
588646
%res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 {
589647
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
590648
%sum = arith.addf %B_elem, %res1 : f32
591-
scf.reduce(%sum) : f32 {
649+
scf.reduce(%sum : f32) {
592650
^bb0(%lhs: f32, %rhs: f32):
593651
%1 = arith.mulf %lhs, %rhs : f32
594652
scf.reduce.return %1 : f32
595653
}
596-
scf.yield
597654
}
598655
return %res1, %res2 : f32, f32
599656
}

0 commit comments

Comments
 (0)