diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 69c497264fd1e..f29eba90c3ceb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1237,6 +1237,10 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) return failure(); + auto isUnitDim = [](VectorType type, int dim) { + return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim]; + }; + // According to vector.transfer_read/write semantics, the vector can be a // slice. Thus, we have to offset the check index with `rankDiff` in // `srcStrides` and source dim sizes. @@ -1247,8 +1251,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { // It can be folded only if they are 1 and the stride is 1. int dim = vectorType.getRank() - i - 1; if (srcStrides[dim + rankDiff] != 1 || - srcType.getDimSize(dim + rankDiff) != 1 || - vectorType.getDimSize(dim) != 1) + srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim)) break; result++; } @@ -1292,7 +1295,8 @@ class DropInnerMostUnitDimsTransferRead auto resultTargetVecType = VectorType::get(targetType.getShape().drop_back(dimsToDrop), - targetType.getElementType()); + targetType.getElementType(), + targetType.getScalableDims().drop_back(dimsToDrop)); auto loc = readOp.getLoc(); SmallVector sizes = @@ -1378,7 +1382,8 @@ class DropInnerMostUnitDimsTransferWrite auto resultTargetVecType = VectorType::get(targetType.getShape().drop_back(dimsToDrop), - targetType.getElementType()); + targetType.getElementType(), + targetType.getScalableDims().drop_back(dimsToDrop)); Location loc = writeOp.getLoc(); SmallVector sizes = diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index 477755b66c020..b4cb640108bae 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -174,3 +174,59 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o // The inner most unit dims can not be dropped if the strides are not ones. // CHECK: func.func @non_unit_strides // CHECK-NOT: memref.subview + +// ----- + +func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>) -> vector<[4]x1xf32> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32> + return %0 : vector<[4]x1xf32> +} +// CHECK: func.func @leading_scalable_dimension_transfer_read +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>> +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32> +// CHECK: return %[[CAST]] + +// ----- + +// Negative test: [1] (scalable 1) is _not_ a unit dimension. +func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>) -> vector<4x[1]xf32> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32> + return %0 : vector<4x[1]xf32> +} +// CHECK: func.func @trailing_scalable_one_dim_transfer_read +// CHECK-NOT: vector.shape_cast +// CHECK: vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32> +// CHECK-NOT: vector.shape_cast + +// ----- + +func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32> + return +} +// CHECK: func.func @leading_scalable_dimension_transfer_write +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32> +// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>> + +// ----- + +// Negative test: [1] (scalable 1) is _not_ a unit dimension. +func.func @trailing_scalable_one_dim_transfer_write(%dest : memref<24x1xf32>, %vec: vector<4x[1]xf32>, %index: index) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<4x[1]xf32>, memref<24x1xf32> + return +} +// CHECK: func.func @trailing_scalable_one_dim_transfer_write +// CHECK-NOT: vector.shape_cast +// CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32> +// CHECK-NOT: vector.shape_cast