diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 0c1bc5047870..a49e9741a901 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -171,7 +171,7 @@ def _update_shape_dtype(shape, dtype, params): shape.update({k : v.shape for k, v in params.items()}) if isinstance(dtype, str): for k, v in params.items(): - if v.dtype != dtype: + if v.dtype != dtype and v.shape: raise ValueError( "%s: dtype not expected %s vs %s" % (k, dtype, v.dtype)) else: