Skip to content

Commit

Permalink
Add a JSON converter for 0.7 -> 0.8 and 0.8 -> 0.9 (apache#9874)
Browse files Browse the repository at this point in the history
* Add default to serialization

* revert changes in serialization.cc

* update 0.6 converter

* json updater working, except for cycles

* clean up code

* Fix tests

* formatting

* format
:

* nit
  • Loading branch information
electriclilies authored and ylc committed Feb 16, 2022
1 parent d7e2dbb commit 092a8bc
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 38 deletions.
1 change: 1 addition & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class GlobalVarNode : public RelayExprNode {

void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class MatchNode : public ExprNode {
v->Visit("data", &data);
v->Visit("clauses", &clauses);
v->Visit("complete", &complete);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class TupleNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -196,6 +197,7 @@ class VarNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -319,6 +321,7 @@ class CallNode : public ExprNode {
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -425,6 +428,7 @@ class LetNode : public ExprNode {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -516,6 +520,7 @@ class IfNode : public ExprNode {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -589,6 +594,7 @@ class TupleGetItemNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -652,6 +658,7 @@ class RefCreateNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -713,6 +720,7 @@ class RefReadNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -776,6 +784,7 @@ class RefWriteNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class FunctionNode : public BaseFuncNode {
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
128 changes: 90 additions & 38 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,53 @@ def _updater(data):
return _updater


def create_updater_08_to_09():
"""
Create an update to upgrade json from v0.8 to v0.9
Returns
-------
fupdater : function
The updater function
"""

def _initialize_virtual_device(item, _):
if "virtual_device_" not in item["attrs"]:
item["attrs"]["virtual_device_"] = "0"
return item

node_map = {
# Base IR
"GlobalVar": _initialize_virtual_device,
"relay.Var": _initialize_virtual_device,
"relay.Function": _initialize_virtual_device,
"relay.Tuple": _initialize_virtual_device,
"relay.Call": _initialize_virtual_device,
"relay.Let": _initialize_virtual_device,
"relay.If": _initialize_virtual_device,
"relay.TupleGetItem": _initialize_virtual_device,
"relay.RefCreate": _initialize_virtual_device,
"relay.RefRead": _initialize_virtual_device,
"relay.RefWrite": _initialize_virtual_device,
"relay.Match": _initialize_virtual_device,
}

return create_updater(node_map, "0.8", "0.9")


def create_updater_07_to_08():
"""Create an update to upgrade json from v0.7 to v0.8"""

def _initialize_module_attributes(item, _):
assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules"
if "attrs" not in item["attrs"]:
item["attrs"]["attrs"] = "0"
return item

node_map = {"IRModule": _initialize_module_attributes}
return create_updater(node_map, "0.7", "0.8")


def create_updater_06_to_07():
"""Create an update to upgrade json from v0.6 to v0.7
Expand Down Expand Up @@ -127,7 +174,7 @@ def _convert(item, nodes):
"relay.IncompleteType": _rename("IncompleteType"),
"relay.TypeRelation": _rename("TypeRelation"),
"relay.TypeCall": _rename("TypeCall"),
"relay.Constructor": [_update_from_std_str("name_hint")],
"relay.Constructor": _update_from_std_str("name_hint"),
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
Expand All @@ -143,43 +190,43 @@ def _convert(item, nodes):
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
"StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")],
"Cast": [_rename("tir.Cast")],
"Add": [_rename("tir.Add")],
"Sub": [_rename("tir.Sub")],
"Mul": [_rename("tir.Mul")],
"Div": [_rename("tir.Div")],
"Mod": [_rename("tir.Mod")],
"FloorDiv": [_rename("tir.FloorDiv")],
"FloorMod": [_rename("tir.FloorMod")],
"Min": [_rename("tir.Min")],
"Max": [_rename("tir.Max")],
"EQ": [_rename("tir.EQ")],
"NE": [_rename("tir.NE")],
"LT": [_rename("tir.LT")],
"LE": [_rename("tir.LE")],
"GT": [_rename("tir.GT")],
"GE": [_rename("tir.GE")],
"And": [_rename("tir.And")],
"Or": [_rename("tir.Or")],
"Not": [_rename("tir.Not")],
"Select": [_rename("tir.Select")],
"Load": [_rename("tir.Load")],
"BufferLoad": [_rename("tir.BufferLoad")],
"Ramp": [_rename("tir.Ramp")],
"Broadcast": [_rename("tir.Broadcast")],
"Shuffle": [_rename("tir.Shuffle")],
"Cast": _rename("tir.Cast"),
"Add": _rename("tir.Add"),
"Sub": _rename("tir.Sub"),
"Mul": _rename("tir.Mul"),
"Div": _rename("tir.Div"),
"Mod": _rename("tir.Mod"),
"FloorDiv": _rename("tir.FloorDiv"),
"FloorMod": _rename("tir.FloorMod"),
"Min": _rename("tir.Min"),
"Max": _rename("tir.Max"),
"EQ": _rename("tir.EQ"),
"NE": _rename("tir.NE"),
"LT": _rename("tir.LT"),
"LE": _rename("tir.LE"),
"GT": _rename("tir.GT"),
"GE": _rename("tir.GE"),
"And": _rename("tir.And"),
"Or": _rename("tir.Or"),
"Not": _rename("tir.Not"),
"Select": _rename("tir.Select"),
"Load": _rename("tir.Load"),
"BufferLoad": _rename("tir.BufferLoad"),
"Ramp": _rename("tir.Ramp"),
"Broadcast": _rename("tir.Broadcast"),
"Shuffle": _rename("tir.Shuffle"),
"Call": [_rename("tir.Call"), _update_from_std_str("name")],
"Let": [_rename("tir.Let")],
"Any": [_rename("tir.Any")],
"LetStmt": [_rename("tir.LetStmt")],
"AssertStmt": [_rename("tir.AssertStmt")],
"Store": [_rename("tir.Store")],
"BufferStore": [_rename("tir.BufferStore")],
"BufferRealize": [_rename("tir.BufferRealize")],
"Allocate": [_rename("tir.Allocate")],
"IfThenElse": [_rename("tir.IfThenElse")],
"Evaluate": [_rename("tir.Evaluate")],
"Prefetch": [_rename("tir.Prefetch")],
"Let": _rename("tir.Let"),
"Any": _rename("tir.Any"),
"LetStmt": _rename("tir.LetStmt"),
"AssertStmt": _rename("tir.AssertStmt"),
"Store": _rename("tir.Store"),
"BufferStore": _rename("tir.BufferStore"),
"BufferRealize": _rename("tir.BufferRealize"),
"Allocate": _rename("tir.Allocate"),
"IfThenElse": _rename("tir.IfThenElse"),
"Evaluate": _rename("tir.Evaluate"),
"Prefetch": _rename("tir.Prefetch"),
"AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")],
"Layout": [_rename("tir.Layout"), _update_from_std_str("name")],
"Buffer": [
Expand All @@ -206,8 +253,13 @@ def upgrade_json(json_str):
"""
data = json.loads(json_str)
from_version = data["attrs"]["tvm_version"]

if from_version.startswith("0.6"):
data = create_updater_06_to_07()(data)
data = create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data)))
elif from_version.startswith("0.7"):
data = create_updater_08_to_09()(create_updater_07_to_08()(data))
elif from_version.startswith("0.8"):
data = create_updater_08_to_09()(data)
else:
raise ValueError("Cannot update from version %s" % from_version)
return json.dumps(data, indent=2)
62 changes: 62 additions & 0 deletions tests/python/relay/test_json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import json


# 0.6 BACKWARDS COMPATIBILITY TESTS


def test_type_var():
# type var in 0.6
nodes = [
Expand Down Expand Up @@ -206,6 +209,65 @@ def test_str_map():
assert bool(x["z"] == 2)


# 0.7 BACKWARDS COMPATIBILITY TESTS


def test_irmodule_attributes():
nodes = [
{"type_key": ""},
{
"type_key": "IRModule",
"attrs": {
"functions": "0",
"global_type_var_map_": "0",
"global_var_map_": "0",
"source_map": "0",
"type_definitions": "0",
},
},
]
data = {
"root": 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.7.0"},
"b64ndarrays": [],
}
mod = tvm.ir.load_json(json.dumps(data))
assert isinstance(mod, tvm.ir.IRModule)
# IRModule attributes should defualt to null
assert not mod.attrs


# 0.8 BACKWARDS COMPATIBILITY TESTS


def test_virtual_device():
nodes = [
{"type_key": ""},
{
"type_key": "relay.Function",
"attrs": {
"_checked_type_": "0",
"attrs": "0",
"body": "0",
"params": "0",
"ret_type": "0",
"span": "0",
"type_params": "0",
},
},
]
data = {
"root": 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.8.0"},
"b64ndarrays": [],
}
func = tvm.ir.load_json(json.dumps(data))
assert isinstance(func, relay.Function)
assert not func.virtual_device_


if __name__ == "__main__":
test_op()
test_type_var()
Expand Down

0 comments on commit 092a8bc

Please sign in to comment.