Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a JSON converter for 0.7 -> 0.8 and 0.8 -> 0.9 #9874

Merged
merged 9 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class GlobalVarNode : public RelayExprNode {

void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class MatchNode : public ExprNode {
v->Visit("data", &data);
v->Visit("clauses", &clauses);
v->Visit("complete", &complete);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class TupleNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -196,6 +197,7 @@ class VarNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -319,6 +321,7 @@ class CallNode : public ExprNode {
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -425,6 +428,7 @@ class LetNode : public ExprNode {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -516,6 +520,7 @@ class IfNode : public ExprNode {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -589,6 +594,7 @@ class TupleGetItemNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -652,6 +658,7 @@ class RefCreateNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -713,6 +720,7 @@ class RefReadNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -776,6 +784,7 @@ class RefWriteNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class FunctionNode : public BaseFuncNode {
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
128 changes: 90 additions & 38 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,53 @@ def _updater(data):
return _updater


def create_updater_08_to_09():
"""
Create an update to upgrade json from v0.8 to v0.9

Returns
-------
fupdater : function
The updater function
"""

def _initialize_virtual_device(item, _):
if "virtual_device_" not in item["attrs"]:
item["attrs"]["virtual_device_"] = "0"
return item

node_map = {
# Base IR
"GlobalVar": _initialize_virtual_device,
"relay.Var": _initialize_virtual_device,
"relay.Function": _initialize_virtual_device,
"relay.Tuple": _initialize_virtual_device,
"relay.Call": _initialize_virtual_device,
"relay.Let": _initialize_virtual_device,
"relay.If": _initialize_virtual_device,
"relay.TupleGetItem": _initialize_virtual_device,
"relay.RefCreate": _initialize_virtual_device,
"relay.RefRead": _initialize_virtual_device,
"relay.RefWrite": _initialize_virtual_device,
"relay.Match": _initialize_virtual_device,
}

return create_updater(node_map, "0.8", "0.9")


def create_updater_07_to_08():
"""Create an update to upgrade json from v0.7 to v0.8"""

def _initialize_module_attributes(item, _):
assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules"
if "attrs" not in item["attrs"]:
item["attrs"]["attrs"] = "0"
return item

node_map = {"IRModule": _initialize_module_attributes}
return create_updater(node_map, "0.7", "0.8")


def create_updater_06_to_07():
"""Create an update to upgrade json from v0.6 to v0.7

Expand Down Expand Up @@ -127,7 +174,7 @@ def _convert(item, nodes):
"relay.IncompleteType": _rename("IncompleteType"),
"relay.TypeRelation": _rename("TypeRelation"),
"relay.TypeCall": _rename("TypeCall"),
"relay.Constructor": [_update_from_std_str("name_hint")],
"relay.Constructor": _update_from_std_str("name_hint"),
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
Expand All @@ -143,43 +190,43 @@ def _convert(item, nodes):
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
"StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")],
"Cast": [_rename("tir.Cast")],
"Add": [_rename("tir.Add")],
"Sub": [_rename("tir.Sub")],
"Mul": [_rename("tir.Mul")],
"Div": [_rename("tir.Div")],
"Mod": [_rename("tir.Mod")],
"FloorDiv": [_rename("tir.FloorDiv")],
"FloorMod": [_rename("tir.FloorMod")],
"Min": [_rename("tir.Min")],
"Max": [_rename("tir.Max")],
"EQ": [_rename("tir.EQ")],
"NE": [_rename("tir.NE")],
"LT": [_rename("tir.LT")],
"LE": [_rename("tir.LE")],
"GT": [_rename("tir.GT")],
"GE": [_rename("tir.GE")],
"And": [_rename("tir.And")],
"Or": [_rename("tir.Or")],
"Not": [_rename("tir.Not")],
"Select": [_rename("tir.Select")],
"Load": [_rename("tir.Load")],
"BufferLoad": [_rename("tir.BufferLoad")],
"Ramp": [_rename("tir.Ramp")],
"Broadcast": [_rename("tir.Broadcast")],
"Shuffle": [_rename("tir.Shuffle")],
"Cast": _rename("tir.Cast"),
"Add": _rename("tir.Add"),
"Sub": _rename("tir.Sub"),
"Mul": _rename("tir.Mul"),
"Div": _rename("tir.Div"),
"Mod": _rename("tir.Mod"),
"FloorDiv": _rename("tir.FloorDiv"),
"FloorMod": _rename("tir.FloorMod"),
"Min": _rename("tir.Min"),
"Max": _rename("tir.Max"),
"EQ": _rename("tir.EQ"),
"NE": _rename("tir.NE"),
"LT": _rename("tir.LT"),
"LE": _rename("tir.LE"),
"GT": _rename("tir.GT"),
"GE": _rename("tir.GE"),
"And": _rename("tir.And"),
"Or": _rename("tir.Or"),
"Not": _rename("tir.Not"),
"Select": _rename("tir.Select"),
"Load": _rename("tir.Load"),
"BufferLoad": _rename("tir.BufferLoad"),
"Ramp": _rename("tir.Ramp"),
"Broadcast": _rename("tir.Broadcast"),
"Shuffle": _rename("tir.Shuffle"),
"Call": [_rename("tir.Call"), _update_from_std_str("name")],
"Let": [_rename("tir.Let")],
"Any": [_rename("tir.Any")],
"LetStmt": [_rename("tir.LetStmt")],
"AssertStmt": [_rename("tir.AssertStmt")],
"Store": [_rename("tir.Store")],
"BufferStore": [_rename("tir.BufferStore")],
"BufferRealize": [_rename("tir.BufferRealize")],
"Allocate": [_rename("tir.Allocate")],
"IfThenElse": [_rename("tir.IfThenElse")],
"Evaluate": [_rename("tir.Evaluate")],
"Prefetch": [_rename("tir.Prefetch")],
"Let": _rename("tir.Let"),
"Any": _rename("tir.Any"),
"LetStmt": _rename("tir.LetStmt"),
"AssertStmt": _rename("tir.AssertStmt"),
"Store": _rename("tir.Store"),
"BufferStore": _rename("tir.BufferStore"),
"BufferRealize": _rename("tir.BufferRealize"),
"Allocate": _rename("tir.Allocate"),
"IfThenElse": _rename("tir.IfThenElse"),
"Evaluate": _rename("tir.Evaluate"),
"Prefetch": _rename("tir.Prefetch"),
"AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")],
"Layout": [_rename("tir.Layout"), _update_from_std_str("name")],
"Buffer": [
Expand All @@ -206,8 +253,13 @@ def upgrade_json(json_str):
"""
data = json.loads(json_str)
from_version = data["attrs"]["tvm_version"]

if from_version.startswith("0.6"):
data = create_updater_06_to_07()(data)
data = create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data)))
elif from_version.startswith("0.7"):
data = create_updater_08_to_09()(create_updater_07_to_08()(data))
elif from_version.startswith("0.8"):
data = create_updater_08_to_09()(data)
else:
raise ValueError("Cannot update from version %s" % from_version)
return json.dumps(data, indent=2)
62 changes: 62 additions & 0 deletions tests/python/relay/test_json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import json


# 0.6 BACKWARDS COMPATIBILITY TESTS


def test_type_var():
# type var in 0.6
nodes = [
Expand Down Expand Up @@ -206,6 +209,65 @@ def test_str_map():
assert bool(x["z"] == 2)


# 0.7 BACKWARDS COMPATIBILITY TESTS


def test_irmodule_attributes():
nodes = [
{"type_key": ""},
{
"type_key": "IRModule",
"attrs": {
"functions": "0",
"global_type_var_map_": "0",
"global_var_map_": "0",
"source_map": "0",
"type_definitions": "0",
},
},
]
data = {
"root": 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.7.0"},
"b64ndarrays": [],
}
mod = tvm.ir.load_json(json.dumps(data))
assert isinstance(mod, tvm.ir.IRModule)
# IRModule attributes should defualt to null
assert not mod.attrs


# 0.8 BACKWARDS COMPATIBILITY TESTS


def test_virtual_device():
nodes = [
{"type_key": ""},
{
"type_key": "relay.Function",
"attrs": {
"_checked_type_": "0",
"attrs": "0",
"body": "0",
"params": "0",
"ret_type": "0",
"span": "0",
"type_params": "0",
},
},
]
data = {
"root": 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.8.0"},
"b64ndarrays": [],
}
func = tvm.ir.load_json(json.dumps(data))
assert isinstance(func, relay.Function)
assert not func.virtual_device_


if __name__ == "__main__":
test_op()
test_type_var()
Expand Down