diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 30699ecdde0a2..15b8aca5e267f 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -261,6 +261,8 @@ loopScheduling(scf::ForOp forOp, return 1; }; + std::optional ubConstant = getConstantIntValue(forOp.getUpperBound()); + std::optional lbConstant = getConstantIntValue(forOp.getLowerBound()); DenseMap opCycles; std::map> wrappedSchedule; for (Operation &op : forOp.getBody()->getOperations()) { @@ -271,7 +273,14 @@ loopScheduling(scf::ForOp forOp, Operation *def = operand.getDefiningOp(); if (!def) continue; - earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); + if (ubConstant && lbConstant) { + unsigned ubInt = ubConstant.value(); + unsigned lbInt = lbConstant.value(); + auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def)); + earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency); + } else { + earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); + } } opCycles[&op] = earlyCycle; wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir index a4daa86583c3d..2d6b48fd3e57c 100644 --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -300,3 +300,60 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + +// ----- + +// CHECK-LABEL: func.func @loop_pipeline +func.func @loop_pipeline(%arg0: memref<4x16xf32>, %arg1: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %c3 = arith.constant 3 : index + // CHECK: vector.transfer_read + // CHECK: vector.transfer_read + // CHECK: vector.transfer_read + // CHECK: arith.addf + // CHECK: arith.addf + // CHECK: arith.addf + %0 = scf.for %arg2 = %c0 to %c3 step %c1 iter_args(%arg3 = %arg1) -> (vector<16xf32>) { + %1 = vector.transfer_read %arg0[%arg2, %c0], %cst {in_bounds = [true]} : memref<4x16xf32>, vector<16xf32> + %2 = arith.addf %1, %arg3 : vector<16xf32> + scf.yield %2 : vector<16xf32> + } + return %0 : vector<16xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.for"> + %1 = transform.loop.pipeline %0 {iteration_interval = 1 : i64, read_latency = 5 : i64, scheduling_type = "full-loops"} : (!transform.op<"scf.for">) -> !transform.any_op + transform.yield + } +} + + +// ----- + +// CHECK-LABEL: func.func @loop_pipeline_lb_gt_0 +func.func @loop_pipeline_lb_gt_0(%arg0: memref<4x16xf32>, %arg1: vector<16xf32>) -> vector<16xf32> { + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %c3 = arith.constant 3 : index + // CHECK: vector.transfer_read + // CHECK: vector.transfer_read + // CHECK: arith.addf + // CHECK: arith.addf + %0 = scf.for %arg2 = %c1 to %c3 step %c1 iter_args(%arg3 = %arg1) -> (vector<16xf32>) { + %1 = vector.transfer_read %arg0[%arg2, %c1], %cst {in_bounds = [true]} : memref<4x16xf32>, vector<16xf32> + %2 = arith.addf %1, %arg3 : vector<16xf32> + scf.yield %2 : vector<16xf32> + } + return %0 : vector<16xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.for"> + %1 = transform.loop.pipeline %0 {iteration_interval = 1 : i64, read_latency = 5 : i64, scheduling_type = "full-loops"} : (!transform.op<"scf.for">) -> !transform.any_op + transform.yield + } +}