-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RELAY,TOPI] Threefry PRNG: splittable and stateless #7083
Conversation
awesome work, cc @tqchen @junrushao1994 @MarisaKirisame @eric-haibin-lin who may be interested I think it may be worth discussing high vs low level API for this, and what namespace it should live in. I wrote a few examples for how we might use this here https://discuss.tvm.apache.org/t/rfc-handling-effect-in-tvm-and-relay/5946/25?u=altanh |
worth noting that |
NamingI propose we move everything PRNG to a new Handling different PRNG kernelsThe splitting, keygen, and bit-gen operations will be kernel-specific. However, AFAIK, most if not all of the commonly used random ops simply require random bits as inputs (i.e. don't care about how those bits were generated). I propose the following approach to handling different kernels (thanks to @jroesch for helpful discussion):
Problems.
Other notes
|
cc @antinucleon |
I've moved everything to a new |
src/relay/op/random/kernel.cc
Outdated
@@ -25,21 +25,23 @@ namespace relay { | |||
|
|||
TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs); | |||
|
|||
static const TensorType THREEFRY_KEY_TYPE = TensorType({10}, tvm::DataType::UInt(64)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a rule of thumb, try to avoid static variables. As sometimes they have static variable constructing order issues Use static functions that returns these variables instead. As the construction is not that costly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, is the naming style OK for a static function? THREEFRY_KEY_TYPE()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually CamelCase is preferred for a global or static function
bump @antinucleon @junrushao1994 @eric-haibin-lin @MarisaKirisame (please cc anyone else interested in PRNG for review, thanks!) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tkonolige I had some additional questions about TODOs in the Threefry kernel, maybe you can clarify?
# number of rounds is even, so out always contains the result | ||
(out_buf, tmp) = (tmp, out_buf) | ||
(out_offset, tmp_offset) = (tmp_offset, out_offset) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see some TODO
in this function (_threefry
), do they affect the correctness of the algorithm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only one that matters is TODO should be wrapping
. I do not know if TVM guarantees unsigned integer arithmetic to be wrapping (instead of saturating).
Thanks for the reviews, @junrushao1994 @MarisaKirisame @eric-haibin-lin if you have some cycles, input would be appreciated! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really well written and commented code, great work!
|
||
TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs); | ||
|
||
static TensorType ThreefryKeyType() { return TensorType({10}, tvm::DataType::UInt(64)); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry for not pointing this out earlier. Maybe a good thing to do is to wrap this in a newtype? (e.g. define a type ThreefryKey that you cannot use in anyway except in random operation. this will avoid doing arithmetic on random seed.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a new opaque new type to tvm seems really involved. We have to add a new visitor for each type visitor, which seems like it may cause issues with some passes. We'd also have to add a no-op function with implementations to satisfy the type checker. Or we'd have to add a wrapper struct with all the proper conversion functions. Given all this complication, I don't think it is a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks good to me! I just have a few design questions and notes on comments and your tests.
:py:func:`threefry_generate`. **Do not use this key again after calling | ||
this function.** | ||
|
||
shape : Sequence[int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the total number of outputs need to be a multiple of four?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is an implementation detail. Basically, threefry uses 4 64-bit words as its state, inputs, and outputs.
def threefry_split(key): | ||
"""Split an existing Threefry key into two new ones. | ||
|
||
This is useful if you have to subsequent calls which each need their own |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why wouldn't someone just create two separate three fry keys using different seeds, and use them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating separate keys has not been theoretically proven to be as random as splitting a single key. Maybe I should add a comment that you should only really create one key. On the other hand, these details might be better handled at a higher level interface (future work).
# there is no state to maintain, we can apply it to a sequence of numbers (0..N) to generate a | ||
# sequence of random numbers in parallel. In order to make the PRNG splittable (that is we can | ||
# generate a sequence of random numbers in one place, and another sequence in another), we add a | ||
# path and key in addition to the counter. The path allows us to encode a sequence of splits (a 0 in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on how path and key are used in number generation? You don't explain what the key is, either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last sentence explains what the key is: "To avoid continuously growing the path, we can compress an existing path into the key portion of the generator by hashing the current key, path, and counter to create the new key (this same technique is used if we run out of room for the counter)." I've added a comment on how it is initialized.
I've also added an explanation of how random numbers are generated (we apply the hash to key, path, and counter).
Thank you @altanh @electriclilies @tqchen @jwfromm @MarisaKirisame for the reviews, the PR has been merged. |
* [RELAY,TOPI] Threefry PRNG: splittable and stateless * Fix sphinx? * Lint fixes * sphinx fixes round 2 * fix inputs for tests * reorganize to random, fix uninitialized memory bug * silence linter * silence linter even further * s * strengthen Threefry key type checking, add tests * replace static variable with function for Threefry key type * lint fix * Remove old todos, improve assert messages * describe how random number is generated * add tests for incorrect output size. also vary test sizes Co-authored-by: Altan Haan <[email protected]>
* [RELAY,TOPI] Threefry PRNG: splittable and stateless * Fix sphinx? * Lint fixes * sphinx fixes round 2 * fix inputs for tests * reorganize to random, fix uninitialized memory bug * silence linter * silence linter even further * s * strengthen Threefry key type checking, add tests * replace static variable with function for Threefry key type * lint fix * Remove old todos, improve assert messages * describe how random number is generated * add tests for incorrect output size. also vary test sizes Co-authored-by: Altan Haan <[email protected]>
* [RELAY,TOPI] Threefry PRNG: splittable and stateless * Fix sphinx? * Lint fixes * sphinx fixes round 2 * fix inputs for tests * reorganize to random, fix uninitialized memory bug * silence linter * silence linter even further * s * strengthen Threefry key type checking, add tests * replace static variable with function for Threefry key type * lint fix * Remove old todos, improve assert messages * describe how random number is generated * add tests for incorrect output size. also vary test sizes Co-authored-by: Altan Haan <[email protected]>
* [RELAY,TOPI] Threefry PRNG: splittable and stateless * Fix sphinx? * Lint fixes * sphinx fixes round 2 * fix inputs for tests * reorganize to random, fix uninitialized memory bug * silence linter * silence linter even further * s * strengthen Threefry key type checking, add tests * replace static variable with function for Threefry key type * lint fix * Remove old todos, improve assert messages * describe how random number is generated * add tests for incorrect output size. also vary test sizes Co-authored-by: Altan Haan <[email protected]>
* [RELAY,TOPI] Threefry PRNG: splittable and stateless * Fix sphinx? * Lint fixes * sphinx fixes round 2 * fix inputs for tests * reorganize to random, fix uninitialized memory bug * silence linter * silence linter even further * s * strengthen Threefry key type checking, add tests * replace static variable with function for Threefry key type * lint fix * Remove old todos, improve assert messages * describe how random number is generated * add tests for incorrect output size. also vary test sizes Co-authored-by: Altan Haan <[email protected]>
* [RELAY,TOPI] Threefry PRNG: splittable and stateless * Fix sphinx? * Lint fixes * sphinx fixes round 2 * fix inputs for tests * reorganize to random, fix uninitialized memory bug * silence linter * silence linter even further * s * strengthen Threefry key type checking, add tests * replace static variable with function for Threefry key type * lint fix * Remove old todos, improve assert messages * describe how random number is generated * add tests for incorrect output size. also vary test sizes Co-authored-by: Altan Haan <[email protected]>
nice work. do we have plan to support cuda? Looks like it works on cpu only at this moment. |
It should be easy to use on the GPU, just parallelize the outer loop. |
This PR adds a fast PRNG to Relay for use in dropout and batch norm. The PRNG is stateless: for a given key, it always returns the same random number. It is also splittable: for a given key, we can split the key to generate two new ones.
JAX provides a good explanation of stateless and splittable: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JAX-PRNG.
@altanh