Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][PRNG] Add uniform distribution generator wrt threefry PRNG #8041

Merged
merged 9 commits into from
May 21, 2021

Conversation

zhuzilin
Copy link
Contributor

This PR adds a uniform distribution generator using the threefry PRNG introduced in #7083. We would need uniform to develop the training phase dropout as the following roadmap:

uniform -> bernoulli -> dropout

The algorithm used is basically the same as the one used in jax: using the random bits generated from threefry_generate as the fraction section of the float32 or float64. To be specific, I use the last 23 bits of the random bits for float32 and last 52 for float64. There is one different from the jax implementation. In jax, they used a bitcast to turn uint into float:

# jax implementation
def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray:
  ...
  bits = _random_bits(key, nbits, shape)

  # The strategy here is to randomize only the mantissa bits with an exponent of
  # 1 (after applying the bias), then shift and scale to the desired range. The
  # bit-level transformation we use relies on Numpy and XLA having bit-for-bit
  # equivalent float representations, which might not be true on all platforms.
  float_bits = lax.bitwise_or(
      lax.shift_right_logical(bits, np.array(nbits - nmant, lax.dtype(bits))),
      np.array(1., dtype).view(_UINT_DTYPES[nbits]))
  floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
  return lax.max(
      minval,
      lax.reshape(floats * (maxval - minval) + minval, shape.positional))

However, as I haven't found the bitcast in te or topi, I use a divide to cast the type, which may be slower:

    def uniform_scalar(bits):
        bits = bits >> (nbits - nfraction)
        standard_uniform = bits.astype(out_dtype) / float(1 << nfraction)
        return standard_uniform

Thank you for your time on reviewing this PR. I may not be familiar enough with the tvm codebase at the moment, so I'm sorry for breaking any conventions in the community and I'd love to fix them :).

Gently ping @tqchen @altanh @tkonolige

@zhuzilin zhuzilin force-pushed the prng_uniform branch 2 times, most recently from 2f831ea to 78bc8b7 Compare May 14, 2021 04:28
@zhuzilin zhuzilin requested a review from FrozenGene May 14, 2021 05:04

Parameters
----------
gen : Tensor[10, uint64]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the meaning of 10?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the ThreefryKeyType introduced in #7083. Please refer to:

static TensorType ThreefryKeyType() { return TensorType({10}, tvm::DataType::UInt(64)); }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, let us add comment describe what is the meaning of 10.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could probably say ThreefryKey instead of Tensor[10, uint64]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

less than high.

out_shape : Sequence[int]
Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason of product must be a multiple of 4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the property of the threefry key. Please refer to this comment: #7083 (comment)

Copy link
Contributor Author

@zhuzilin zhuzilin May 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just rethink about this problem. There should not be any restriction to the output shape... We could change the input restriction of the threefry_generate in other PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind sending a PR for updating the threefry_generate output, or rather what approach do you have in mind? I tried to avoid this problem by truncating output buffer but this required an extra copy, wonder if you have something else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@altanh Sorry that I'm not familiar with the threefry algorithm. Is it possible to call _threefry twice in threefry_generate in the following form? something like:

out_array = irb.buffer_ptr(out_array_ptr)
# deal with most of the array
_threefry(irb, tmp, 0, tmp, 4, out_array, 0, out_len // 4)
if out_len % 4 != 0:
    # generate remainders in a small tmp buffer
    tmp_array = irb.allocate(gen.dtype, 4, name="tmp", scope="global")
    # may need to update the tmp key in between
    # ...
    _threefry(irb, tmp, 0, tmp, 4, tmp_array, 0, out_len // 4)
    # only copy the tmp buffer
    for i in range(out_len // 4 * 4, out_len):
        out_array[i] = tmp_array[i%4]

In this way, we coud avoid copying the whole generated tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you could do that. Maybe submit it in a new PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tkonolige Sure, I will submit one. Could you tell me what kind of update on key tmp we need before the second _threefry? I can only think of updating increment counter (tmp[7]).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to update the counter buffer to be equal to out_len

python/tvm/topi/random/kernel.py Show resolved Hide resolved
@@ -103,6 +103,19 @@ def test_threefry_split_infer():
assert tvm.ir.structural_equal(f.ret_type, expected_type)


def test_uniform_infer():
oshape = (12,)
odtype = "float32"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should cover more types. For example float64 you have implemented

tests/python/topi/python/test_topi_prng.py Outdated Show resolved Hide resolved
python/tvm/relay/op/random/kernel.py Show resolved Hide resolved
src/relay/op/random/kernel.cc Outdated Show resolved Hide resolved

standard_uniform_values = tvm.te.compute(out_shape, lambda *i: uniform_scalar(random_bits(*i)))

uniform_values = tvm.topi.add(tvm.topi.multiply(standard_uniform_values, high - low), low)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How well does this approach work when we have a large range (high - low)? It seems like we would be loosing a lot of potential randomness.

@altanh
Copy link
Contributor

altanh commented May 14, 2021

Thanks for this PR! I will be reading it soon, and just wanted to point you to a branch I worked on a while ago where I hacked a uniform op + dropout support: https://github.com/altanh/tvm/commits/prng (just in case it might be useful for you to check and compare).

However, as I haven't found the bitcast in te or topi, I use a divide to cast the type, which may be slower:

Perhaps this is the operation you're looking for? https://github.com/altanh/tvm/blob/2d9ac7710ab055d4f20e8b5a0a3580836723efac/python/tvm/topi/generic/algorithm.py#L465

Thanks!

@zhuzilin
Copy link
Contributor Author

zhuzilin commented May 15, 2021

@FrozenGene @tkonolige @altanh Thank you for your reviews. I've updated this PR based on them.

I think there are 2 options: 1. Ref Conv2DRel. 2. You could restrict the type here and raise exception. I prefer option 1.

@FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in Conv2DRel or Conv2DAttrs... Should I add the type restriction to the UniformAttrs, or raise error in the MakeUniform and UniformRel?

How well does this approach work when we have a large range (high - low)? It seems like we would be loosing a lot of potential randomness.

@tkonolige As this approach is only using the fraction bits to represent float, there will be loss of randomness for all floats, at least (2^nexp-1) / 2^nexp of the float (nexp stands for the number of exponential digits). However, it's a little tricky to use all 64 digits of the random bit to represent a uniform distributed float number... Do you have any idea on that?

Perhaps this is the operation you're looking for? https://github.com/altanh/tvm/blob/2d9ac7710ab055d4f20e8b5a0a3580836723efac/python/tvm/topi/generic/algorithm.py#L465

@altanh Thank you for your references! The reinterpret is exactly what I was looking for. I've updated the algorithm and right now it is the same as the one used in jax.

@zhuzilin zhuzilin requested a review from FrozenGene May 16, 2021 02:33
@FrozenGene
Copy link
Member

FrozenGene commented May 17, 2021

@FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in Conv2DRel or Conv2DAttrs... Should I add the type restriction to the UniformAttrs, or raise error in the MakeUniform and UniformRel?

Suggest UniformRel

@zhuzilin
Copy link
Contributor Author

Suggest UniformRel

@FrozenGene Thank you. I've added the type restriction.

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good to me. Just a couple of minor fixes.

python/tvm/relay/op/random/kernel.py Outdated Show resolved Hide resolved
python/tvm/topi/random/kernel.py Show resolved Hide resolved
Copy link
Contributor

@altanh altanh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM with some minor comments. I did want to request that we keep the output shape restriction in the documentation for now until a follow up PR is merged which relaxes it. Thanks for the work!

src/relay/op/random/kernel.cc Outdated Show resolved Hide resolved
python/tvm/relay/op/strategy/generic.py Show resolved Hide resolved
@zhuzilin
Copy link
Contributor Author

zhuzilin commented May 18, 2021

@FrozenGene @altanh @tkonolige I've updated the PR upon the reviews. Could you take another look? Thank you~

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Just some small comments

python/tvm/topi/random/kernel.py Outdated Show resolved Hide resolved
src/relay/op/random/kernel.cc Outdated Show resolved Hide resolved
Copy link
Contributor

@altanh altanh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I'm a bit uneasy about introducing a nondeterministic test based on averaging the random numbers but I imagine it will almost never fail. Also left a comment about comparing the min/max of the generated numbers - can we always guarantee <= or >= on the output or will there be some floating point inaccuracy cases where this might be violated?

tests/python/topi/python/test_topi_prng.py Outdated Show resolved Hide resolved
tests/python/topi/python/test_topi_prng.py Outdated Show resolved Hide resolved
@zhuzilin
Copy link
Contributor Author

@FrozenGene Could you take another look of this PR? Thank you~

@zhuzilin
Copy link
Contributor Author

@FrozenGene Could you have another look at this PR? Thank you!

@FrozenGene FrozenGene merged commit e438a73 into apache:main May 21, 2021
@FrozenGene
Copy link
Member

Thanks @zhuzilin @altanh @tkonolige merged now

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 17, 2021
…pache#8041)

* Add uniform distribution generator wrt threefry PRNG

* fix lint

* remove the redundant print

* modifications based on review

* update docs

* update uniform algorithm to use bit operations only

* add type restrictions

* minor fix upon review

* update test and error information
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jun 17, 2021
…pache#8041)

* Add uniform distribution generator wrt threefry PRNG

* fix lint

* remove the redundant print

* modifications based on review

* update docs

* update uniform algorithm to use bit operations only

* add type restrictions

* minor fix upon review

* update test and error information
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants