From b069eff6b633517ae848ce01d548b39b3101b5c2 Mon Sep 17 00:00:00 2001 From: Kazuaki Matsumura Date: Thu, 6 Jun 2024 15:56:31 -0700 Subject: [PATCH 1/4] [flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE clause --- flang/lib/Lower/Bridge.cpp | 61 +++++++++++++++++++++++++++++++++++-- flang/test/Lower/loops3.f90 | 23 ++++++++++++++ 2 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 flang/test/Lower/loops3.f90 diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 512c7a349ae21..d0a0a36500f61 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -104,7 +104,7 @@ struct IncrementLoopInfo { bool hasLocalitySpecs() const { return !localSymList.empty() || !localInitSymList.empty() || - !sharedSymList.empty(); + !reduceSymList.empty() || !sharedSymList.empty(); } // Data members common to both structured and unstructured loops. @@ -116,6 +116,9 @@ struct IncrementLoopInfo { bool isUnordered; // do concurrent, forall llvm::SmallVector localSymList; llvm::SmallVector localInitSymList; + llvm::SmallVector< + std::pair> + reduceSymList; llvm::SmallVector sharedSymList; mlir::Value loopVariable = nullptr; @@ -1741,6 +1744,36 @@ class FirConverter : public Fortran::lower::AbstractConverter { builder->create(loc); } + fir::ReduceOperationEnum + getReduceOperationEnum(const Fortran::parser::ReductionOperator &rOpr) { + switch (rOpr.v) { + case Fortran::parser::ReductionOperator::Operator::Plus: + return fir::ReduceOperationEnum::Add; + case Fortran::parser::ReductionOperator::Operator::Multiply: + return fir::ReduceOperationEnum::Multiply; + case Fortran::parser::ReductionOperator::Operator::And: + return fir::ReduceOperationEnum::AND; + case Fortran::parser::ReductionOperator::Operator::Or: + return fir::ReduceOperationEnum::OR; + case Fortran::parser::ReductionOperator::Operator::Eqv: + return fir::ReduceOperationEnum::EQV; + case Fortran::parser::ReductionOperator::Operator::Neqv: + return fir::ReduceOperationEnum::NEQV; + case Fortran::parser::ReductionOperator::Operator::Max: + return fir::ReduceOperationEnum::MAX; + case Fortran::parser::ReductionOperator::Operator::Min: + return fir::ReduceOperationEnum::MIN; + case Fortran::parser::ReductionOperator::Operator::Iand: + return fir::ReduceOperationEnum::IAND; + case Fortran::parser::ReductionOperator::Operator::Ior: + return fir::ReduceOperationEnum::IOR; + case Fortran::parser::ReductionOperator::Operator::Ieor: + return fir::ReduceOperationEnum::EIOR; + } + fir::emitFatalError(toLocation(), "illegal reduction operator"); + return fir::ReduceOperationEnum::Add; + } + /// Collect DO CONCURRENT or FORALL loop control information. IncrementLoopNestInfo getConcurrentControl( const Fortran::parser::ConcurrentHeader &header, @@ -1763,6 +1796,16 @@ class FirConverter : public Fortran::lower::AbstractConverter { std::get_if(&x.u)) for (const Fortran::parser::Name &x : localInitList->v) info.localInitSymList.push_back(x.symbol); + if (const auto *reduceList = + std::get_if(&x.u)) { + fir::ReduceOperationEnum reduce_operation = getReduceOperationEnum( + std::get(reduceList->t)); + for (const Fortran::parser::Name &x : + std::get>(reduceList->t)) { + info.reduceSymList.push_back( + std::make_pair(reduce_operation, x.symbol)); + } + } if (const auto *sharedList = std::get_if(&x.u)) for (const Fortran::parser::Name &x : sharedList->v) @@ -1955,9 +1998,23 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::Type loopVarType = info.getLoopVariableType(); mlir::Value loopValue; if (info.isUnordered) { + llvm::SmallVector reduceOperands; + llvm::SmallVector reduceAttrs; + // Create DO CONCURRENT reduce operations and attributes + for (const auto reduceSym : info.reduceSymList) { + const fir::ReduceOperationEnum reduce_operation = reduceSym.first; + const Fortran::semantics::Symbol *sym = reduceSym.second; + fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr); + reduceOperands.push_back(fir::getBase(exv)); + auto reduce_attr = + fir::ReduceAttr::get(builder->getContext(), reduce_operation); + reduceAttrs.push_back(reduce_attr); + } // The loop variable value is explicitly updated. info.doLoop = builder->create( - loc, lowerValue, upperValue, stepValue, /*unordered=*/true); + loc, lowerValue, upperValue, stepValue, /*unordered=*/true, + /*finalCountValue=*/false, /*iterArgs=*/std::nullopt, + llvm::ArrayRef(reduceOperands), reduceAttrs); builder->setInsertionPointToStart(info.doLoop.getBody()); loopValue = builder->createConvert(loc, loopVarType, info.doLoop.getInductionVar()); diff --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90 new file mode 100644 index 0000000000000..dd24e26d72c31 --- /dev/null +++ b/flang/test/Lower/loops3.f90 @@ -0,0 +1,23 @@ +! Test do concurrent reduction +! RUN: bbc -emit-fir -hlfir=false -o - %s | FileCheck %s + +! CHECK-LABEL: loop_test +subroutine loop_test + integer(4) :: i, j, k, tmp, sum = 0 + real :: m + + i = 100 + j = 200 + k = 300 + + ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm" + ! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref + ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered { + ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered { + ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered reduce(#fir.reduce_attr -> %[[VAL_1:.*]] : !fir.ref, #fir.reduce_attr -> %[[VAL_0:.*]] : !fir.ref) { + do concurrent (i=1:5, j=1:5, k=1:5) local(tmp) reduce(+:sum) reduce(max:m) + tmp = i + j + k + sum = tmp + sum + m = max(m, sum) + enddo +end subroutine loop_test From c2607709f550caff3bdc6fe71329d941b87bd244 Mon Sep 17 00:00:00 2001 From: Kazuaki Matsumura Date: Thu, 6 Jun 2024 21:14:10 -0700 Subject: [PATCH 2/4] [flang] Close a brace --- flang/test/Lower/loops3.f90 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90 index dd24e26d72c31..2e62ee480ec8a 100644 --- a/flang/test/Lower/loops3.f90 +++ b/flang/test/Lower/loops3.f90 @@ -10,7 +10,7 @@ subroutine loop_test j = 200 k = 300 - ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm" + ! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"} ! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered { ! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered { From 3512291f5da53d5f9b9a0340822352cc1326f7d6 Mon Sep 17 00:00:00 2001 From: Kazuaki Matsumura Date: Thu, 6 Jun 2024 21:39:51 -0700 Subject: [PATCH 3/4] [flang] Fix a comment --- flang/lib/Lower/Bridge.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index d0a0a36500f61..dca71256192ed 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2000,7 +2000,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { if (info.isUnordered) { llvm::SmallVector reduceOperands; llvm::SmallVector reduceAttrs; - // Create DO CONCURRENT reduce operations and attributes + // Create DO CONCURRENT reduce operands and attributes for (const auto reduceSym : info.reduceSymList) { const fir::ReduceOperationEnum reduce_operation = reduceSym.first; const Fortran::semantics::Symbol *sym = reduceSym.second; From f9756968593f777cf3f12027eea2b89c55a2ea63 Mon Sep 17 00:00:00 2001 From: Kazuaki Matsumura Date: Thu, 6 Jun 2024 22:30:52 -0700 Subject: [PATCH 4/4] [flang] Use llvm_unreachable --- flang/lib/Lower/Bridge.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index dca71256192ed..14e99757925ac 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1770,8 +1770,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { case Fortran::parser::ReductionOperator::Operator::Ieor: return fir::ReduceOperationEnum::EIOR; } - fir::emitFatalError(toLocation(), "illegal reduction operator"); - return fir::ReduceOperationEnum::Add; + llvm_unreachable("illegal reduction operator"); } /// Collect DO CONCURRENT or FORALL loop control information.