From 4f8c0f9c7f71926ade1b3beab494be26d7971572 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 6 Jan 2022 15:57:57 -0800 Subject: [PATCH 1/9] Add default to serialization --- include/tvm/ir/expr.h | 1 + include/tvm/relay/adt.h | 1 + include/tvm/relay/expr.h | 9 +++++++ include/tvm/relay/function.h | 1 + src/node/serialization.cc | 6 +++-- tests/python/relay/test_json_compact.py | 36 +++++++++++++++++++++++++ 6 files changed, 52 insertions(+), 2 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 8937bb7b1016..2cfc8467a1cb 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -221,6 +221,7 @@ class GlobalVarNode : public RelayExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index b1d4d5975cb8..31dec2204146 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -299,6 +299,7 @@ class MatchNode : public ExprNode { v->Visit("data", &data); v->Visit("clauses", &clauses); v->Visit("complete", &complete); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 04dd9223719e..dcb7838a1b72 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -108,6 +108,7 @@ class TupleNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -196,6 +197,7 @@ class VarNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("vid", &vid); v->Visit("type_annotation", &type_annotation); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -319,6 +321,7 @@ class CallNode : public ExprNode { v->Visit("args", &args); v->Visit("attrs", &attrs); v->Visit("type_args", &type_args); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -425,6 +428,7 @@ class LetNode : public ExprNode { v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -516,6 +520,7 @@ class IfNode : public ExprNode { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); v->Visit("false_branch", &false_branch); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -589,6 +594,7 @@ class TupleGetItemNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple_value", &tuple); v->Visit("index", &index); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -652,6 +658,7 @@ class RefCreateNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -713,6 +720,7 @@ class RefReadNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -776,6 +784,7 @@ class RefWriteNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); v->Visit("value", &value); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index d9bf7acaa037..5869f878aa85 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -64,6 +64,7 @@ class FunctionNode : public BaseFuncNode { v->Visit("ret_type", &ret_type); v->Visit("type_params", &type_params); v->Visit("attrs", &attrs); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 09eb02e10bfa..8134c895e389 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -315,7 +315,8 @@ class FieldDependencyFinder : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + // If we encounter a field that hasn't been set, initialize it to null. + return "0"; } return it->second; } @@ -372,7 +373,8 @@ class JSONAttrSetter : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + // If we encounter a field that hasn't been set, initialize it to null. + return "0"; } return it->second; } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 65efc306a347..3ca738ac8c8e 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -206,6 +206,42 @@ def test_str_map(): assert bool(x["z"] == 2) +def test_default_fields(): + # Node with all fields set + nodes = [ + {"type_key": ""}, + { + "type_key": "relay.GlobalVar", + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"}, + }, + ] + data = { + "root": 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar, tvm.ir.GlobalVar) + # Construct node without virtual_device_ field + nodes = [ + {"type_key": ""}, + { + "type_key": "relay.GlobalVar", + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}, + }, + ] + data = { + "root": 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar_default = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar_default, tvm.ir.GlobalVar) + assert not tvar_default.virtual_device_ + + if __name__ == "__main__": test_op() test_type_var() From 993100abdc8ba06de7ab02780afcd351b9655ff9 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 13:40:26 -0800 Subject: [PATCH 2/9] revert changes in serialization.cc --- src/node/serialization.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 8134c895e389..09eb02e10bfa 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -315,8 +315,7 @@ class FieldDependencyFinder : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - // If we encounter a field that hasn't been set, initialize it to null. - return "0"; + LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } @@ -373,8 +372,7 @@ class JSONAttrSetter : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - // If we encounter a field that hasn't been set, initialize it to null. - return "0"; + LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } From 5edb222a3e125e2525ef800d51546262326c69b7 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 14:19:58 -0800 Subject: [PATCH 3/9] update 0.6 converter --- python/tvm/ir/json_compact.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index a22d7d3ce108..6d2d000f4bcc 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -56,8 +56,10 @@ def _updater(data): return _updater +def create_updater_08_to_09(): + pass -def create_updater_06_to_07(): +def create_updater_06_to_09(): """Create an update to upgrade json from v0.6 to v0.7 Returns @@ -91,6 +93,15 @@ def _convert(item, _): return _convert + def _initialize_virtual_device(item, _): + print(item) + item["attrs"]["virtual_device_"] = "0" + return item + + def _initialize_module_attributes(item, _): + item["attrs"]["attrs"] = "0" + return item + def _update_global_key(item, _): if "global_key" in item: item["repr_str"] = item["global_key"] @@ -128,17 +139,28 @@ def _convert(item, nodes): "relay.TypeRelation": _rename("TypeRelation"), "relay.TypeCall": _rename("TypeCall"), "relay.Constructor": [_update_from_std_str("name_hint")], - "relay.Module": _rename("IRModule"), + "relay.Module": [_rename("IRModule"), _initialize_module_attributes], "relay.SourceName": _rename("SourceName"), "relay.Span": _rename("Span"), - "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")], - "GlobalVar": _update_from_std_str("name_hint"), + "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint"), _initialize_virtual_device], + "GlobalVar": [_update_from_std_str("name_hint"), _initialize_virtual_device], + "relay.Var": _initialize_virtual_device, + "relay.Function": _initialize_virtual_device, + "relay.Tuple": _initialize_virtual_device, + "relay.Call": _initialize_virtual_device, + "relay.Let": _initialize_virtual_device, + "relay.If": _initialize_virtual_device, + "relay.TupleGetItem": _initialize_virtual_device, + "relay.RefCreate": _initialize_virtual_device, + "relay.RefRead": _initialize_virtual_device, + "relay.RefWrite": _initialize_virtual_device, "relay.Pass": _rename("transform.Pass"), "relay.PassInfo": _rename("transform.PassInfo"), "relay.PassContext": _rename("transform.PassContext"), "relay.ModulePass": _rename("transform.ModulePass"), "relay.Sequential": _rename("transform.Sequential"), "StrMap": _rename("Map"), + # TIR "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], @@ -207,7 +229,9 @@ def upgrade_json(json_str): data = json.loads(json_str) from_version = data["attrs"]["tvm_version"] if from_version.startswith("0.6"): - data = create_updater_06_to_07()(data) + data = create_updater_06_to_09()(data) + elif from_version.startswith("0.8"): + data = create_updater_08_to_09()(data) else: raise ValueError("Cannot update from version %s" % from_version) return json.dumps(data, indent=2) From 32f75cb060a2b8443a03b32b43485ef8877b36be Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 16:55:36 -0800 Subject: [PATCH 4/9] json updater working, except for cycles --- python/tvm/ir/base.py | 6 +- python/tvm/ir/json_compact.py | 164 +++++++++++++++--------- src/node/serialization.cc | 24 +++- tests/python/relay/test_json_compact.py | 68 +++++++--- 4 files changed, 177 insertions(+), 85 deletions(-) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 00514b472d67..001f4d764c4e 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -135,9 +135,13 @@ def load_json(json_str): """ try: - return tvm.runtime._ffi_node_api.LoadJSON(json_str) + loaded = tvm.runtime._ffi_node_api.LoadJSON(json_str) + print("LOADED COMPLETE") + return loaded except tvm.error.TVMError: + print("Upgrading Json") json_str = json_compact.upgrade_json(json_str) + print("Loading again") return tvm.runtime._ffi_node_api.LoadJSON(json_str) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 6d2d000f4bcc..0ef6ced23b87 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -56,10 +56,57 @@ def _updater(data): return _updater + def create_updater_08_to_09(): - pass + """ + Create an update to upgrade json from v0.8 to v0.9 + + Returns + ------- + fupdater : function + The updater function + """ + + def _initialize_virtual_device(item, _): + print("Item before: ", item) + item["attrs"]["virtual_device_"] = "0" + print("Item after: ", item) + return item + + node_map = { + # Base IR + "GlobalVar": _initialize_virtual_device, + "relay.Var": _initialize_virtual_device, + "relay.Function": _initialize_virtual_device, + "relay.Tuple": _initialize_virtual_device, + "relay.Call": _initialize_virtual_device, + "relay.Let": _initialize_virtual_device, + "relay.If": _initialize_virtual_device, + "relay.TupleGetItem": _initialize_virtual_device, + "relay.RefCreate": _initialize_virtual_device, + "relay.RefRead": _initialize_virtual_device, + "relay.RefWrite": _initialize_virtual_device, + "relay.Match": _initialize_virtual_device, + } + + return create_updater(node_map, "0.8", "0.9") -def create_updater_06_to_09(): + +def create_updater_07_to_08(): + """Create an update to upgrade json from v0.7 to v0.8""" + + def _initialize_module_attributes(item, _): + assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules" + print("Module before: ", item) + item["attrs"]["attrs"] = "0" + print("Module after:", item) + return item + + node_map = {"IRModule": _initialize_module_attributes} + return create_updater(node_map, "0.7", "0.8") + + +def create_updater_06_to_07(): """Create an update to upgrade json from v0.6 to v0.7 Returns @@ -93,15 +140,6 @@ def _convert(item, _): return _convert - def _initialize_virtual_device(item, _): - print(item) - item["attrs"]["virtual_device_"] = "0" - return item - - def _initialize_module_attributes(item, _): - item["attrs"]["attrs"] = "0" - return item - def _update_global_key(item, _): if "global_key" in item: item["repr_str"] = item["global_key"] @@ -138,70 +176,59 @@ def _convert(item, nodes): "relay.IncompleteType": _rename("IncompleteType"), "relay.TypeRelation": _rename("TypeRelation"), "relay.TypeCall": _rename("TypeCall"), - "relay.Constructor": [_update_from_std_str("name_hint")], - "relay.Module": [_rename("IRModule"), _initialize_module_attributes], + "relay.Constructor": _update_from_std_str("name_hint"), + "relay.Module": _rename("IRModule"), "relay.SourceName": _rename("SourceName"), "relay.Span": _rename("Span"), - "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint"), _initialize_virtual_device], - "GlobalVar": [_update_from_std_str("name_hint"), _initialize_virtual_device], - "relay.Var": _initialize_virtual_device, - "relay.Function": _initialize_virtual_device, - "relay.Tuple": _initialize_virtual_device, - "relay.Call": _initialize_virtual_device, - "relay.Let": _initialize_virtual_device, - "relay.If": _initialize_virtual_device, - "relay.TupleGetItem": _initialize_virtual_device, - "relay.RefCreate": _initialize_virtual_device, - "relay.RefRead": _initialize_virtual_device, - "relay.RefWrite": _initialize_virtual_device, + "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")], + "GlobalVar": _update_from_std_str("name_hint"), "relay.Pass": _rename("transform.Pass"), "relay.PassInfo": _rename("transform.PassInfo"), "relay.PassContext": _rename("transform.PassContext"), "relay.ModulePass": _rename("transform.ModulePass"), "relay.Sequential": _rename("transform.Sequential"), "StrMap": _rename("Map"), - # TIR "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], "StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")], - "Cast": [_rename("tir.Cast")], - "Add": [_rename("tir.Add")], - "Sub": [_rename("tir.Sub")], - "Mul": [_rename("tir.Mul")], - "Div": [_rename("tir.Div")], - "Mod": [_rename("tir.Mod")], - "FloorDiv": [_rename("tir.FloorDiv")], - "FloorMod": [_rename("tir.FloorMod")], - "Min": [_rename("tir.Min")], - "Max": [_rename("tir.Max")], - "EQ": [_rename("tir.EQ")], - "NE": [_rename("tir.NE")], - "LT": [_rename("tir.LT")], - "LE": [_rename("tir.LE")], - "GT": [_rename("tir.GT")], - "GE": [_rename("tir.GE")], - "And": [_rename("tir.And")], - "Or": [_rename("tir.Or")], - "Not": [_rename("tir.Not")], - "Select": [_rename("tir.Select")], - "Load": [_rename("tir.Load")], - "BufferLoad": [_rename("tir.BufferLoad")], - "Ramp": [_rename("tir.Ramp")], - "Broadcast": [_rename("tir.Broadcast")], - "Shuffle": [_rename("tir.Shuffle")], + "Cast": _rename("tir.Cast"), + "Add": _rename("tir.Add"), + "Sub": _rename("tir.Sub"), + "Mul": _rename("tir.Mul"), + "Div": _rename("tir.Div"), + "Mod": _rename("tir.Mod"), + "FloorDiv": _rename("tir.FloorDiv"), + "FloorMod": _rename("tir.FloorMod"), + "Min": _rename("tir.Min"), + "Max": _rename("tir.Max"), + "EQ": _rename("tir.EQ"), + "NE": _rename("tir.NE"), + "LT": _rename("tir.LT"), + "LE": _rename("tir.LE"), + "GT": _rename("tir.GT"), + "GE": _rename("tir.GE"), + "And": _rename("tir.And"), + "Or": _rename("tir.Or"), + "Not": _rename("tir.Not"), + "Select": _rename("tir.Select"), + "Load": _rename("tir.Load"), + "BufferLoad": _rename("tir.BufferLoad"), + "Ramp": _rename("tir.Ramp"), + "Broadcast": _rename("tir.Broadcast"), + "Shuffle": _rename("tir.Shuffle"), "Call": [_rename("tir.Call"), _update_from_std_str("name")], - "Let": [_rename("tir.Let")], - "Any": [_rename("tir.Any")], - "LetStmt": [_rename("tir.LetStmt")], - "AssertStmt": [_rename("tir.AssertStmt")], - "Store": [_rename("tir.Store")], - "BufferStore": [_rename("tir.BufferStore")], - "BufferRealize": [_rename("tir.BufferRealize")], - "Allocate": [_rename("tir.Allocate")], - "IfThenElse": [_rename("tir.IfThenElse")], - "Evaluate": [_rename("tir.Evaluate")], - "Prefetch": [_rename("tir.Prefetch")], + "Let": _rename("tir.Let"), + "Any": _rename("tir.Any"), + "LetStmt": _rename("tir.LetStmt"), + "AssertStmt": _rename("tir.AssertStmt"), + "Store": _rename("tir.Store"), + "BufferStore": _rename("tir.BufferStore"), + "BufferRealize": _rename("tir.BufferRealize"), + "Allocate": _rename("tir.Allocate"), + "IfThenElse": _rename("tir.IfThenElse"), + "Evaluate": _rename("tir.Evaluate"), + "Prefetch": _rename("tir.Prefetch"), "AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")], "Layout": [_rename("tir.Layout"), _update_from_std_str("name")], "Buffer": [ @@ -227,10 +254,21 @@ def upgrade_json(json_str): The updated version. """ data = json.loads(json_str) + print("Completed loading") from_version = data["attrs"]["tvm_version"] + if from_version.startswith("0.6"): - data = create_updater_06_to_09()(data) + print("From 0.6") + data = create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data))) + elif from_version.startswith("0.7"): + print("From 0.7") + data1 = create_updater_07_to_08()(data) + print("First updater done") + data2 = create_updater_08_to_09()(data1) + print("2nd updater done") + data = data2 elif from_version.startswith("0.8"): + print("From 0.8") data = create_updater_08_to_09()(data) else: raise ValueError("Cannot update from version %s" % from_version) diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 09eb02e10bfa..f909efaec8c8 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -313,14 +313,24 @@ class FieldDependencyFinder : public AttrVisitor { ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { + std::cout << "Dependency finder" << std::endl; + std::cout << "Key: " << key << std::endl; + std::cout << "All keys: ["; + + for (auto kv : jnode_->attrs) { + std::cout << kv.first << ", "; + } + std::cout << "]" << std::endl; + auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + LOG(FATAL) << "JSONReader (DependencyFinder): cannot find field " << key; } return it->second; } template void ParseValue(const char* key, T* value) const { + std::cout << "ParseValue for " << key << std::endl; std::istringstream is(GetValue(key)); is >> *value; if (is.fail()) { @@ -337,6 +347,7 @@ class FieldDependencyFinder : public AttrVisitor { void Visit(const char* key, DataType* value) final {} void Visit(const char* key, runtime::NDArray* value) final {} void Visit(const char* key, ObjectRef* value) final { + std::cout << "Object: " << PrettyPrint(*value) << std::endl; size_t index; ParseValue(key, &index); jnode_->fields.push_back(index); @@ -370,9 +381,16 @@ class JSONAttrSetter : public AttrVisitor { ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { + std::cout << "Key: " << key << std::endl; + std::cout << "All keys: ["; + + for (auto kv : jnode_->attrs) { + std::cout << kv.first << ", "; + } + std::cout << "]" << std::endl; auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + LOG(FATAL) << "JSONReader (AttrSetter): cannot find field " << key; } return it->second; } @@ -557,7 +575,7 @@ struct JSONGraph { } } } - ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; + // ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; std::reverse(std::begin(topo_order), std::end(topo_order)); return topo_order; } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 3ca738ac8c8e..2d8eab6e3b3b 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -21,6 +21,9 @@ import json +# 0.6 BACKWARDS COMPATIBILITY TESTS + + def test_type_var(): # type var in 0.6 nodes = [ @@ -206,41 +209,70 @@ def test_str_map(): assert bool(x["z"] == 2) -def test_default_fields(): - # Node with all fields set +# 0.7 BACKWARDS COMPATIBILITY TESTS + + +def test_irmodule_attributes(): nodes = [ - {"type_key": ""}, { - "type_key": "relay.GlobalVar", - "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"}, - }, + "type_key": "IRModule", + "attrs": { + "functions": "0", + "global_type_var_map_": "0", + "global_var_map_": "0", + "source_map": "0", + "type_definitions": "0", + }, + } ] data = { "root": 1, "nodes": nodes, - "attrs": {"tvm_version": "0.6.0"}, + "attrs": {"tvm_version": "0.7.0"}, "b64ndarrays": [], } - tvar = tvm.ir.load_json(json.dumps(data)) - assert isinstance(tvar, tvm.ir.GlobalVar) - # Construct node without virtual_device_ field + mod = tvm.ir.load_json(json.dumps(data)) + assert isinstance(mod, tvm.ir.IRModule) + # IRModule attributes should defualt to null + assert not mod.attrs + + +# 0.8 BACKWARDS COMPATIBILITY TESTS + +# Does this break with functions? Yes. Seems bad. Probably should remove json dep checker? +def test_func_cycle(): nodes = [ - {"type_key": ""}, { - "type_key": "relay.GlobalVar", - "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}, - }, + "type_key": "relay.Function", + "attrs": { + "_checked_type_": "0", + "attrs": "0", + "body": "0", + "params": "0", + "ret_type": "0", + "span": "0", + "type_params": "0", + }, + } ] data = { "root": 1, "nodes": nodes, - "attrs": {"tvm_version": "0.6.0"}, + "attrs": {"tvm_version": "0.8.0"}, "b64ndarrays": [], } - tvar_default = tvm.ir.load_json(json.dumps(data)) - assert isinstance(tvar_default, tvm.ir.GlobalVar) - assert not tvar_default.virtual_device_ + dump = json.dumps(data) + print("Done dumping") + func = tvm.ir.load_json(dump) + assert isinstance(func, relay.Function) + assert not func.virtual_device_ + + +# add module attributes and virtual device test + +# BACKWARD COMPAT WITH 0.8 TESTS +# add test module attrs and test virtual device if __name__ == "__main__": test_op() From 72f2b8524c2f98d5f5b4e526649b436363c935ed Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 16:58:26 -0800 Subject: [PATCH 5/9] clean up code --- python/tvm/ir/base.py | 6 +----- python/tvm/ir/json_compact.py | 14 +------------- src/node/serialization.cc | 24 +++--------------------- tests/python/relay/test_json_compact.py | 1 - 4 files changed, 5 insertions(+), 40 deletions(-) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 001f4d764c4e..00514b472d67 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -135,13 +135,9 @@ def load_json(json_str): """ try: - loaded = tvm.runtime._ffi_node_api.LoadJSON(json_str) - print("LOADED COMPLETE") - return loaded + return tvm.runtime._ffi_node_api.LoadJSON(json_str) except tvm.error.TVMError: - print("Upgrading Json") json_str = json_compact.upgrade_json(json_str) - print("Loading again") return tvm.runtime._ffi_node_api.LoadJSON(json_str) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 0ef6ced23b87..2be329057dfa 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -68,9 +68,7 @@ def create_updater_08_to_09(): """ def _initialize_virtual_device(item, _): - print("Item before: ", item) item["attrs"]["virtual_device_"] = "0" - print("Item after: ", item) return item node_map = { @@ -97,9 +95,7 @@ def create_updater_07_to_08(): def _initialize_module_attributes(item, _): assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules" - print("Module before: ", item) item["attrs"]["attrs"] = "0" - print("Module after:", item) return item node_map = {"IRModule": _initialize_module_attributes} @@ -254,21 +250,13 @@ def upgrade_json(json_str): The updated version. """ data = json.loads(json_str) - print("Completed loading") from_version = data["attrs"]["tvm_version"] if from_version.startswith("0.6"): - print("From 0.6") data = create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data))) elif from_version.startswith("0.7"): - print("From 0.7") - data1 = create_updater_07_to_08()(data) - print("First updater done") - data2 = create_updater_08_to_09()(data1) - print("2nd updater done") - data = data2 + data = create_updater_08_to_09()(create_updater_07_to_08()(data)) elif from_version.startswith("0.8"): - print("From 0.8") data = create_updater_08_to_09()(data) else: raise ValueError("Cannot update from version %s" % from_version) diff --git a/src/node/serialization.cc b/src/node/serialization.cc index f909efaec8c8..09eb02e10bfa 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -313,24 +313,14 @@ class FieldDependencyFinder : public AttrVisitor { ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { - std::cout << "Dependency finder" << std::endl; - std::cout << "Key: " << key << std::endl; - std::cout << "All keys: ["; - - for (auto kv : jnode_->attrs) { - std::cout << kv.first << ", "; - } - std::cout << "]" << std::endl; - auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader (DependencyFinder): cannot find field " << key; + LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } template void ParseValue(const char* key, T* value) const { - std::cout << "ParseValue for " << key << std::endl; std::istringstream is(GetValue(key)); is >> *value; if (is.fail()) { @@ -347,7 +337,6 @@ class FieldDependencyFinder : public AttrVisitor { void Visit(const char* key, DataType* value) final {} void Visit(const char* key, runtime::NDArray* value) final {} void Visit(const char* key, ObjectRef* value) final { - std::cout << "Object: " << PrettyPrint(*value) << std::endl; size_t index; ParseValue(key, &index); jnode_->fields.push_back(index); @@ -381,16 +370,9 @@ class JSONAttrSetter : public AttrVisitor { ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { - std::cout << "Key: " << key << std::endl; - std::cout << "All keys: ["; - - for (auto kv : jnode_->attrs) { - std::cout << kv.first << ", "; - } - std::cout << "]" << std::endl; auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader (AttrSetter): cannot find field " << key; + LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } @@ -575,7 +557,7 @@ struct JSONGraph { } } } - // ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; + ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; std::reverse(std::begin(topo_order), std::end(topo_order)); return topo_order; } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 2d8eab6e3b3b..5358e61a3b88 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -262,7 +262,6 @@ def test_func_cycle(): "b64ndarrays": [], } dump = json.dumps(data) - print("Done dumping") func = tvm.ir.load_json(dump) assert isinstance(func, relay.Function) assert not func.virtual_device_ From 435ad5f8ff2bd80d3827ea04245092f96b2590d0 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 17:33:02 -0800 Subject: [PATCH 6/9] Fix tests --- python/tvm/ir/json_compact.py | 6 ++++-- tests/python/relay/test_json_compact.py | 9 +++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 2be329057dfa..ec8cd6c0a4b2 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -68,7 +68,8 @@ def create_updater_08_to_09(): """ def _initialize_virtual_device(item, _): - item["attrs"]["virtual_device_"] = "0" + if ("virtual_device_" not in item["attrs"].keys()): + item["attrs"]["virtual_device_"] = "0" return item node_map = { @@ -95,7 +96,8 @@ def create_updater_07_to_08(): def _initialize_module_attributes(item, _): assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules" - item["attrs"]["attrs"] = "0" + if "attrs" not in item["attrs"].keys(): + item["attrs"]["attrs"] = "0" return item node_map = {"IRModule": _initialize_module_attributes} diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 5358e61a3b88..528ee271091d 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -214,6 +214,7 @@ def test_str_map(): def test_irmodule_attributes(): nodes = [ + {"type_key": ""}, { "type_key": "IRModule", "attrs": { @@ -223,7 +224,7 @@ def test_irmodule_attributes(): "source_map": "0", "type_definitions": "0", }, - } + }, ] data = { "root": 1, @@ -239,9 +240,9 @@ def test_irmodule_attributes(): # 0.8 BACKWARDS COMPATIBILITY TESTS -# Does this break with functions? Yes. Seems bad. Probably should remove json dep checker? -def test_func_cycle(): +def test_virtual_device(): nodes = [ + {"type_key": ""}, { "type_key": "relay.Function", "attrs": { @@ -253,7 +254,7 @@ def test_func_cycle(): "span": "0", "type_params": "0", }, - } + }, ] data = { "root": 1, From c67e4e8104add537240ac9e2287c6e149b9b3d7f Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 17:37:13 -0800 Subject: [PATCH 7/9] formatting --- python/tvm/ir/json_compact.py | 2 +- tests/python/relay/test_json_compact.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index ec8cd6c0a4b2..4f49e4a641eb 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -68,7 +68,7 @@ def create_updater_08_to_09(): """ def _initialize_virtual_device(item, _): - if ("virtual_device_" not in item["attrs"].keys()): + if "virtual_device_" not in item["attrs"].keys(): item["attrs"]["virtual_device_"] = "0" return item diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 528ee271091d..b4418b043c8c 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -240,6 +240,7 @@ def test_irmodule_attributes(): # 0.8 BACKWARDS COMPATIBILITY TESTS + def test_virtual_device(): nodes = [ {"type_key": ""}, @@ -268,12 +269,6 @@ def test_virtual_device(): assert not func.virtual_device_ -# add module attributes and virtual device test - -# BACKWARD COMPAT WITH 0.8 TESTS - -# add test module attrs and test virtual device - if __name__ == "__main__": test_op() test_type_var() From 2bd0a01c883df06caa1b9d970e393b45d04f39ed Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 17:40:20 -0800 Subject: [PATCH 8/9] format : --- tests/python/relay/test_json_compact.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index b4418b043c8c..5a7084eb53ab 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -263,8 +263,7 @@ def test_virtual_device(): "attrs": {"tvm_version": "0.8.0"}, "b64ndarrays": [], } - dump = json.dumps(data) - func = tvm.ir.load_json(dump) + func = tvm.ir.load_json(json.dumps(data)) assert isinstance(func, relay.Function) assert not func.virtual_device_ From 2615824e07c8207f8f17e0c60ef103165870ab63 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 10 Jan 2022 13:00:39 -0800 Subject: [PATCH 9/9] nit --- python/tvm/ir/json_compact.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 4f49e4a641eb..9666475b8039 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -68,7 +68,7 @@ def create_updater_08_to_09(): """ def _initialize_virtual_device(item, _): - if "virtual_device_" not in item["attrs"].keys(): + if "virtual_device_" not in item["attrs"]: item["attrs"]["virtual_device_"] = "0" return item @@ -96,7 +96,7 @@ def create_updater_07_to_08(): def _initialize_module_attributes(item, _): assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules" - if "attrs" not in item["attrs"].keys(): + if "attrs" not in item["attrs"]: item["attrs"]["attrs"] = "0" return item