-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][PRNG] Add uniform distribution generator wrt threefry PRNG #8041
Conversation
2f831ea
to
78bc8b7
Compare
python/tvm/topi/random/kernel.py
Outdated
|
||
Parameters | ||
---------- | ||
gen : Tensor[10, uint64] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the meaning of 10
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the ThreefryKeyType
introduced in #7083. Please refer to:
tvm/src/relay/op/random/kernel.cc
Line 28 in c999a84
static TensorType ThreefryKeyType() { return TensorType({10}, tvm::DataType::UInt(64)); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, let us add comment describe what is the meaning of 10.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could probably say ThreefryKey
instead of Tensor[10, uint64]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
less than high. | ||
|
||
out_shape : Sequence[int] | ||
Output shape of the random numbers. Product of all dimensions must be a multiple of 4. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason of product must be a multiple of 4
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the property of the threefry key. Please refer to this comment: #7083 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I just rethink about this problem. There should not be any restriction to the output shape... We could change the input restriction of the threefry_generate
in other PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mind sending a PR for updating the threefry_generate output, or rather what approach do you have in mind? I tried to avoid this problem by truncating output buffer but this required an extra copy, wonder if you have something else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@altanh Sorry that I'm not familiar with the threefry algorithm. Is it possible to call _threefry
twice in threefry_generate
in the following form? something like:
out_array = irb.buffer_ptr(out_array_ptr)
# deal with most of the array
_threefry(irb, tmp, 0, tmp, 4, out_array, 0, out_len // 4)
if out_len % 4 != 0:
# generate remainders in a small tmp buffer
tmp_array = irb.allocate(gen.dtype, 4, name="tmp", scope="global")
# may need to update the tmp key in between
# ...
_threefry(irb, tmp, 0, tmp, 4, tmp_array, 0, out_len // 4)
# only copy the tmp buffer
for i in range(out_len // 4 * 4, out_len):
out_array[i] = tmp_array[i%4]
In this way, we coud avoid copying the whole generated tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you could do that. Maybe submit it in a new PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tkonolige Sure, I will submit one. Could you tell me what kind of update on key tmp
we need before the second _threefry
? I can only think of updating increment counter (tmp[7]
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to update the counter buffer to be equal to out_len
tests/python/relay/test_prng.py
Outdated
@@ -103,6 +103,19 @@ def test_threefry_split_infer(): | |||
assert tvm.ir.structural_equal(f.ret_type, expected_type) | |||
|
|||
|
|||
def test_uniform_infer(): | |||
oshape = (12,) | |||
odtype = "float32" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should cover more types. For example float64
you have implemented
|
||
standard_uniform_values = tvm.te.compute(out_shape, lambda *i: uniform_scalar(random_bits(*i))) | ||
|
||
uniform_values = tvm.topi.add(tvm.topi.multiply(standard_uniform_values, high - low), low) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How well does this approach work when we have a large range (high - low)? It seems like we would be loosing a lot of potential randomness.
Thanks for this PR! I will be reading it soon, and just wanted to point you to a branch I worked on a while ago where I hacked a uniform op + dropout support: https://github.com/altanh/tvm/commits/prng (just in case it might be useful for you to check and compare).
Perhaps this is the operation you're looking for? https://github.com/altanh/tvm/blob/2d9ac7710ab055d4f20e8b5a0a3580836723efac/python/tvm/topi/generic/algorithm.py#L465 Thanks! |
@FrozenGene @tkonolige @altanh Thank you for your reviews. I've updated this PR based on them.
@FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in
@tkonolige As this approach is only using the fraction bits to represent float, there will be loss of randomness for all floats, at least
@altanh Thank you for your references! The |
Suggest |
@FrozenGene Thank you. I've added the type restriction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good to me. Just a couple of minor fixes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM with some minor comments. I did want to request that we keep the output shape restriction in the documentation for now until a follow up PR is merged which relaxes it. Thanks for the work!
@FrozenGene @altanh @tkonolige I've updated the PR upon the reviews. Could you take another look? Thank you~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Just some small comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I'm a bit uneasy about introducing a nondeterministic test based on averaging the random numbers but I imagine it will almost never fail. Also left a comment about comparing the min/max of the generated numbers - can we always guarantee <=
or >=
on the output or will there be some floating point inaccuracy cases where this might be violated?
@FrozenGene Could you take another look of this PR? Thank you~ |
@FrozenGene Could you have another look at this PR? Thank you! |
Thanks @zhuzilin @altanh @tkonolige merged now |
…pache#8041) * Add uniform distribution generator wrt threefry PRNG * fix lint * remove the redundant print * modifications based on review * update docs * update uniform algorithm to use bit operations only * add type restrictions * minor fix upon review * update test and error information
…pache#8041) * Add uniform distribution generator wrt threefry PRNG * fix lint * remove the redundant print * modifications based on review * update docs * update uniform algorithm to use bit operations only * add type restrictions * minor fix upon review * update test and error information
This PR adds a uniform distribution generator using the threefry PRNG introduced in #7083. We would need uniform to develop the training phase dropout as the following roadmap:
The algorithm used is basically the same as the one used in jax: using the random bits generated from
threefry_generate
as the fraction section of the float32 or float64. To be specific, I use the last 23 bits of the random bits for float32 and last 52 for float64. There is one different from the jax implementation. In jax, they used a bitcast to turn uint into float:However, as I haven't found the bitcast in te or topi, I use a divide to cast the type, which may be slower:
Thank you for your time on reviewing this PR. I may not be familiar enough with the tvm codebase at the moment, so I'm sorry for breaking any conventions in the community and I'd love to fix them :).
Gently ping @tqchen @altanh @tkonolige