From 5717701550fb0de6ab6102ca0a0b79daf15e6e01 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sat, 31 Aug 2024 15:33:08 +0000 Subject: [PATCH] refine target kind identity --- src/target/target_kind.cc | 15 ++++++++++++++- tests/python/target/test_target_target.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index fced74c3a559..979b755af846 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -35,7 +35,20 @@ namespace tvm { -TVM_REGISTER_NODE_TYPE(TargetKindNode); +// helper to get internal dev function in objectref. +struct TargetKind2ObjectPtr : public ObjectRef { + static ObjectPtr Get(const TargetKind& kind) { return GetDataPtr(kind); } +}; + +TVM_REGISTER_NODE_TYPE(TargetKindNode) + .set_creator([](const std::string& name) { + auto kind = TargetKind::Get(name); + ICHECK(kind.defined()) << "Cannot find target kind \'" << name << '\''; + return TargetKind2ObjectPtr::Get(kind.value()); + }) + .set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index e977ef10aae0..1a52a46da1fc 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -559,5 +559,21 @@ def test_target_from_device_opencl(input_device): assert target.thread_warp_size == dev.warp_size +def test_module_dict_from_deserialized_targets(): + target = Target("llvm") + + from tvm.script import tir as T + + @T.prim_func + def func(): + T.evaluate(0) + + func = func.with_attr("Target", target) + target2 = tvm.ir.load_json(tvm.ir.save_json(target)) + mod = tvm.IRModule({"main": func}) + lib = tvm.build({target2: mod}, target_host=target) + lib["func"]() + + if __name__ == "__main__": tvm.testing.main()