From d7ce21ac198d587eb659da0efc03fbe4836abd4e Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 19 Sep 2023 13:59:53 -0700 Subject: [PATCH 1/3] [mlir][python] Expose AsmState python side. This does basic plumbing, ideally want a context approach to reduce needing to thread these manually, but the current is useful even in that state. --- mlir/lib/Bindings/Python/IRCore.cpp | 38 ++++++++++++++++++++++------- mlir/lib/Bindings/Python/IRModule.h | 25 +++++++++++++++++++ mlir/test/python/ir/value.py | 12 ++++++++- 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index af713547cccbb2..2ab1219016006d 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3425,19 +3425,35 @@ void mlir::python::populateIRCore(py::module &m) { kValueDunderStrDocstring) .def( "get_name", - [](PyValue &self, bool useLocalScope) { + [](PyValue &self, std::optional useLocalScope, + std::optional> state) { PyPrintAccumulator printAccum; - MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - if (useLocalScope) - mlirOpPrintingFlagsUseLocalScope(flags); - MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags); - mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(), + MlirOpPrintingFlags flags; + MlirAsmState valueState; + // Use state if provided, else create a new state. + if (state) { + valueState = state.value().get().get(); + // Don't allow setting using local scope and state at same time. + if (useLocalScope) + throw py::value_error( + "setting AsmState and local scope together not supported"); + } else { + flags = mlirOpPrintingFlagsCreate(); + if (useLocalScope.value_or(false)) + mlirOpPrintingFlagsUseLocalScope(flags); + valueState = mlirAsmStateCreateForValue(self.get(), flags); + } + mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(), printAccum.getUserData()); - mlirOpPrintingFlagsDestroy(flags); - mlirAsmStateDestroy(state); + // Release state if allocated locally. + if (!state) { + mlirOpPrintingFlagsDestroy(flags); + mlirAsmStateDestroy(valueState); + } return printAccum.join(); }, - py::arg("use_local_scope") = false, kGetNameAsOperand) + py::arg("use_local_scope") = std::nullopt, + py::arg("state") = std::nullopt, kGetNameAsOperand) .def_property_readonly( "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( @@ -3456,6 +3472,10 @@ void mlir::python::populateIRCore(py::module &m) { PyOpResult::bind(m); PyOpOperand::bind(m); + py::class_(m, "AsmState", py::module_local()) + .def(py::init(), py::arg("value"), + py::arg("use_local_scope")); + //---------------------------------------------------------------------------- // Mapping of SymbolTable. //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index d1911730c1ede0..23338f7fdb38ad 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -748,6 +748,31 @@ class PyRegion { MlirRegion region; }; +/// Wrapper around an MlirAsmState. +class PyAsmState { + public: + PyAsmState(MlirValue value, bool useLocalScope) { + flags = mlirOpPrintingFlagsCreate(); + // The OpPrintingFlags are not exposed Python side, create locally and + // associate lifetime with the state. + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + state = mlirAsmStateCreateForValue(value, flags); + } + ~PyAsmState() { + mlirOpPrintingFlagsDestroy(flags); + } + // Delete copy constructors. + PyAsmState(PyAsmState &other) = delete; + PyAsmState(const PyAsmState &other) = delete; + + MlirAsmState get() { return state; } + + private: + MlirAsmState state; + MlirOpPrintingFlags flags; +}; + /// Wrapper around an MlirBlock. /// Blocks are managed completely by their containing operation. Unlike the /// C++ API, the python API does not support detached blocks. diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 46a50ac5291e8d..2a47c8d820eaf4 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false import gc from mlir.ir import * @@ -199,6 +199,16 @@ def testValuePrintAsOperand(): # CHECK: %[[VAL4]] print(value4.get_name()) + print("With AsmState") + # CHECK-LABEL: With AsmState + state = AsmState(value3, use_local_scope=True) + # CHECK: %0 + print(value3.get_name(state=state)) + # CHECK: %1 + print(value4.get_name(state=state)) + + print("With use_local_scope") + # CHECK-LABEL: With use_local_scope # CHECK: %0 print(value3.get_name(use_local_scope=True)) # CHECK: %1 From 9de8ec41675aafea37a44f5847064e0eaa4b658f Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 19 Sep 2023 14:09:18 -0700 Subject: [PATCH 2/3] Set same default in constructor --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 2ab1219016006d..cafcdd19ad9a0c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3474,7 +3474,7 @@ void mlir::python::populateIRCore(py::module &m) { py::class_(m, "AsmState", py::module_local()) .def(py::init(), py::arg("value"), - py::arg("use_local_scope")); + py::arg("use_local_scope") = false); //---------------------------------------------------------------------------- // Mapping of SymbolTable. From f860a6fad3391971d2b293e279c38c9281fab715 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 20 Sep 2023 11:44:04 -0700 Subject: [PATCH 3/3] Make clear checking if arg is set rather than value of it --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index cafcdd19ad9a0c..fc80e193b1aac7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3434,7 +3434,7 @@ void mlir::python::populateIRCore(py::module &m) { if (state) { valueState = state.value().get().get(); // Don't allow setting using local scope and state at same time. - if (useLocalScope) + if (useLocalScope.has_value()) throw py::value_error( "setting AsmState and local scope together not supported"); } else {