Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[Parser][Printer] explicitly parse and print attrs_type_key in calls (#…
Browse files Browse the repository at this point in the history
…19)

* relax call_packed arity, return IRModule factory, print IRModule PrimFuncs

* explicitly parse and print attrs_type_key on calls

* print type even when attrs has no fields
  • Loading branch information
altanh authored and junrushao committed Feb 9, 2023
1 parent 14cfb84 commit b4371eb
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 10 deletions.
27 changes: 18 additions & 9 deletions python/tvm/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,16 +871,25 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, rx.Expr]:
self.report_error(f"unsupported function in call: {op}", expr.func_name.span)

# parse call attributes if applicable
if isinstance(op, rx.ExternFunc) or (isinstance(op, tvm.ir.Op) and op.attrs_type_key != ""):
attrs_type_key = "DictAttrs" if isinstance(op, rx.ExternFunc) else op.attrs_type_key
kwargs = {}
for key, val in expr.keyword_params.items():
assert isinstance(key, ast.Constant) and isinstance(key.value, str)
# TODO(@altanh): might need separate attribute parsing eventually
kwargs[key.value] = self.transform_expr(val)
attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs)
kwargs = {}
for key, val in expr.keyword_params.items():
assert isinstance(key, ast.Constant) and isinstance(key.value, str)
# TODO(@altanh): might need separate attribute parsing eventually
kwargs[key.value] = self.transform_expr(val)

is_default = False
if "attrs_type_key" in kwargs:
attrs_type_key = kwargs["attrs_type_key"]
kwargs.pop("attrs_type_key")
elif isinstance(op, tvm.ir.Op) and op.attrs_type_key != "":
attrs_type_key = op.attrs_type_key
else:
attrs = None
attrs_type_key = "DictAttrs"
is_default = True

attrs = None
if kwargs or not is_default:
attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs)

return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span))

Expand Down
3 changes: 3 additions & 0 deletions src/relay/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ Doc RelaxScriptPrinter::VisitNode_(const relay::CallNode* op) {
doc << "(" << Doc::Concat(args, Doc::Text(", "));

std::vector<Doc> attrs = PrintAttrs(op->attrs);
if (op->attrs.defined()) {
attrs.push_back(Doc::Text("attrs_type_key=") << Doc::StrLiteral(op->attrs->GetTypeKey()));
}
if (!attrs.empty()) {
doc << ", " << Doc::Concat(attrs);
}
Expand Down
7 changes: 6 additions & 1 deletion tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,17 +435,22 @@ def test_call_packed():
def f(x: Tensor[(3, 3), "float32"]):
# test that we can intro dim vars
z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False)
w = relax.call_packed(
"contrib.my_shape_of", x, dtype="int32", attrs_type_key="relay.attrs.ShapeOfAttrs"
)
return z

x = f.params[0]
(z_bind,) = f.body.blocks[0].bindings
(z_bind, w_bind) = f.body.blocks[0].bindings
check_tensor_var(z_bind.var, ("n", "m"), "float32")

assert isinstance(z_bind.value.op, rx.ExternFunc)
assert z_bind.value.op.global_symbol == "contrib.my_matmul"
assert "mp" in z_bind.value.attrs and z_bind.value.attrs["mp"] == False
assert structural_equal(z_bind.value.args, [x, x])

assert isinstance(w_bind.value.attrs, relay.op.op_attrs.ShapeOfAttrs)


def test_primexpr_arithmetic():
@rx.script
Expand Down
3 changes: 3 additions & 0 deletions tests/python/relax/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def test_call_packed():
def foo(x: Tensor[(3, 3), "float32"]):
# test that we can intro dim vars
z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False)
w = relax.call_packed(
"contrib.my_shape_of", x, dtype="int32", attrs_type_key="relay.attrs.ShapeOfAttrs"
)
return z

check_roundtrip(foo)
Expand Down

0 comments on commit b4371eb

Please sign in to comment.