Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed May 14, 2021
1 parent 7baf726 commit 78bc8b7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
3 changes: 2 additions & 1 deletion include/tvm/relay/attrs/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ struct UniformAttrs : public tvm::AttrsNode<UniformAttrs> {

TVM_DECLARE_ATTRS(UniformAttrs, "relay.attrs.UniformAttrs") {
TVM_ATTR_FIELD(out_shape).describe("Shape of random numbers to generate");
TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>())
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Data type of the generated numbers");
}
};
Expand Down
7 changes: 2 additions & 5 deletions python/tvm/topi/random/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,8 @@ def uniform_scalar(bits):
standard_uniform = bits.astype(out_dtype) / float(1 << nfraction)
return standard_uniform

standard_uniform_values = tvm.te.compute(
out_shape, lambda *i : uniform_scalar(random_bits(*i)))
standard_uniform_values = tvm.te.compute(out_shape, lambda *i: uniform_scalar(random_bits(*i)))

uniform_values = tvm.topi.add(
tvm.topi.multiply(standard_uniform_values, high - low),
low)
uniform_values = tvm.topi.add(tvm.topi.multiply(standard_uniform_values, high - low), low)

return uniform_values

0 comments on commit 78bc8b7

Please sign in to comment.