Skip to content

Commit

Permalink
[FIX][RUNTIME] Convert container with function value type
Browse files Browse the repository at this point in the history
Prior to this PR, though the `convert` function is capable of
converting a single Python function/lambda to TVM func, it is not able
to convert a container whose values inside are functions to TVM
object.

This PR adds function conversion to `convert_to_object` and redirects
`convert` to `convert_to_object`, so that now the conversion is always
recursive, and therefore will work well on function container value
type.

Co-authored-by: Chaofan Lin <[email protected]>
  • Loading branch information
MasterJH5574 and SiriusNEO committed Feb 17, 2023
1 parent d12a636 commit ff1647d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
15 changes: 8 additions & 7 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def asobject(self):
raise NotImplementedError()


ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PackedFuncBase, PyNativeObject)


def convert_to_object(value, span=None):
Expand Down Expand Up @@ -79,6 +79,8 @@ def convert_to_object(value, span=None):
return _ffi_api.Map(*vlist)
if isinstance(value, ObjectGeneric):
return value.asobject()
if callable(value):
return convert_to_tvm_func(value)
if value is None:
return None

Expand All @@ -99,13 +101,12 @@ def convert(value, span=None):
-------
tvm_val : Object or Function
Converted value in TVM
"""
if isinstance(value, (PackedFuncBase, ObjectBase)):
return value

if callable(value):
return convert_to_tvm_func(value)
Note
----
This function is redirected to `convert_to_object` as it is widely used in
the codebase. We can choose one to keep and discard the other one later.
"""
return convert_to_object(value, span=span)


Expand Down
13 changes: 13 additions & 0 deletions tests/python/all-platform-minimal-test/test_runtime_packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,18 @@ def check(arr):
assert tvm.testing.object_use_count(x) == 1


def test_dict_function_value_type():
from tvm import tir # pylint: disable=import-outside-toplevel

te_func_dict = {"add": lambda a, b: a + b}

converted_dict = tvm.runtime.convert(te_func_dict)
f = converted_dict["add"]
a = tir.Var("a", "float32")
b = tir.Var("b", "float32")
tvm.ir.assert_structural_equal(f(a, b), tir.Add(a, b))


if __name__ == "__main__":
test_ndarray_args()
test_numpy_scalar()
Expand All @@ -164,3 +176,4 @@ def check(arr):
test_return_func()
test_byte_array()
test_device()
test_dict_function_value_type()

0 comments on commit ff1647d

Please sign in to comment.