diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 05426dfb1aeb..de8e801cca64 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -35,7 +35,7 @@ def asobject(self): raise NotImplementedError() -ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject) +ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PackedFuncBase, PyNativeObject) def convert_to_object(value, span=None): @@ -79,6 +79,8 @@ def convert_to_object(value, span=None): return _ffi_api.Map(*vlist) if isinstance(value, ObjectGeneric): return value.asobject() + if callable(value): + return convert_to_tvm_func(value) if value is None: return None @@ -99,13 +101,12 @@ def convert(value, span=None): ------- tvm_val : Object or Function Converted value in TVM - """ - if isinstance(value, (PackedFuncBase, ObjectBase)): - return value - - if callable(value): - return convert_to_tvm_func(value) + Note + ---- + This function is redirected to `convert_to_object` as it is widely used in + the codebase. We can choose one to keep and discard the other one later. + """ return convert_to_object(value, span=span) diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py index f905ef8e117e..bbfb8bd2db12 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py +++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py @@ -153,6 +153,18 @@ def check(arr): assert tvm.testing.object_use_count(x) == 1 +def test_dict_function_value_type(): + from tvm import tir # pylint: disable=import-outside-toplevel + + te_func_dict = {"add": lambda a, b: a + b} + + converted_dict = tvm.runtime.convert(te_func_dict) + f = converted_dict["add"] + a = tir.Var("a", "float32") + b = tir.Var("b", "float32") + tvm.ir.assert_structural_equal(f(a, b), tir.Add(a, b)) + + if __name__ == "__main__": test_ndarray_args() test_numpy_scalar() @@ -164,3 +176,4 @@ def check(arr): test_return_func() test_byte_array() test_device() + test_dict_function_value_type()