Skip to content

Commit

Permalink
[mlir][python] remove mixins (llvm#68853)
Browse files Browse the repository at this point in the history
This PR replaces the mixin `OpView` extension mechanism with the
standard inheritance mechanism.

Why? Firstly, mixins are not very pythonic (inheritance is usually used
for this), a little convoluted, and too "tight" (can only be used in the
immediately adjacent `_ext.py`). Secondly, it (mixins) are now blocking
are correct implementation of "value builders" (see
[here](llvm#68764)) where the
problem becomes how to choose the correct base class that the value
builder should call.

This PR looks big/complicated but appearances are deceiving; 4 things
were needed to make this work:

1. Drop `skipDefaultBuilders` in
`OpPythonBindingGen::emitDefaultOpBuilders`
2. Former mixin extension classes are converted to inherit from the
generated `OpView` instead of being "mixins"
a. extension classes that simply were calling into an already generated
`super().__init__` continue to do so
b. (almost all) extension classes that were calling `self.build_generic`
because of a lack of default builder being generated can now also just
call `super().__init__`
3. To handle the [lone single
use-case](https://sourcegraph.com/search?q=context%3Aglobal+select_opview_mixin&patternType=standard&sm=1&groupBy=repo)
of `select_opview_mixin`, namely
[linalg](https://github.com/llvm/llvm-project/blob/main/mlir/python/mlir/dialects/_linalg_ops_ext.py#L38),
only a small change was necessary in `opdsl/lang/emitter.py` (thanks to
the emission/generation of default builders/`__init__`s)
4. since the `extend_opview_class` decorator is removed, we need a way
to register extension classes as the desired `OpView` that `op.opview`
conjures into existence; so we do the standard thing and just enable
replacing the existing registered `OpView` i.e.,
`register_operation(_Dialect, replace=True)`.

Note, the upgrade path for the common case is to change an extension to
inherit from the generated builder and decorate it with
`register_operation(_Dialect, replace=True)`. In the slightly more
complicated case where `super().__init(self.build_generic(...))` is
called in the extension's `__init__`, this needs to be updated to call
`__init__` in `OpView`, i.e., the grandparent (see updated docs). 
Note, also `<DIALECT>_ext.py` files/modules will no longer be automatically loaded.

Note, the PR has 3 base commits that look funny but this was done for
the purpose of tracking the line history of moving the
`<DIALECT>_ops_ext.py` class into `<DIALECT>.py` and updating (commit
labeled "fix").
  • Loading branch information
makslevental authored Oct 19, 2023
1 parent a30095a commit a2288a8
Show file tree
Hide file tree
Showing 49 changed files with 2,814 additions and 2,920 deletions.
127 changes: 58 additions & 69 deletions mlir/docs/Bindings/Python.md
Original file line number Diff line number Diff line change
Expand Up @@ -1017,90 +1017,79 @@ very generic signature.

#### Extending Generated Op Classes

Note that this is a rather complex mechanism and this section errs on the side
of explicitness. Users are encouraged to find an example and duplicate it if
they don't feel the need to understand the subtlety. The `builtin` dialect
provides some relatively simple examples.

As mentioned above, the build system generates Python sources like
`_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It is
often desirable to to use these generated classes as a starting point for
further customization, so an extension mechanism is provided to make this easy
(you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` file
but we prefer a more standard mechanism that is applied uniformly).
often desirable to use these generated classes as a starting point for
further customization, so an extension mechanism is provided to make this easy.
This mechanism uses conventional inheritance combined with `OpView` registration.
For example, the default builder for `arith.constant`

```python
class ConstantOp(_ods_ir.OpView):
OPERATION_NAME = "arith.constant"

_ODS_REGIONS = (0, True)

def __init__(self, value, *, loc=None, ip=None):
...
```

To provide extensions, add a `_{DIALECT_NAMESPACE}_ops_ext.py` file to the
`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level and
the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an
example, the generated code will include an import like this:
expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:

```python
try:
from . import _builtin_ops_ext as _ods_ext_module
except ImportError:
_ods_ext_module = None
from typing import Union

from mlir.ir import Type, IntegerAttr, FloatAttr
from mlir.dialects._arith_ops_gen import _Dialect, ConstantOp
from mlir.dialects._ods_common import _cext

@_cext.register_operation(_Dialect, replace=True)
class ConstantOpExt(ConstantOp):
def __init__(
self, result: Type, value: Union[int, float], *, loc=None, ip=None
):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
raise NotImplementedError(f"Building `arith.constant` not supported for {result=} {value=}")
```

Then for each generated concrete `OpView` subclass, it will apply a decorator
like:
which enables building an instance of `arith.constant` like so:

```python
@_ods_cext.register_operation(_Dialect)
@_ods_extend_opview_class(_ods_ext_module)
class FuncOp(_ods_ir.OpView):
from mlir.ir import F32Type

a = ConstantOpExt(F32Type.get(), 42.42)
b = ConstantOpExt(IntegerType.get_signless(32), 42)
```

See the `_ods_common.py` `extend_opview_class` function for details of the
mechanism. At a high level:

* If the extension module exists, locate an extension class for the op (in
this example, `FuncOp`):
* First by looking for an attribute with the exact name in the extension
module.
* Falling back to calling a `select_opview_mixin(parent_opview_cls)`
function defined in the extension module.
* If a mixin class is found, a new subclass is dynamically created that
multiply inherits from `({_builtin_ops_ext.FuncOp},
_builtin_ops_gen.FuncOp)`.

The mixin class should not inherit from anything (i.e. directly extends `object`
only). The facility is typically used to define custom `__init__` methods,
properties, instance methods and static methods. Due to the inheritance
ordering, the mixin class can act as though it extends the generated `OpView`
subclass in most contexts (i.e. `issubclass(_builtin_ops_ext.FuncOp, OpView)`
will return `False` but usage generally allows you treat it as duck typed as an
`OpView`).

There are a couple of recommendations, given how the class hierarchy is defined:

* For static methods that need to instantiate the actual "leaf" op (which is
dynamically generated and would result in circular dependencies to try to
reference by name), prefer to use `@classmethod` and the concrete subclass
will be provided as your first `cls` argument. See
`_builtin_ops_ext.FuncOp.from_py_func` as an example.
* If seeking to replace the generated `__init__` method entirely, you may
actually want to invoke the super-super-class `mlir.ir.OpView` constructor
directly, as it takes an `mlir.ir.Operation`, which is likely what you are
constructing (i.e. the generated `__init__` method likely adds more API
constraints than you want to expose in a custom builder).

A pattern that comes up frequently is wanting to provide a sugared `__init__`
method which has optional or type-polymorphism/implicit conversions but to
otherwise want to invoke the default op building logic. For such cases, it is
recommended to use an idiom such as:
Note, three key aspects of the extension mechanism in this example:

1. `ConstantOpExt` directly inherits from the generated `ConstantOp`;
2. in this, simplest, case all that's required is a call to the super class' initializer, i.e., `super().__init__(...)`;
3. in order to register `ConstantOpExt` as the preferred `OpView` that is returned by `mlir.ir.Operation.opview` (see [Operations, Regions and Blocks](#operations-regions-and-blocks))
we decorate the class with `@_cext.register_operation(_Dialect, replace=True)`, **where the `replace=True` must be used**.

In some more complex cases it might be necessary to explicitly build the `OpView` through `OpView.build_generic` (see [Default Builder](#default-builder)), just as is performed by the generated builders.
I.e., we must call `OpView.build_generic` **and pass the result to `OpView.__init__`**, where the small issue becomes that the latter is already overridden by the generated builder.
Thus, we must call a method of a super class' super class (the "grandparent"); for example:

```python
def __init__(self, sugar, spice, *, loc=None, ip=None):
... massage into result_type, operands, attributes ...
OpView.__init__(self, self.build_generic(
results=[result_type],
operands=operands,
attributes=attributes,
loc=loc,
ip=ip))
from mlir.dialects._scf_ops_gen import _Dialect, ForOp
from mlir.dialects._ods_common import _cext

@_cext.register_operation(_Dialect, replace=True)
class ForOpExt(ForOp):
def __init__(self, lower_bound, upper_bound, step, iter_args, *, loc=None, ip=None):
...
super(ForOp, self).__init__(self.build_generic(...))
```

Refer to the documentation for `build_generic` for more information.
where `OpView.__init__` is called via `super(ForOp, self).__init__`.
Note, there are alternatives ways to implement this (e.g., explicitly writing `OpView.__init__`); see any discussion on Python inheritance.

## Providing Python bindings for a dialect

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ class PyGlobals {
pybind11::object pyClass);

/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass);
pybind11::object pyClass, bool replace = false);

/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}

void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass) {
py::object pyClass, bool replace) {
py::object &found = operationClassMap[operationName];
if (found) {
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
.str());
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a,
"operation_name"_a, "operation_class"_a, "replace"_a = false,
"Testing hook for directly registering an operation");

// Aside from making the globals accessible to python, having python manage
Expand All @@ -63,20 +63,21 @@ PYBIND11_MODULE(_mlir, m) {
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
[](const py::object &dialectClass) -> py::cpp_function {
[](const py::object &dialectClass, bool replace) -> py::cpp_function {
return py::cpp_function(
[dialectClass](py::object opClass) -> py::object {
[dialectClass, replace](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
PyGlobals::get().registerOperationImpl(operationName, opClass);
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);

// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
return opClass;
});
},
"dialect_class"_a,
"dialect_class"_a, "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
Expand Down
19 changes: 0 additions & 19 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/AffineOps.td
SOURCES
dialects/affine.py
dialects/_affine_ops_ext.py
DIALECT_NAME affine
GEN_ENUM_BINDINGS)

Expand All @@ -78,7 +77,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BufferizationOps.td
SOURCES
dialects/bufferization.py
dialects/_bufferization_ops_ext.py
DIALECT_NAME bufferization
GEN_ENUM_BINDINGS_TD_FILE
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
Expand All @@ -90,7 +88,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BuiltinOps.td
SOURCES
dialects/builtin.py
dialects/_builtin_ops_ext.py
DIALECT_NAME builtin)

declare_mlir_dialect_python_bindings(
Expand All @@ -115,7 +112,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/FuncOps.td
SOURCES
dialects/func.py
dialects/_func_ops_ext.py
DIALECT_NAME func)

declare_mlir_dialect_python_bindings(
Expand All @@ -131,7 +127,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgOps.td
SOURCES
dialects/_linalg_ops_ext.py
SOURCES_GLOB
dialects/linalg/*.py
DIALECT_NAME linalg
Expand All @@ -152,7 +147,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformPDLExtensionOps.td
SOURCES
dialects/_transform_pdl_extension_ops_ext.py
dialects/transform/pdl.py
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
Expand All @@ -162,7 +156,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformOps.td
SOURCES
dialects/_transform_ops_ext.py
dialects/transform/__init__.py
_mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform
Expand All @@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/BufferizationTransformOps.td
SOURCES
dialects/_bufferization_transform_ops_ext.py
dialects/transform/bufferization.py
DIALECT_NAME transform
EXTENSION_NAME bufferization_transform)
Expand All @@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/GPUTransformOps.td
SOURCES
dialects/_gpu_transform_ops_ext.py
dialects/transform/gpu.py
DIALECT_NAME transform
EXTENSION_NAME gpu_transform)
Expand All @@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/SCFLoopTransformOps.td
SOURCES
dialects/_loop_transform_ops_ext.py
dialects/transform/loop.py
DIALECT_NAME transform
EXTENSION_NAME loop_transform)
Expand All @@ -205,7 +195,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/MemRefTransformOps.td
SOURCES
dialects/_memref_transform_ops_ext.py
dialects/transform/memref.py
DIALECT_NAME transform
EXTENSION_NAME memref_transform)
Expand All @@ -224,7 +213,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgStructuredTransformOps.td
SOURCES
dialects/_structured_transform_ops_ext.py
dialects/transform/structured.py
DIALECT_NAME transform
EXTENSION_NAME structured_transform
Expand All @@ -246,7 +234,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TensorTransformOps.td
SOURCES
dialects/_tensor_transform_ops_ext.py
dialects/transform/tensor.py
DIALECT_NAME transform
EXTENSION_NAME tensor_transform)
Expand Down Expand Up @@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/ArithOps.td
SOURCES
dialects/arith.py
dialects/_arith_ops_ext.py
DIALECT_NAME arith
GEN_ENUM_BINDINGS)

Expand All @@ -286,7 +272,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MemRefOps.td
SOURCES
dialects/memref.py
dialects/_memref_ops_ext.py
DIALECT_NAME memref)

declare_mlir_dialect_python_bindings(
Expand All @@ -295,7 +280,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MLProgramOps.td
SOURCES
dialects/ml_program.py
dialects/_ml_program_ops_ext.py
DIALECT_NAME ml_program)

declare_mlir_dialect_python_bindings(
Expand Down Expand Up @@ -339,7 +323,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/PDLOps.td
SOURCES
dialects/pdl.py
dialects/_pdl_ops_ext.py
_mlir_libs/_mlir/dialects/pdl.pyi
DIALECT_NAME pdl)

Expand All @@ -357,7 +340,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/SCFOps.td
SOURCES
dialects/scf.py
dialects/_scf_ops_ext.py
DIALECT_NAME scf)

declare_mlir_dialect_python_bindings(
Expand All @@ -383,7 +365,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/TensorOps.td
SOURCES
dialects/tensor.py
dialects/_tensor_ops_ext.py
DIALECT_NAME tensor)

declare_mlir_dialect_python_bindings(
Expand Down
Loading

0 comments on commit a2288a8

Please sign in to comment.