diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index d005a4cc6859c..b96682843538c 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -127,4 +127,18 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> { }]; } +def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> { + let summary = "Convert integer types to signless"; + let description = [{ + This pass converts signed or unsigned integer types to signless. It + currently does this greedily for all operators and can also change the + signature of the function. Should the signature of the entrypoint + function change, it will be the responsibility of the user to carry + signedness information of the inputs and outputs independently. + + This can be a useful transformation for conversion to other formats + that require strict adherence to the TOSA specification. + }]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index bbf079faea3d0..803993bb1008d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms + TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp TosaFolders.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp new file mode 100644 index 0000000000000..3085e56ceebc0 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp @@ -0,0 +1,134 @@ +//===- TosaConvertIntegerTypeToSignless.cpp +//-------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-------------------------------------------------------------------------------===// + +// ----------- +// Motivation: +// ----------- + +// The TOSA specification uses a signless type system, which means that +// information about signedness must be encapsulated by the operations +// themselves. For example, tosa.rescale provides the attrbutes `input_unsigned` +// and `output_unsigned` to indicate whether the input/output should be +// interpreted as unsigned or signed. + +// The TOSA dialect, on the other hand, allows the use of signed or unsigned +// types in addition to signless. As such, when converting from TOSA dialect to +// other formats, we need to ensure that we conform to the TOSA specification. + +// --------- +// Overview: +// --------- + +// This pass converts signed or unsigned integer types to signless. It currently +// does this greedily for all operators and can also change the signature of the +// function. Should the signature of the entrypoint function change, it will be +// the responsibility of the user to carry signedness information of the inputs +// and outputs independently. + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace tosa { + +#define GEN_PASS_DEF_TOSACONVERTINTEGERTYPETOSIGNLESS +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +namespace { +class ToSignlessTensorTypeConverter : public TypeConverter { + static Type convertType(Type type) { + const auto tensorType = dyn_cast(type); + if (!tensorType) + return type; + + const auto intType = dyn_cast(tensorType.getElementType()); + if (!intType || + intType.getSignedness() == IntegerType::SignednessSemantics::Signless) + return type; + + const auto signlessType = IntegerType::get( + intType.getContext(), intType.getWidth(), IntegerType::Signless); + return tensorType.cloneWith(std::nullopt, signlessType); + } + +public: + explicit ToSignlessTensorTypeConverter() { addConversion(convertType); } +}; + +class ConvertGenericOpWithIntegerTensorType : public ConversionPattern { +public: + ConvertGenericOpWithIntegerTensorType(TypeConverter &typeConverter, + MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Convert integer types to signless + SmallVector resultTypes; + if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes))) + return failure(); + + // Create new op with replaced operands and results + auto *newOp = Operation::create( + op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(), + op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions()); + + // Handle regions in e.g. tosa.cond_if and tosa.while_loop + for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) { + Region &before = std::get<0>(regions); + Region &parent = std::get<1>(regions); + rewriter.inlineRegionBefore(before, parent, parent.end()); + if (failed(rewriter.convertRegionTypes(&parent, *typeConverter))) + return failure(); + } + + // Replace with rewritten op + rewriter.insert(newOp); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +class TosaConvertIntegerTypeToSignless + : public impl::TosaConvertIntegerTypeToSignlessBase< + TosaConvertIntegerTypeToSignless> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + ToSignlessTensorTypeConverter typeConverter; + + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + RewritePatternSet patterns(context); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + patterns.add(typeConverter, context); + + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 229f42d3178b5..3f27849b8c90c 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1320,13 +1320,14 @@ void TosaValidation::runOnOperation() { // validate operator element types: // - rescale operator is allowed to have ui8/ui16/ui32 - // operands/results + // operands/results when strictOpSpecAlignment is false // - perform valid element type check at the beginning to // protect rest of code against quantized element types - const bool opIsRescale = isa(op); + const bool allowUnsigned = + !strictOpSpecAlignment && isa(op); for (Value operand : op->getOperands()) { auto elementTy = getElementTypeOrSelf(operand); - if (!isValidElementType(elementTy, opIsRescale)) { + if (!isValidElementType(elementTy, allowUnsigned)) { op->emitOpError() << "is not profile-aligned: element type " << elementTy << " is not legal"; return signalPassFailure(); @@ -1334,7 +1335,7 @@ void TosaValidation::runOnOperation() { } for (Type resultTy : op->getResultTypes()) { auto elementTy = getElementTypeOrSelf(resultTy); - if (!isValidElementType(elementTy, opIsRescale)) { + if (!isValidElementType(elementTy, allowUnsigned)) { op->emitOpError() << "is not profile-aligned: element type " << elementTy << " is not legal"; return signalPassFailure(); diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 805522799a6d8..e25b3b7ef3e3a 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -2000,6 +2000,7 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8 %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}} %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8> return %r : tensor<1x1xi8> } @@ -2012,6 +2013,7 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}} %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8> return %r : tensor<1x1xui8> } diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir new file mode 100644 index 0000000000000..38ac8d8fb66d9 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt --split-input-file --tosa-convert-integer-type-to-signless %s | FileCheck %s + +// ----- + +// CHECK-LABEL: test_rescale_output_unsigned +// CHECK: %arg0: tensor<1x1xi8> +func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) { + %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8> + %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8> + // CHECK: return %[[RESCALE]] : tensor<1x1xi8> + return %r : tensor<1x1xui8> +} + +// ----- + +// CHECK-LABEL: test_rescale_input_unsigned +// CHECK: %arg0: tensor<1x1xi16> +func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui16>) -> (tensor<1x1xi8>) { + %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.const"() <{values = dense<32768> : tensor<1xi16>}> : () -> tensor<1xi16> + // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8> + %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8> + // CHECK: return %[[RESCALE]] : tensor<1x1xi8> + return %r : tensor<1x1xi8> +} + +// ----- + +// CHECK-LABEL: test_unsigned_function_signature +// CHECK: %arg0: tensor<1xi8>, %arg1: tensor<1xi8> +func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<1xui8>) -> (tensor<1xui8>, tensor<1xui8>) { + // CHECK: return %arg0, %arg1 : tensor<1xi8>, tensor<1xi8> + return %arg0, %arg1 : tensor<1xui8>, tensor<1xui8> +} + +// ----- + +// CHECK-LABEL: test_no_change +// CHECK: %arg0: tensor<13x21x3xi8> +func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { + %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> + // CHECK: return %0 : tensor<13x21x3xi8> + return %0 : tensor<13x21x3xi8> +} + +// ----- + +// CHECK-LABEL: test_regions +// CHECK: %arg0: tensor, %arg1: tensor +func.func @test_regions(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: tosa.cond_if %arg2 -> (tensor) + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: tosa.yield %1 : tensor + tosa.yield %1 : tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + // CHECK: %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: tosa.yield %1 : tensor + tosa.yield %1 : tensor + }) : (tensor, tensor, tensor) -> tensor + // CHECK: return %0 : tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir new file mode 100644 index 0000000000000..cab14201dc0ce --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir @@ -0,0 +1,31 @@ +//-------------------------------------------------------------------------------------------------- +// Test valid IR in terms of the shape and type of tensor, and the argument type of +// operation. Excludes the profile compilance checking since it is performed earlier in the +// validation flow. +//-------------------------------------------------------------------------------------------------- + +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s + +// ----- + +// CHECK-LABEL: test_rescale_input_unsigned +func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) { + %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> + %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8> + return %r : tensor<1x1xi8> +} + +// ----- + +// CHECK-LABEL: test_rescale_output_unsigned +func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) { + %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> + %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> + %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> + %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8> + return %r : tensor<1x1xui8> +}