diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 8937bb7b1016..2cfc8467a1cb 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -221,6 +221,7 @@ class GlobalVarNode : public RelayExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index b1d4d5975cb8..31dec2204146 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -299,6 +299,7 @@ class MatchNode : public ExprNode { v->Visit("data", &data); v->Visit("clauses", &clauses); v->Visit("complete", &complete); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 04dd9223719e..dcb7838a1b72 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -108,6 +108,7 @@ class TupleNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -196,6 +197,7 @@ class VarNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("vid", &vid); v->Visit("type_annotation", &type_annotation); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -319,6 +321,7 @@ class CallNode : public ExprNode { v->Visit("args", &args); v->Visit("attrs", &attrs); v->Visit("type_args", &type_args); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -425,6 +428,7 @@ class LetNode : public ExprNode { v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -516,6 +520,7 @@ class IfNode : public ExprNode { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); v->Visit("false_branch", &false_branch); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -589,6 +594,7 @@ class TupleGetItemNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple_value", &tuple); v->Visit("index", &index); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -652,6 +658,7 @@ class RefCreateNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -713,6 +720,7 @@ class RefReadNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -776,6 +784,7 @@ class RefWriteNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); v->Visit("value", &value); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index d9bf7acaa037..5869f878aa85 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -64,6 +64,7 @@ class FunctionNode : public BaseFuncNode { v->Visit("ret_type", &ret_type); v->Visit("type_params", &type_params); v->Visit("attrs", &attrs); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index a22d7d3ce108..9666475b8039 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -57,6 +57,53 @@ def _updater(data): return _updater +def create_updater_08_to_09(): + """ + Create an update to upgrade json from v0.8 to v0.9 + + Returns + ------- + fupdater : function + The updater function + """ + + def _initialize_virtual_device(item, _): + if "virtual_device_" not in item["attrs"]: + item["attrs"]["virtual_device_"] = "0" + return item + + node_map = { + # Base IR + "GlobalVar": _initialize_virtual_device, + "relay.Var": _initialize_virtual_device, + "relay.Function": _initialize_virtual_device, + "relay.Tuple": _initialize_virtual_device, + "relay.Call": _initialize_virtual_device, + "relay.Let": _initialize_virtual_device, + "relay.If": _initialize_virtual_device, + "relay.TupleGetItem": _initialize_virtual_device, + "relay.RefCreate": _initialize_virtual_device, + "relay.RefRead": _initialize_virtual_device, + "relay.RefWrite": _initialize_virtual_device, + "relay.Match": _initialize_virtual_device, + } + + return create_updater(node_map, "0.8", "0.9") + + +def create_updater_07_to_08(): + """Create an update to upgrade json from v0.7 to v0.8""" + + def _initialize_module_attributes(item, _): + assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules" + if "attrs" not in item["attrs"]: + item["attrs"]["attrs"] = "0" + return item + + node_map = {"IRModule": _initialize_module_attributes} + return create_updater(node_map, "0.7", "0.8") + + def create_updater_06_to_07(): """Create an update to upgrade json from v0.6 to v0.7 @@ -127,7 +174,7 @@ def _convert(item, nodes): "relay.IncompleteType": _rename("IncompleteType"), "relay.TypeRelation": _rename("TypeRelation"), "relay.TypeCall": _rename("TypeCall"), - "relay.Constructor": [_update_from_std_str("name_hint")], + "relay.Constructor": _update_from_std_str("name_hint"), "relay.Module": _rename("IRModule"), "relay.SourceName": _rename("SourceName"), "relay.Span": _rename("Span"), @@ -143,43 +190,43 @@ def _convert(item, nodes): "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], "StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")], - "Cast": [_rename("tir.Cast")], - "Add": [_rename("tir.Add")], - "Sub": [_rename("tir.Sub")], - "Mul": [_rename("tir.Mul")], - "Div": [_rename("tir.Div")], - "Mod": [_rename("tir.Mod")], - "FloorDiv": [_rename("tir.FloorDiv")], - "FloorMod": [_rename("tir.FloorMod")], - "Min": [_rename("tir.Min")], - "Max": [_rename("tir.Max")], - "EQ": [_rename("tir.EQ")], - "NE": [_rename("tir.NE")], - "LT": [_rename("tir.LT")], - "LE": [_rename("tir.LE")], - "GT": [_rename("tir.GT")], - "GE": [_rename("tir.GE")], - "And": [_rename("tir.And")], - "Or": [_rename("tir.Or")], - "Not": [_rename("tir.Not")], - "Select": [_rename("tir.Select")], - "Load": [_rename("tir.Load")], - "BufferLoad": [_rename("tir.BufferLoad")], - "Ramp": [_rename("tir.Ramp")], - "Broadcast": [_rename("tir.Broadcast")], - "Shuffle": [_rename("tir.Shuffle")], + "Cast": _rename("tir.Cast"), + "Add": _rename("tir.Add"), + "Sub": _rename("tir.Sub"), + "Mul": _rename("tir.Mul"), + "Div": _rename("tir.Div"), + "Mod": _rename("tir.Mod"), + "FloorDiv": _rename("tir.FloorDiv"), + "FloorMod": _rename("tir.FloorMod"), + "Min": _rename("tir.Min"), + "Max": _rename("tir.Max"), + "EQ": _rename("tir.EQ"), + "NE": _rename("tir.NE"), + "LT": _rename("tir.LT"), + "LE": _rename("tir.LE"), + "GT": _rename("tir.GT"), + "GE": _rename("tir.GE"), + "And": _rename("tir.And"), + "Or": _rename("tir.Or"), + "Not": _rename("tir.Not"), + "Select": _rename("tir.Select"), + "Load": _rename("tir.Load"), + "BufferLoad": _rename("tir.BufferLoad"), + "Ramp": _rename("tir.Ramp"), + "Broadcast": _rename("tir.Broadcast"), + "Shuffle": _rename("tir.Shuffle"), "Call": [_rename("tir.Call"), _update_from_std_str("name")], - "Let": [_rename("tir.Let")], - "Any": [_rename("tir.Any")], - "LetStmt": [_rename("tir.LetStmt")], - "AssertStmt": [_rename("tir.AssertStmt")], - "Store": [_rename("tir.Store")], - "BufferStore": [_rename("tir.BufferStore")], - "BufferRealize": [_rename("tir.BufferRealize")], - "Allocate": [_rename("tir.Allocate")], - "IfThenElse": [_rename("tir.IfThenElse")], - "Evaluate": [_rename("tir.Evaluate")], - "Prefetch": [_rename("tir.Prefetch")], + "Let": _rename("tir.Let"), + "Any": _rename("tir.Any"), + "LetStmt": _rename("tir.LetStmt"), + "AssertStmt": _rename("tir.AssertStmt"), + "Store": _rename("tir.Store"), + "BufferStore": _rename("tir.BufferStore"), + "BufferRealize": _rename("tir.BufferRealize"), + "Allocate": _rename("tir.Allocate"), + "IfThenElse": _rename("tir.IfThenElse"), + "Evaluate": _rename("tir.Evaluate"), + "Prefetch": _rename("tir.Prefetch"), "AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")], "Layout": [_rename("tir.Layout"), _update_from_std_str("name")], "Buffer": [ @@ -206,8 +253,13 @@ def upgrade_json(json_str): """ data = json.loads(json_str) from_version = data["attrs"]["tvm_version"] + if from_version.startswith("0.6"): - data = create_updater_06_to_07()(data) + data = create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data))) + elif from_version.startswith("0.7"): + data = create_updater_08_to_09()(create_updater_07_to_08()(data)) + elif from_version.startswith("0.8"): + data = create_updater_08_to_09()(data) else: raise ValueError("Cannot update from version %s" % from_version) return json.dumps(data, indent=2) diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 65efc306a347..5a7084eb53ab 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -21,6 +21,9 @@ import json +# 0.6 BACKWARDS COMPATIBILITY TESTS + + def test_type_var(): # type var in 0.6 nodes = [ @@ -206,6 +209,65 @@ def test_str_map(): assert bool(x["z"] == 2) +# 0.7 BACKWARDS COMPATIBILITY TESTS + + +def test_irmodule_attributes(): + nodes = [ + {"type_key": ""}, + { + "type_key": "IRModule", + "attrs": { + "functions": "0", + "global_type_var_map_": "0", + "global_var_map_": "0", + "source_map": "0", + "type_definitions": "0", + }, + }, + ] + data = { + "root": 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.7.0"}, + "b64ndarrays": [], + } + mod = tvm.ir.load_json(json.dumps(data)) + assert isinstance(mod, tvm.ir.IRModule) + # IRModule attributes should defualt to null + assert not mod.attrs + + +# 0.8 BACKWARDS COMPATIBILITY TESTS + + +def test_virtual_device(): + nodes = [ + {"type_key": ""}, + { + "type_key": "relay.Function", + "attrs": { + "_checked_type_": "0", + "attrs": "0", + "body": "0", + "params": "0", + "ret_type": "0", + "span": "0", + "type_params": "0", + }, + }, + ] + data = { + "root": 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.8.0"}, + "b64ndarrays": [], + } + func = tvm.ir.load_json(json.dumps(data)) + assert isinstance(func, relay.Function) + assert not func.virtual_device_ + + if __name__ == "__main__": test_op() test_type_var()