Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
fix example nn_module
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Dec 23, 2022
1 parent a15b997 commit 3a77ed4
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relax/testing/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ def _unpack_params(value: object) -> List[relax.Var]:

def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]:
"""Utility function to initialize model's parameters."""
shape_dict = {v.name_hint: v.shape_ for v in mod["main"].params}
sinfo_dict = {v.name_hint: v.struct_info for v in mod["main"].params}
params = []
for k, v in shape_dict.items():
for k, v in sinfo_dict.items():
if k.startswith("data"):
continue
if isinstance(v, relax.ShapeExpr):
if isinstance(v, relax.TensorStructInfo) and isinstance(v.shape, relax.ShapeExpr):
shape = []
for i in v:
for i in v.shape:
if isinstance(i, tir.IntImm):
shape.append(int(i))
else:
Expand Down

0 comments on commit 3a77ed4

Please sign in to comment.