From cc6bfa7e6bcd07ae669a35c03abc6d9e926f4bf6 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Wed, 9 Jun 2021 20:07:45 +0100 Subject: [PATCH] Expose list of PassContext configurations to the Python APIs (#8212) * Expose C++ PassContext::ListAllConfigs via its Python counterpart tvm.ir.transform.PassContext.list_configs() * Add unit tests for the C++ and Python layers --- include/tvm/ir/transform.h | 6 ++++++ python/tvm/ir/transform.py | 5 +++++ src/ir/transform.cc | 14 ++++++++++++++ tests/cpp/relay_transform_sequential_test.cc | 7 +++++++ tests/python/relay/test_pass_instrument.py | 7 +++++++ 5 files changed, 39 insertions(+) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index ce5ae280e176..d5b50a7f7a6a 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -183,6 +183,12 @@ class PassContext : public ObjectRef { */ TVM_DLL static PassContext Current(); + /*! + * \brief Get all supported configuration names, registered within the PassContext. + * \return List of all configuration names. + */ + TVM_DLL static Array ListConfigNames(); + /*! * \brief Call instrument implementations' callbacks when entering PassContext. * The callbacks are called in order, and if one raises an exception, the rest will not be diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 3a3ac16be677..7a0ea825939a 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -120,6 +120,11 @@ def current(): """Return the current pass context.""" return _ffi_transform_api.GetCurrentPassContext() + @staticmethod + def list_config_names(): + """List all registered `PassContext` configuration names""" + return list(_ffi_transform_api.ListConfigNames()) + @tvm._ffi.register_object("transform.Pass") class Pass(tvm.runtime.Object): diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 9537ef532b44..a8541b1f42a7 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -145,6 +145,14 @@ class PassConfigManager { } } + Array ListConfigNames() { + Array config_keys; + for (const auto& kv : key2vtype_) { + config_keys.push_back(kv.first); + } + return config_keys; + } + static PassConfigManager* Global() { static auto* inst = new PassConfigManager(); return inst; @@ -163,6 +171,10 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassConfigManager::Global()->Register(key, value_type_index); } +Array PassContext::ListConfigNames() { + return PassConfigManager::Global()->ListConfigNames(); +} + PassContext PassContext::Create() { return PassContext(make_object()); } void PassContext::InstrumentEnterPassContext() { @@ -607,5 +619,7 @@ Pass PrintIR(String header, bool show_meta_data) { TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); +TVM_REGISTER_GLOBAL("transform.ListConfigNames").set_body_typed(PassContext::ListConfigNames); + } // namespace transform } // namespace tvm diff --git a/tests/cpp/relay_transform_sequential_test.cc b/tests/cpp/relay_transform_sequential_test.cc index 289574aef1e2..16e9438821ec 100644 --- a/tests/cpp/relay_transform_sequential_test.cc +++ b/tests/cpp/relay_transform_sequential_test.cc @@ -121,6 +121,13 @@ TEST(Relay, Sequential) { ICHECK(tvm::StructuralEqual()(f, expected)); } +TEST(PassContextListConfigNames, Basic) { + Array configs = relay::transform::PassContext::ListConfigNames(); + ICHECK_EQ(configs.empty(), false); + ICHECK_EQ(std::count(std::begin(configs), std::end(configs), "relay.backend.use_auto_scheduler"), + 1); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 86283fd31819..c7405ae04169 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -182,6 +182,13 @@ def run_after_pass(self, mod, info): assert passes_counter.run_after_count == 0 +def test_list_pass_configs(): + config_names = tvm.transform.PassContext.list_config_names() + + assert len(config_names) > 0 + assert "relay.backend.use_auto_scheduler" in config_names + + def test_enter_pass_ctx_exception(): events = []