-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][TOSA] Remove rollback from TOSA -> Linalg patterns #136308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesReorganize the implementation slightly, such that patterns check all preconditions before starting the actual rewrite. I.e., pattern no longer start rewriting and then abort, which would cause a pattern rollback. Pattern rollbacks are expensive and will be disallowed as part of the One-Shot Dialect Conversion refactoring. Full diff: https://github.com/llvm/llvm-project/pull/136308.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 9ca93ab28daed..bc4ef58cbcd62 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -91,6 +91,50 @@ createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
+/// Return "failure" if the given elementwise operation cannot be converted.
+static LogicalResult
+isSupportedElementwiseOperation(ConversionPatternRewriter &rewriter,
+ Operation *op, RankedTensorType resultType) {
+ auto elementTy =
+ cast<ShapedType>(op->getOperand(0).getType()).getElementType();
+
+ // tosa::MulOp
+ if (isa<tosa::MulOp>(op)) {
+ auto shiftVal = cast<tosa::MulOp>(op).getShift();
+ DenseElementsAttr shiftElem;
+ if (!matchPattern(shiftVal, m_Constant(&shiftElem)))
+ return rewriter.notifyMatchFailure(op, "shift value of mul not found");
+
+ int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ if (isa<FloatType>(elementTy) && shift != 0)
+ return rewriter.notifyMatchFailure(op,
+ "Cannot have shift value for float");
+ return success();
+ }
+
+ // tosa::NegateOp
+ if (isa<tosa::NegateOp>(op)) {
+ auto negate = cast<tosa::NegateOp>(op);
+ if (failed(negate.getInput1ZeroPoint()))
+ return rewriter.notifyMatchFailure(
+ op, "input1 zero point cannot be statically determined");
+ if (failed(negate.getOutputZeroPoint()))
+ return rewriter.notifyMatchFailure(
+ op, "output zero point cannot be statically determined");
+ return success();
+ }
+
+ // tosa::CastOp
+ if (isa<tosa::CastOp>(op)) {
+ if (!elementTy.isIntOrFloat() ||
+ !resultType.getElementType().isIntOrFloat())
+ return rewriter.notifyMatchFailure(op, "unsupported type");
+ return success();
+ }
+
+ return success();
+}
+
static Value createLinalgBodyCalculationForElementwiseOp(
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
ConversionPatternRewriter &rewriter) {
@@ -139,17 +183,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
auto shiftVal = cast<tosa::MulOp>(op).getShift();
DenseElementsAttr shiftElem;
if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
- (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
- return nullptr;
+ llvm_unreachable("shift value of mul not found");
}
int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
- (void)rewriter.notifyMatchFailure(op,
- "Cannot have shift value for float");
- return nullptr;
+ llvm_unreachable("Cannot have shift value for float");
}
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}
@@ -196,16 +237,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
if (failed(maybeInZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input1 zero point cannot be statically determined");
- return nullptr;
+ llvm_unreachable("input1 zero point cannot be statically determined");
}
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
if (failed(maybeOutZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return nullptr;
+ llvm_unreachable("output zero point cannot be statically determined");
}
int64_t inZp = *maybeInZp;
@@ -548,10 +585,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::CastOp>(op)) {
Type srcTy = elementTy;
Type dstTy = resultTypes.front();
- if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
- (void)rewriter.notifyMatchFailure(op, "unsupported type");
- return nullptr;
- }
+ if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat())
+ llvm_unreachable("unsupported type");
bool bitExtend =
srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
@@ -706,8 +741,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
}
- (void)rewriter.notifyMatchFailure(
- op, "unhandled op for linalg body calculation for elementwise op");
+ llvm_unreachable(
+ "unhandled op for linalg body calculation for elementwise op");
return nullptr;
}
@@ -930,17 +965,11 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
});
}
-static LogicalResult
-emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
- Operation *operation, ValueRange operands,
- ArrayRef<OpFoldResult> targetShape,
- const TypeConverter &converter) {
+static LogicalResult emitElementwiseComputation(
+ ConversionPatternRewriter &rewriter, Location loc, Operation *operation,
+ ValueRange operands, ArrayRef<OpFoldResult> targetShape,
+ const TypeConverter &converter, RankedTensorType resultType) {
// Generate output tensor
- auto resultType = cast_or_null<RankedTensorType>(
- converter.convertType(operation->getResultTypes().front()));
- if (!resultType) {
- return rewriter.notifyMatchFailure(operation, "failed to convert type");
- }
Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, targetShape, resultType.getElementType());
@@ -967,7 +996,6 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
// Emit 'linalg.generic' op
- bool encounteredError = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, outputTensor.getType(), operands, outputTensor, affineMaps,
getNParallelLoopsAttrs(rank),
@@ -975,15 +1003,10 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
Value opResult = createLinalgBodyCalculationForElementwiseOp(
operation, blockArgs.take_front(operation->getNumOperands()),
{resultType.getElementType()}, rewriter);
- if (!opResult) {
- encounteredError = true;
- return;
- }
+ assert(opResult &&
+ "unable to create linalg.generic body for elementwise op");
opBuilder.create<linalg::YieldOp>(loc, opResult);
});
- if (encounteredError)
- return rewriter.notifyMatchFailure(
- operation, "unable to create linalg.generic body for elementwise op");
// Cast 'linalg.generic' result into original result type if needed
auto castResult = rewriter.createOrFold<tensor::CastOp>(
@@ -1008,13 +1031,20 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
ConversionPatternRewriter &rewriter,
const TypeConverter &converter) {
- // Collect op properties
+ // Check if operation is supported.
assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
assert(operation->getNumOperands() >= 1 &&
"elementwise op expects at least 1 operand");
if (!operandsAndResultsRanked(operation))
return rewriter.notifyMatchFailure(operation,
"Unranked tensors not supported");
+ auto resultType = cast_or_null<RankedTensorType>(
+ converter.convertType(operation->getResultTypes().front()));
+ if (!resultType) {
+ return rewriter.notifyMatchFailure(operation, "failed to convert type");
+ }
+ if (failed(isSupportedElementwiseOperation(rewriter, operation, resultType)))
+ return failure();
// Lower operation
IndexPool indexPool;
@@ -1026,7 +1056,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
targetShape, masterOperands);
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
- targetShape, converter);
+ targetShape, converter, resultType);
}
// Returns the constant initial value for a given reduction operation. The
@@ -1126,7 +1156,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::OrIOp>(loc, args);
- return {};
+ llvm_unreachable("unhandled reduction op");
}
// Performs the match and rewrite for reduction operations. This includes
@@ -1142,6 +1172,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
auto elementTy = resultTy.getElementType();
+ auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
+ if (!fillValueAttr)
+ return rewriter.notifyMatchFailure(
+ op, "No initial value found for reduction operation");
Value input = op->getOperand(0);
SmallVector<int64_t> reduceShape;
@@ -1164,11 +1198,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
dynDims)
.getResult();
- auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
- if (!fillValueAttr)
- return rewriter.notifyMatchFailure(
- op, "No initial value found for reduction operation");
-
auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
auto filledTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValue},
@@ -1212,7 +1241,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
}
}
- bool didEncounterError = false;
linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
loc, inputs, outputs, axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
@@ -1220,8 +1248,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
auto result = createLinalgBodyCalculationForReduceOp(
op, binaryArgs, elementTy, rewriter);
- if (result)
- didEncounterError = true;
+ assert(result && "could not create reduction body");
SmallVector<Value> resultsToYield;
if (isNanIgnoreMode) {
@@ -1247,10 +1274,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
});
- if (!didEncounterError)
- return rewriter.notifyMatchFailure(
- op, "unable to create linalg.generic body for reduce op");
-
if (isNanIgnoreMode) {
// Materialize a check to see whether we encountered any non-NaN values, if
// we didn't we need to select a tensor of NaNs since the result will just
@@ -1358,13 +1381,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
if (!isa<IntegerType>(inputTy.getElementType()))
return rewriter.notifyMatchFailure(op, "only support integer type");
- SmallVector<Value> dynDims;
- for (int i = 0; i < outputTy.getRank(); i++) {
- if (outputTy.isDynamicDim(i)) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
- }
- }
-
// The shift and multiplier values.
DenseElementsAttr shiftElems;
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
@@ -1376,6 +1392,21 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant multiplier input values");
+ if (failed(op.getInputZeroPoint()))
+ return rewriter.notifyMatchFailure(
+ op, "input zero point cannot be statically determined");
+
+ if (failed(op.getOutputZeroPoint()))
+ return rewriter.notifyMatchFailure(
+ op, "output zero point cannot be statically determined");
+
+ SmallVector<Value> dynDims;
+ for (int i = 0; i < outputTy.getRank(); i++) {
+ if (outputTy.isDynamicDim(i)) {
+ dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ }
+ }
+
llvm::SmallVector<int8_t> shiftValues =
llvm::to_vector(shiftElems.getValues<int8_t>());
// explicit cast is required here
@@ -1473,23 +1504,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
auto inputZp = createConstOpFromZpVal<int32_t>(
op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
nestedBuilder);
-
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
-
auto outputZp = createConstOpFromZpVal<int32_t>(
op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
@@ -1783,6 +1801,15 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
+ SmallVector<int64_t> scale, offset, border;
+ if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
+ !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
+ !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
+ return rewriter.notifyMatchFailure(
+ op, "tosa.resize scale/offset/border should have compile time "
+ "constant values.");
+ }
+
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
@@ -1810,15 +1837,6 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
- SmallVector<int64_t> scale, offset, border;
- if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
- !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
- !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
- return rewriter.notifyMatchFailure(
- op, "tosa.resize scale/offset/border should have compile time "
- "constant values.");
- }
-
Value yScaleN, yScaleD, xScaleN, xScaleD;
yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
@@ -2204,6 +2222,9 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
auto inputTy = cast<ShapedType>(input.getType());
auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
auto inElementTy = inputTy.getElementType();
+ if (!isa<IntegerType, FloatType>(inElementTy))
+ return rewriter.notifyMatchFailure(
+ argmaxOp, "unsupported tosa.argmax element type");
auto outElementTy = resultTy.getElementType();
int axis = argmaxOp.getAxis();
auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
@@ -2213,6 +2234,12 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
argmaxOp,
"tosa.arg_max to linalg.* requires integer-like result type");
+ auto fillValueMaxAttr =
+ createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
+ if (!fillValueMaxAttr)
+ return rewriter.notifyMatchFailure(
+ argmaxOp, "unsupported tosa.argmax element type");
+
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) && i != axis) {
@@ -2238,12 +2265,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
inElementTy, dynDims)
.getResult();
- auto fillValueMaxAttr =
- createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
-
- if (!fillValueMaxAttr)
- return rewriter.notifyMatchFailure(
- argmaxOp, "unsupported tosa.argmax element type");
auto fillValueMax =
rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
@@ -2267,7 +2288,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
- bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
rewriter.getContext());
auto linalgOp = rewriter.create<linalg::GenericOp>(
@@ -2305,8 +2325,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
predicate = rewriter.create<arith::CmpIOp>(
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
} else {
- didEncounterError = true;
- return;
+ llvm_unreachable("unsupported tosa.argmax element type");
}
auto resultMax = rewriter.create<arith::SelectOp>(
@@ -2317,10 +2336,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
nestedLoc, ValueRange({resultIndex, resultMax}));
});
- if (didEncounterError)
- return rewriter.notifyMatchFailure(
- argmaxOp, "unsupported tosa.argmax element type");
-
rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
return success();
}
@@ -2416,6 +2431,15 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
auto tableElementTy = tableTy.getElementType();
auto resultElementTy = resultTy.getElementType();
+ bool isI8_8_8 = inputElementTy.isInteger(8) &&
+ tableElementTy.isInteger(8) && resultElementTy.isInteger(8);
+ bool isI16_16_32 = inputElementTy.isInteger(16) &&
+ tableElementTy.isInteger(16) &&
+ resultElementTy.isInteger(32);
+ if (!isI8_8_8 && !isI16_16_32)
+ return rewriter.notifyMatchFailure(
+ op, "unable to create body for tosa.table op");
+
SmallVector<Value> dynDims;
for (int i = 0; i < resultTy.getRank(); ++i) {
if (inputTy.isDynamicDim(i)) {
@@ -2446,8 +2470,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
auto inputValue = block->getArgument(0);
rewriter.setInsertionPointToStart(block);
- if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
- resultElementTy.isInteger(8)) {
+ if (isI8_8_8) {
Value index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), inputValue);
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
@@ -2459,8 +2482,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
return success();
}
- if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
- resultElementTy.isInteger(32)) {
+ if (isI16_16_32) {
Value extend = rewriter.create<arith::ExtSIOp>(
loc, rewriter.getI32Type(), inputValue);
@@ -2516,8 +2538,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
}
}
- return rewriter.notifyMatchFailure(
- op, "unable to create body for tosa.table op");
+ llvm_unreachable("unable to create body for tosa.table op");
}
};
|
Reorganize the implementation slightly, such that patterns check all preconditions before starting the actual rewrite. I.e., pattern no longer start rewriting and then abort, which would cause a pattern rollback. Pattern rollbacks are expensive and will be disallowed as part of the One-Shot Dialect Conversion refactoring.