diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index 0a36e97c2ae68..a33190c380d37 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -39,6 +39,7 @@ #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Pass.h" +#include "mlir-c/Rewrite.h" // The 'mlir' Python package is relocatable and supports co-existing in multiple // projects. Each project must define its outer package prefix with this define @@ -284,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) { return module; } +/** Creates a capsule object encapsulating the raw C-API + * MlirFrozenRewritePatternSet. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the module in any way. */ +static inline PyObject * +mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm), + MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL); +} + +/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from + * mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the + * right type, then a null module is returned. */ +static inline MlirFrozenRewritePatternSet +mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER); + MlirFrozenRewritePatternSet pm = {ptr}; + return pm; +} + /** Creates a capsule object encapsulating the raw C-API MlirPassManager. * The returned capsule does not extend or affect ownership of any Python * objects that reference the module in any way. */ diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h new file mode 100644 index 0000000000000..45218a1cd4ebd --- /dev/null +++ b/mlir/include/mlir-c/Rewrite.h @@ -0,0 +1,60 @@ +//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the registration and creation method for +// rewrite patterns. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_REWRITE_H +#define MLIR_C_REWRITE_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Config/mlir-config.h" + +//===----------------------------------------------------------------------===// +/// Opaque type declarations (see mlir-c/IR.h for more details). +//===----------------------------------------------------------------------===// + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); +DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); +DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); + +MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet +mlirFreezeRewritePattern(MlirRewritePatternSet op); + +MLIR_CAPI_EXPORTED void +mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op); + +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( + MlirModule op, MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig); + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); + +MLIR_CAPI_EXPORTED MlirPDLPatternModule +mlirPDLPatternModuleFromModule(MlirModule op); + +MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op); + +MLIR_CAPI_EXPORTED MlirRewritePatternSet +mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op); +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + +#undef DEFINE_C_API_STRUCT + +#endif // MLIR_C_REWRITE_H diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index d8f22c7aa1709..ebf50109f72f2 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -198,6 +198,27 @@ struct type_caster { }; }; +/// Casts object <-> MlirFrozenRewritePatternSet. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet, + _("MlirFrozenRewritePatternSet")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + return value.ptr != nullptr; + } + static handle cast(MlirFrozenRewritePatternSet v, return_value_policy, + handle) { + py::object capsule = py::reinterpret_steal( + mlirPythonFrozenRewritePatternSetToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) + .attr("FrozenRewritePatternSet") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + /// Casts object <-> MlirOperation. template <> struct type_caster { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 8c34c11f70950..f49efcd506ee9 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -22,6 +22,7 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" +#include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 17272472ccca4..8da1ab16a4514 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -11,6 +11,7 @@ #include "Globals.h" #include "IRModule.h" #include "Pass.h" +#include "Rewrite.h" namespace py = pybind11; using namespace mlir; @@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) { populateIRInterfaces(irModule); populateIRTypes(irModule); + auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings"); + populateRewriteSubmodule(rewriteModule); + // Define and populate PassManager submodule. auto passModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp new file mode 100644 index 0000000000000..1d8128be9f082 --- /dev/null +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -0,0 +1,110 @@ +//===- Rewrite.cpp - Rewrite ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Rewrite.h" + +#include "IRModule.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Rewrite.h" +#include "mlir/Config/mlir-config.h" + +namespace py = pybind11; +using namespace mlir; +using namespace py::literals; +using namespace mlir::python; + +namespace { + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +/// Owning Wrapper around a PDLPatternModule. +class PyPDLPatternModule { +public: + PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} + PyPDLPatternModule(PyPDLPatternModule &&other) noexcept + : module(other.module) { + other.module.ptr = nullptr; + } + ~PyPDLPatternModule() { + if (module.ptr != nullptr) + mlirPDLPatternModuleDestroy(module); + } + MlirPDLPatternModule get() { return module; } + +private: + MlirPDLPatternModule module; +}; +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + +/// Owning Wrapper around a FrozenRewritePatternSet. +class PyFrozenRewritePatternSet { +public: + PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} + PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept + : set(other.set) { + other.set.ptr = nullptr; + } + ~PyFrozenRewritePatternSet() { + if (set.ptr != nullptr) + mlirFrozenRewritePatternSetDestroy(set); + } + MlirFrozenRewritePatternSet get() { return set; } + + pybind11::object getCapsule() { + return py::reinterpret_steal( + mlirPythonFrozenRewritePatternSetToCapsule(get())); + } + + static pybind11::object createFromCapsule(pybind11::object capsule) { + MlirFrozenRewritePatternSet rawPm = + mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + if (rawPm.ptr == nullptr) + throw py::error_already_set(); + return py::cast(PyFrozenRewritePatternSet(rawPm), + py::return_value_policy::move); + } + +private: + MlirFrozenRewritePatternSet set; +}; + +} // namespace + +/// Create the `mlir.rewrite` here. +void mlir::python::populateRewriteSubmodule(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of the top-level PassManager + //---------------------------------------------------------------------------- +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH + py::class_(m, "PDLModule", py::module_local()) + .def(py::init<>([](MlirModule module) { + return mlirPDLPatternModuleFromModule(module); + }), + "module"_a, "Create a PDL module from the given module.") + .def("freeze", [](PyPDLPatternModule &self) { + return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + mlirRewritePatternSetFromPDLPatternModule(self.get()))); + }); +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg + py::class_(m, "FrozenRewritePatternSet", + py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyFrozenRewritePatternSet::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyFrozenRewritePatternSet::createFromCapsule); + m.def( + "apply_patterns_and_fold_greedily", + [](MlirModule module, MlirFrozenRewritePatternSet set) { + auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); + if (mlirLogicalResultIsFailure(status)) + // FIXME: Not sure this is the right error to throw here. + throw py::value_error("pattern application failed to converge"); + }, + "module"_a, "set"_a, + "Applys the given patterns to the given module greedily while folding " + "results."); +} diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h new file mode 100644 index 0000000000000..997b80adda303 --- /dev/null +++ b/mlir/lib/Bindings/Python/Rewrite.h @@ -0,0 +1,22 @@ +//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H +#define MLIR_BINDINGS_PYTHON_REWRITE_H + +#include "PybindUtils.h" + +namespace mlir { +namespace python { + +void populateRewriteSubmodule(pybind11::module &m); + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_REWRITE_H diff --git a/mlir/lib/CAPI/Transforms/CMakeLists.txt b/mlir/lib/CAPI/Transforms/CMakeLists.txt index 2638025a8c359..6c67aa09fdf40 100644 --- a/mlir/lib/CAPI/Transforms/CMakeLists.txt +++ b/mlir/lib/CAPI/Transforms/CMakeLists.txt @@ -1,6 +1,9 @@ add_mlir_upstream_c_api_library(MLIRCAPITransforms Passes.cpp + Rewrite.cpp LINK_LIBS PUBLIC + MLIRIR MLIRTransforms + MLIRTransformUtils ) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp new file mode 100644 index 0000000000000..0de1958398f63 --- /dev/null +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -0,0 +1,83 @@ +//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Rewrite.h" +#include "mlir-c/Transforms.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { + assert(module.ptr && "unexpected null module"); + return *(static_cast(module.ptr)); +} + +inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { + return {module}; +} + +inline mlir::FrozenRewritePatternSet * +unwrap(MlirFrozenRewritePatternSet module) { + assert(module.ptr && "unexpected null module"); + return static_cast(module.ptr); +} + +inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { + return {module}; +} + +MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { + auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); + op.ptr = nullptr; + return wrap(m); +} + +void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { + delete unwrap(op); + op.ptr = nullptr; +} + +MlirLogicalResult +mlirApplyPatternsAndFoldGreedily(MlirModule op, + MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig) { + return wrap( + mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns))); +} + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { + assert(module.ptr && "unexpected null module"); + return static_cast(module.ptr); +} + +inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { + return {module}; +} + +MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { + return wrap(new mlir::PDLPatternModule( + mlir::OwningOpRef(unwrap(op)))); +} + +void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) { + delete unwrap(op); + op.ptr = nullptr; +} + +MlirRewritePatternSet +mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { + auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op))); + op.ptr = nullptr; + return wrap(m); +} +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index d8f2d1989fdea..d03036e17749d 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python _mlir_libs/__init__.py ir.py passmanager.py + rewrite.py dialects/_ods_common.py # The main _mlir module has submodules: include stubs from each. @@ -448,6 +449,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core IRModule.cpp IRTypes.cpp Pass.cpp + Rewrite.cpp # Headers must be included explicitly so they are installed. Globals.h diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py index db07dc50aabd7..b7b8430cebd07 100644 --- a/mlir/python/mlir/dialects/pdl.py +++ b/mlir/python/mlir/dialects/pdl.py @@ -6,7 +6,7 @@ from ._pdl_ops_gen import _Dialect from .._mlir_libs._mlirDialectsPDL import * from .._mlir_libs._mlirDialectsPDL import OperationType - +from ..extras.meta import region_op try: from ..ir import * @@ -127,6 +127,9 @@ def body(self): return self.regions[0].blocks[0] +pattern = region_op(PatternOp.__base__) + + @_ods_cext.register_operation(_Dialect, replace=True) class ReplaceOp(ReplaceOp): """Specialization for PDL replace op class.""" @@ -195,6 +198,9 @@ def body(self): return self.regions[0].blocks[0] +rewrite = region_op(RewriteOp) + + @_ods_cext.register_operation(_Dialect, replace=True) class TypeOp(TypeOp): """Specialization for PDL type op class.""" diff --git a/mlir/python/mlir/rewrite.py b/mlir/python/mlir/rewrite.py new file mode 100644 index 0000000000000..5bc1bba7ae9a7 --- /dev/null +++ b/mlir/python/mlir/rewrite.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._mlir_libs._mlir.rewrite import * diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py new file mode 100644 index 0000000000000..923af29a71ad7 --- /dev/null +++ b/mlir/test/python/integration/dialects/pdl.py @@ -0,0 +1,67 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from mlir.dialects import arith, func, pdl +from mlir.dialects.builtin import module +from mlir.ir import * +from mlir.rewrite import * + + +def construct_and_print_in_module(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + module = f(module) + if module is not None: + print(module) + return f + + +# CHECK-LABEL: TEST: test_add_to_mul +# CHECK: arith.muli +@construct_and_print_in_module +def test_add_to_mul(module_): + index_type = IndexType.get() + + # Create a test case. + @module(sym_name="ir") + def ir(): + @func.func(index_type, index_type) + def add_func(a, b): + return arith.addi(a, b) + + # Create a rewrite from add to mul. This will match + # - operation name is arith.addi + # - operands are index types. + # - there are two operands. + with Location.unknown(): + m = Module.create() + with InsertionPoint(m.body): + # Change all arith.addi with index types to arith.muli. + @pdl.pattern(benefit=1, sym_name="addi_to_mul") + def pat(): + # Match arith.addi with index types. + index_type = pdl.TypeOp(IndexType.get()) + operand0 = pdl.OperandOp(index_type) + operand1 = pdl.OperandOp(index_type) + op0 = pdl.OperationOp( + name="arith.addi", args=[operand0, operand1], types=[index_type] + ) + + # Replace the matched op with arith.muli. + @pdl.rewrite() + def rew(): + newOp = pdl.OperationOp( + name="arith.muli", args=[operand0, operand1], types=[index_type] + ) + pdl.ReplaceOp(op0, with_op=newOp) + + # Create a PDL module from module and freeze it. At this point the ownership + # of the module is transferred to the PDL module. This ownership transfer is + # not yet captured Python side/has sharp edges. So best to construct the + # module and PDL module in same scope. + # FIXME: This should be made more robust. + frozen = PDLModule(m).freeze() + # Could apply frozen pattern set multiple times. + apply_patterns_and_fold_greedily(module_, frozen) + return module_ diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 0254e127980e5..9eda1a2b4c7e1 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -420,6 +420,7 @@ mlir_c_api_cc_library( "include/mlir-c/Interfaces.h", "include/mlir-c/Pass.h", "include/mlir-c/RegisterEverything.h", + "include/mlir-c/Rewrite.h", "include/mlir-c/Support.h", "include/mlir/CAPI/AffineExpr.h", "include/mlir/CAPI/AffineMap.h", @@ -866,7 +867,10 @@ mlir_c_api_cc_library( mlir_c_api_cc_library( name = "CAPITransforms", - srcs = ["lib/CAPI/Transforms/Passes.cpp"], + srcs = [ + "lib/CAPI/Transforms/Passes.cpp", + "lib/CAPI/Transforms/Rewrite.cpp", + ], hdrs = ["include/mlir-c/Transforms.h"], capi_deps = [ ":CAPIIR", @@ -876,7 +880,10 @@ mlir_c_api_cc_library( ], includes = ["include"], deps = [ + ":IR", ":Pass", + ":Rewrite", + ":TransformUtils", ":Transforms", ], ) @@ -939,6 +946,7 @@ cc_library( textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS), deps = [ ":CAPIIRHeaders", + ":CAPITransformsHeaders", "@local_config_python//:python_headers", "@pybind11", ], @@ -957,6 +965,7 @@ cc_library( textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS), deps = [ ":CAPIIR", + ":CAPITransforms", "@local_config_python//:python_headers", "@pybind11", ], @@ -981,6 +990,7 @@ MLIR_PYTHON_BINDINGS_SOURCES = [ "lib/Bindings/Python/IRModule.cpp", "lib/Bindings/Python/IRTypes.cpp", "lib/Bindings/Python/Pass.cpp", + "lib/Bindings/Python/Rewrite.cpp", ] cc_library( diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel index add150de69faf..254cab0db4a5d 100644 --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -82,6 +82,13 @@ filegroup( ], ) +filegroup( + name = "RewritePyFiles", + srcs = [ + "mlir/rewrite.py", + ], +) + filegroup( name = "RuntimePyFiles", srcs = glob([