diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index c8373e06f0db77..389a4621c14e59 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3207,7 +3207,18 @@ void mlir::python::populateIRCore(py::module &m) { "Inserts an operation.") .def_property_readonly( "block", [](PyInsertionPoint &self) { return self.getBlock(); }, - "Returns the block that this InsertionPoint points to."); + "Returns the block that this InsertionPoint points to.") + .def_property_readonly( + "ref_operation", + [](PyInsertionPoint &self) -> py::object { + auto ref_operation = self.getRefOperation(); + if (ref_operation) + return ref_operation->getObject(); + return py::none(); + }, + "The reference operation before which new operations are " + "inserted, or None if the insertion point is at the end of " + "the block"); //---------------------------------------------------------------------------- // Mapping of PyAttribute. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 3ca7dd851961a4..c5412e735dddcb 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -833,6 +833,7 @@ class PyInsertionPoint { const pybind11::object &excTb); PyBlock &getBlock() { return block; } + std::optional &getRefOperation() { return refOperation; } private: // Trampoline constructor that avoids null initializing members while diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index e8f4440d216eeb..2609117dd220be 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -755,6 +755,8 @@ class InsertionPoint: def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @property def block(self) -> Block: ... + @property + def ref_operation(self) -> Optional[_OperationBase]: ... # TODO: Auto-generated. Audit and fix. class IntegerAttr(Attribute): diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py index 0dc7d757f56d19..268d2e77d036f5 100644 --- a/mlir/test/python/ir/insertion_point.py +++ b/mlir/test/python/ir/insertion_point.py @@ -27,6 +27,8 @@ def test_insert_at_block_end(): ) entry_block = module.body.operations[0].regions[0].blocks[0] ip = InsertionPoint(entry_block) + assert ip.block == entry_block + assert ip.ref_operation is None ip.insert(Operation.create("custom.op2")) # CHECK: "custom.op1" # CHECK: "custom.op2" @@ -51,6 +53,8 @@ def test_insert_before_operation(): ) entry_block = module.body.operations[0].regions[0].blocks[0] ip = InsertionPoint(entry_block.operations[1]) + assert ip.block == entry_block + assert ip.ref_operation == entry_block.operations[1] ip.insert(Operation.create("custom.op3")) # CHECK: "custom.op1" # CHECK: "custom.op3" @@ -75,6 +79,8 @@ def test_insert_at_block_begin(): ) entry_block = module.body.operations[0].regions[0].blocks[0] ip = InsertionPoint.at_block_begin(entry_block) + assert ip.block == entry_block + assert ip.ref_operation == entry_block.operations[0] ip.insert(Operation.create("custom.op1")) # CHECK: "custom.op1" # CHECK: "custom.op2" @@ -108,6 +114,8 @@ def test_insert_at_terminator(): ) entry_block = module.body.operations[0].regions[0].blocks[0] ip = InsertionPoint.at_block_terminator(entry_block) + assert ip.block == entry_block + assert ip.ref_operation == entry_block.operations[1] ip.insert(Operation.create("custom.op2")) # CHECK: "custom.op1" # CHECK: "custom.op2"