Skip to content

[mlir][tosa] Require signless types in validation and add corresponding conversion pass #144367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,18 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
}];
}

def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
let summary = "Convert integer types to signless";
let description = [{
This pass converts signed or unsigned integer types to signless. It
currently does this greedily for all operators and can also change the
signature of the function. Should the signature of the entrypoint
function change, it will be the responsibility of the user to carry
signedness information of the inputs and outputs independently.

This can be a useful transformation for conversion to other formats
that require strict adherence to the TOSA specification.
}];
}

#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
TosaFolders.cpp
Expand Down
134 changes: 134 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//===- TosaConvertIntegerTypeToSignless.cpp
//-------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===-------------------------------------------------------------------------------===//

// -----------
// Motivation:
// -----------

// The TOSA specification uses a signless type system, which means that
// information about signedness must be encapsulated by the operations
// themselves. For example, tosa.rescale provides the attrbutes `input_unsigned`
// and `output_unsigned` to indicate whether the input/output should be
// interpreted as unsigned or signed.

// The TOSA dialect, on the other hand, allows the use of signed or unsigned
// types in addition to signless. As such, when converting from TOSA dialect to
// other formats, we need to ensure that we conform to the TOSA specification.

// ---------
// Overview:
// ---------

// This pass converts signed or unsigned integer types to signless. It currently
// does this greedily for all operators and can also change the signature of the
// function. Should the signature of the entrypoint function change, it will be
// the responsibility of the user to carry signedness information of the inputs
// and outputs independently.

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace tosa {

#define GEN_PASS_DEF_TOSACONVERTINTEGERTYPETOSIGNLESS
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"

namespace {
class ToSignlessTensorTypeConverter : public TypeConverter {
static Type convertType(Type type) {
const auto tensorType = dyn_cast<TensorType>(type);
if (!tensorType)
return type;

const auto intType = dyn_cast<IntegerType>(tensorType.getElementType());
if (!intType ||
intType.getSignedness() == IntegerType::SignednessSemantics::Signless)
return type;

const auto signlessType = IntegerType::get(
intType.getContext(), intType.getWidth(), IntegerType::Signless);
return tensorType.cloneWith(std::nullopt, signlessType);
}

public:
explicit ToSignlessTensorTypeConverter() { addConversion(convertType); }
};

class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
public:
ConvertGenericOpWithIntegerTensorType(TypeConverter &typeConverter,
MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Convert integer types to signless
SmallVector<Type, 4> resultTypes;
if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
return failure();

// Create new op with replaced operands and results
auto *newOp = Operation::create(
op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());

// Handle regions in e.g. tosa.cond_if and tosa.while_loop
for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
Region &before = std::get<0>(regions);
Region &parent = std::get<1>(regions);
rewriter.inlineRegionBefore(before, parent, parent.end());
if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
return failure();
}

// Replace with rewritten op
rewriter.insert(newOp);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};

class TosaConvertIntegerTypeToSignless
: public impl::TosaConvertIntegerTypeToSignlessBase<
TosaConvertIntegerTypeToSignless> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
ToSignlessTensorTypeConverter typeConverter;

target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});

RewritePatternSet patterns(context);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
patterns.add<ConvertGenericOpWithIntegerTensorType>(typeConverter, context);

if (failed(
applyFullConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
};

} // namespace

} // namespace tosa
} // namespace mlir
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1320,21 +1320,22 @@ void TosaValidation::runOnOperation() {

// validate operator element types:
// - rescale operator is allowed to have ui8/ui16/ui32
// operands/results
// operands/results when strictOpSpecAlignment is false
// - perform valid element type check at the beginning to
// protect rest of code against quantized element types
const bool opIsRescale = isa<tosa::RescaleOp>(op);
const bool allowUnsigned =
!strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
for (Value operand : op->getOperands()) {
auto elementTy = getElementTypeOrSelf(operand);
if (!isValidElementType(elementTy, opIsRescale)) {
if (!isValidElementType(elementTy, allowUnsigned)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
}
}
for (Type resultTy : op->getResultTypes()) {
auto elementTy = getElementTypeOrSelf(resultTy);
if (!isValidElementType(elementTy, opIsRescale)) {
if (!isValidElementType(elementTy, allowUnsigned)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,7 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
return %r : tensor<1x1xi8>
}
Expand All @@ -2012,6 +2013,7 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
return %r : tensor<1x1xui8>
}
Expand Down
73 changes: 73 additions & 0 deletions mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: mlir-opt --split-input-file --tosa-convert-integer-type-to-signless %s | FileCheck %s

// -----

// CHECK-LABEL: test_rescale_output_unsigned
// CHECK: %arg0: tensor<1x1xi8>
func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
// CHECK: return %[[RESCALE]] : tensor<1x1xi8>
return %r : tensor<1x1xui8>
}

// -----

// CHECK-LABEL: test_rescale_input_unsigned
// CHECK: %arg0: tensor<1x1xi16>
func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui16>) -> (tensor<1x1xi8>) {
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<32768> : tensor<1xi16>}> : () -> tensor<1xi16>
// CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
// CHECK: return %[[RESCALE]] : tensor<1x1xi8>
return %r : tensor<1x1xi8>
}

// -----

// CHECK-LABEL: test_unsigned_function_signature
// CHECK: %arg0: tensor<1xi8>, %arg1: tensor<1xi8>
func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<1xui8>) -> (tensor<1xui8>, tensor<1xui8>) {
// CHECK: return %arg0, %arg1 : tensor<1xi8>, tensor<1xi8>
return %arg0, %arg1 : tensor<1xui8>, tensor<1xui8>
}

// -----

// CHECK-LABEL: test_no_change
// CHECK: %arg0: tensor<13x21x3xi8>
func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
%0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
// CHECK: return %0 : tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
}

// -----

// CHECK-LABEL: test_regions
// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
// CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
// CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
%1 = tosa.add %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
// CHECK: tosa.yield %1 : tensor<i8>
tosa.yield %1 : tensor<ui8>
}, {
^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
// CHECK: %1 = tosa.sub %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
%1 = tosa.sub %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
// CHECK: tosa.yield %1 : tensor<i8>
tosa.yield %1 : tensor<ui8>
}) : (tensor<i1>, tensor<ui8>, tensor<ui8>) -> tensor<ui8>
// CHECK: return %0 : tensor<i8>
return %0 : tensor<ui8>
}
31 changes: 31 additions & 0 deletions mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//--------------------------------------------------------------------------------------------------
// Test valid IR in terms of the shape and type of tensor, and the argument type of
// operation. Excludes the profile compilance checking since it is performed earlier in the
// validation flow.
//--------------------------------------------------------------------------------------------------

// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s

// -----

// CHECK-LABEL: test_rescale_input_unsigned
func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
return %r : tensor<1x1xi8>
}

// -----

// CHECK-LABEL: test_rescale_output_unsigned
func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
%0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
%1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
%2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
return %r : tensor<1x1xui8>
}
Loading