diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 78015e3deeb96..e128cc71a5d62 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -406,19 +406,29 @@ struct WarpOpToScfIfPattern : public OpRewritePattern { static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, vector::TransferWriteOp writeOp, - VectorType targetType) { + VectorType targetType, + VectorType maybeMaskType) { assert(writeOp->getParentOp() == warpOp && "write must be nested immediately under warp"); OpBuilder::InsertionGuard g(rewriter); SmallVector newRetIndices; - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, ValueRange{{writeOp.getVector()}}, - TypeRange{targetType}, newRetIndices); + WarpExecuteOnLane0Op newWarpOp; + if (maybeMaskType) { + newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()}, + TypeRange{targetType, maybeMaskType}, newRetIndices); + } else { + newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, ValueRange{{writeOp.getVector()}}, + TypeRange{targetType}, newRetIndices); + } rewriter.setInsertionPointAfter(newWarpOp); auto newWriteOp = cast(rewriter.clone(*writeOp.getOperation())); rewriter.eraseOp(writeOp); newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); + if (maybeMaskType) + newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1])); return newWriteOp; } @@ -489,10 +499,25 @@ struct WarpOpTransferWrite : public OpRewritePattern { if (!targetType) return failure(); + // 2.5 Compute the distributed type for the new mask; + VectorType maskType; + if (writeOp.getMask()) { + // TODO: Distribution of masked writes with non-trivial permutation maps + // requires the distribution of the mask to elementwise match the + // distribution of the permuted written vector. Currently the details + // of which lane is responsible for which element is captured strictly + // by shape information on the warp op, and thus requires materializing + // the permutation in IR. + if (!writeOp.getPermutationMap().isMinorIdentity()) + return failure(); + maskType = + getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize()); + } + // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from // the rest. vector::TransferWriteOp newWriteOp = - cloneWriteOp(rewriter, warpOp, writeOp, targetType); + cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType); // 4. Reindex the write using the distribution map. auto newWarpOp = @@ -561,10 +586,6 @@ struct WarpOpTransferWrite : public OpRewritePattern { LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { - // Ops with mask not supported yet. - if (writeOp.getMask()) - return failure(); - auto warpOp = dyn_cast(writeOp->getParentOp()); if (!warpOp) return failure(); @@ -575,8 +596,10 @@ struct WarpOpTransferWrite : public OpRewritePattern { if (!isMemoryEffectFree(nextOp)) return failure(); + Value maybeMask = writeOp.getMask(); if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { return writeOp.getVector() == value || + (maybeMask && maybeMask == value) || warpOp.isDefinedOutsideOfRegion(value); })) return failure(); @@ -584,6 +607,10 @@ struct WarpOpTransferWrite : public OpRewritePattern { if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) return success(); + // Masked writes not supported for extraction. + if (writeOp.getMask()) + return failure(); + if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) return success(); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 5ec02ce002ffb..f050bcd246e5e 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1253,3 +1253,31 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409 // CHECK-PROP-SAME: (%{{.+}}: index, %[[SRC:.+]]: memref<4096xf32>, %[[INDEX:.+]]: index) // CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %cst {in_bounds = [true]} : memref<4096xf32>, vector<1xf32> // CHECK-PROP: return %[[READ]] : vector<1xf32> + +// ----- + +func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) { + %c0 = arith.constant 0 : index + vector.warp_execute_on_lane_0(%laneid)[32] -> () { + %mask = "mask_def_0"() : () -> (vector<4096xi1>) + %mask2 = "mask_def_1"() : () -> (vector<32xi1>) + %0 = "some_def_0"() : () -> (vector<4096xf32>) + %1 = "some_def_1"() : () -> (vector<32xf32>) + vector.transfer_write %0, %dest[%c0], %mask : vector<4096xf32>, memref<4096xf32> + vector.transfer_write %1, %dest[%c0], %mask2 : vector<32xf32>, memref<4096xf32> + vector.yield + } + return +} + +// CHECK-DIST-AND-PROP-LABEL: func.func @warp_propagate_masked_write( +// CHECK-DIST-AND-PROP: %[[W:.*]]:4 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xi1>, vector<128xf32>, vector<128xi1>) { +// CHECK-DIST-AND-PROP: %[[M0:.*]] = "mask_def_0" +// CHECK-DIST-AND-PROP: %[[M1:.*]] = "mask_def_1" +// CHECK-DIST-AND-PROP: %[[V0:.*]] = "some_def_0" +// CHECK-DIST-AND-PROP: %[[V1:.*]] = "some_def_1" +// CHECK-DIST-AND-PROP: vector.yield %[[V1]], %[[M1]], %[[V0]], %[[M0]] +// CHECK-DIST-AND-PROP-SAME: vector<32xf32>, vector<32xi1>, vector<4096xf32>, vector<4096xi1> +// CHECK-DIST-AND-PROP: } +// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32> +// CHECK-DIST-AND-PROP: vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>