Skip to content

Commit

Permalink
[Unity] Fix StructInfo Infer for vm.alloc_tensor (#14283)
Browse files Browse the repository at this point in the history
A hot fix for the struct info deduction for `vm.alloc_tensor`
  • Loading branch information
Hzfengsy authored Mar 13, 2023
1 parent d074580 commit 78884cc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 1 addition & 3 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,15 +444,13 @@ TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStora

// vm alloc_tensor

Expr InferShapeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) { return call->args[1]; }

StructInfo InferStructInfoVMAllocTensor(const Call& call, const BlockBuilder& ctx) {
DataType out_dtype;
if (const auto* dtype_node = call->args[3].as<DataTypeImmNode>()) {
const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node);
out_dtype = dtype_imm->value;
}
if (const auto* output_shape = call->args[1].as<ShapeExprNode>()) {
if (const auto* output_shape = call->args[2].as<ShapeExprNode>()) {
return TensorStructInfo(GetRef<Expr>(output_shape), out_dtype);
}
return TensorStructInfo(out_dtype, kUnknownNDim);
Expand Down
8 changes: 8 additions & 0 deletions tests/python/relax/test_op_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,13 @@ def _check_call(expr, op_name: str):
assert isinstance(x[1][0], rx.TupleGetItem)


def test_vm_alloc_tensor():
bb = rx.BlockBuilder()
storage = rx.Var("storage", rx.TensorStructInfo(dtype="float32"))
alloc = rx.op.vm.alloc_tensor(storage, offset=0, shape=rx.ShapeExpr([4, 5]), dtype="float32")
alloc = bb.normalize(alloc)
tvm.ir.assert_structural_equal(alloc.struct_info, R.Tensor([4, 5], "float32"))


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 78884cc

Please sign in to comment.