@@ -483,34 +483,32 @@ func.func @do_not_fuse_multiple_stores_on_diff_indices(
483
483
484
484
// -----
485
485
486
- func.func @fuse_reductions (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
486
+ func.func @fuse_reductions_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
487
487
%c2 = arith.constant 2 : index
488
488
%c0 = arith.constant 0 : index
489
489
%c1 = arith.constant 1 : index
490
490
%init1 = arith.constant 1.0 : f32
491
491
%init2 = arith.constant 2.0 : f32
492
492
%res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
493
493
%A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
494
- scf.reduce (%A_elem ) : f32 {
494
+ scf.reduce (%A_elem : f32 ) {
495
495
^bb0 (%lhs: f32 , %rhs: f32 ):
496
496
%1 = arith.addf %lhs , %rhs : f32
497
497
scf.reduce.return %1 : f32
498
498
}
499
- scf.yield
500
499
}
501
500
%res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
502
501
%B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
503
- scf.reduce (%B_elem ) : f32 {
502
+ scf.reduce (%B_elem : f32 ) {
504
503
^bb0 (%lhs: f32 , %rhs: f32 ):
505
504
%1 = arith.mulf %lhs , %rhs : f32
506
505
scf.reduce.return %1 : f32
507
506
}
508
- scf.yield
509
507
}
510
508
return %res1 , %res2 : f32 , f32
511
509
}
512
510
513
- // CHECK-LABEL: func @fuse_reductions
511
+ // CHECK-LABEL: func @fuse_reductions_two
514
512
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
515
513
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
516
514
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -521,44 +519,105 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
521
519
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
522
520
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
523
521
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
524
- // CHECK: scf.reduce(%[[VAL_A]]) : f32 {
522
+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
523
+ // CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
525
524
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
526
525
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
527
526
// CHECK: scf.reduce.return %[[R]] : f32
528
527
// CHECK: }
529
- // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
530
- // CHECK: scf.reduce(%[[VAL_B]]) : f32 {
531
528
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
532
529
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
533
530
// CHECK: scf.reduce.return %[[R]] : f32
534
531
// CHECK: }
535
- // CHECK: scf.yield
536
532
// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
537
533
538
534
// -----
539
535
536
+ func.func @fuse_reductions_three (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >, %C: memref <2 x2 xf32 >) -> (f32 , f32 , f32 ) {
537
+ %c2 = arith.constant 2 : index
538
+ %c0 = arith.constant 0 : index
539
+ %c1 = arith.constant 1 : index
540
+ %init1 = arith.constant 1.0 : f32
541
+ %init2 = arith.constant 2.0 : f32
542
+ %init3 = arith.constant 3.0 : f32
543
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
544
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
545
+ scf.reduce (%A_elem : f32 ) {
546
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
547
+ %1 = arith.addf %lhs , %rhs : f32
548
+ scf.reduce.return %1 : f32
549
+ }
550
+ }
551
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
552
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
553
+ scf.reduce (%B_elem : f32 ) {
554
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
555
+ %1 = arith.mulf %lhs , %rhs : f32
556
+ scf.reduce.return %1 : f32
557
+ }
558
+ }
559
+ %res3 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init3 ) -> f32 {
560
+ %A_elem = memref.load %C [%i , %j ] : memref <2 x2 xf32 >
561
+ scf.reduce (%A_elem : f32 ) {
562
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
563
+ %1 = arith.addf %lhs , %rhs : f32
564
+ scf.reduce.return %1 : f32
565
+ }
566
+ }
567
+ return %res1 , %res2 , %res3 : f32 , f32 , f32
568
+ }
569
+
570
+ // CHECK-LABEL: func @fuse_reductions_three
571
+ // CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
572
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
573
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
574
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
575
+ // CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
576
+ // CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
577
+ // CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
578
+ // CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
579
+ // CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
580
+ // CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
581
+ // CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
582
+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
583
+ // CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
584
+ // CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
585
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
586
+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
587
+ // CHECK: scf.reduce.return %[[R]] : f32
588
+ // CHECK: }
589
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
590
+ // CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
591
+ // CHECK: scf.reduce.return %[[R]] : f32
592
+ // CHECK: }
593
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
594
+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
595
+ // CHECK: scf.reduce.return %[[R]] : f32
596
+ // CHECK: }
597
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
598
+
599
+ // -----
600
+
540
601
func.func @reductions_use_res (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
541
602
%c2 = arith.constant 2 : index
542
603
%c0 = arith.constant 0 : index
543
604
%c1 = arith.constant 1 : index
544
605
%init1 = arith.constant 1.0 : f32
545
606
%res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
546
607
%A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
547
- scf.reduce (%A_elem ) : f32 {
608
+ scf.reduce (%A_elem : f32 ) {
548
609
^bb0 (%lhs: f32 , %rhs: f32 ):
549
610
%1 = arith.addf %lhs , %rhs : f32
550
611
scf.reduce.return %1 : f32
551
612
}
552
- scf.yield
553
613
}
554
614
%res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%res1 ) -> f32 {
555
615
%B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
556
- scf.reduce (%B_elem ) : f32 {
616
+ scf.reduce (%B_elem : f32 ) {
557
617
^bb0 (%lhs: f32 , %rhs: f32 ):
558
618
%1 = arith.mulf %lhs , %rhs : f32
559
619
scf.reduce.return %1 : f32
560
620
}
561
- scf.yield
562
621
}
563
622
return %res1 , %res2 : f32 , f32
564
623
}
@@ -578,22 +637,20 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
578
637
%init2 = arith.constant 2.0 : f32
579
638
%res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
580
639
%A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
581
- scf.reduce (%A_elem ) : f32 {
640
+ scf.reduce (%A_elem : f32 ) {
582
641
^bb0 (%lhs: f32 , %rhs: f32 ):
583
642
%1 = arith.addf %lhs , %rhs : f32
584
643
scf.reduce.return %1 : f32
585
644
}
586
- scf.yield
587
645
}
588
646
%res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
589
647
%B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
590
648
%sum = arith.addf %B_elem , %res1 : f32
591
- scf.reduce (%sum ) : f32 {
649
+ scf.reduce (%sum : f32 ) {
592
650
^bb0 (%lhs: f32 , %rhs: f32 ):
593
651
%1 = arith.mulf %lhs , %rhs : f32
594
652
scf.reduce.return %1 : f32
595
653
}
596
- scf.yield
597
654
}
598
655
return %res1 , %res2 : f32 , f32
599
656
}
0 commit comments