Skip to content

Commit

Permalink
[Relay][PRNG] Add uniform distribution generator wrt threefry PRNG (a…
Browse files Browse the repository at this point in the history
…pache#8041)

* Add uniform distribution generator wrt threefry PRNG

* fix lint

* remove the redundant print

* modifications based on review

* update docs

* update uniform algorithm to use bit operations only

* add type restrictions

* minor fix upon review

* update test and error information
  • Loading branch information
zhuzilin authored and trevor-m committed Jun 17, 2021
1 parent 8e0a262 commit f8f3f6c
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 1 deletion.
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ struct ThreefryGenerateAttrs : public tvm::AttrsNode<ThreefryGenerateAttrs> {
}
};

struct UniformAttrs : public tvm::AttrsNode<UniformAttrs> {
Array<Integer> out_shape;
DataType out_dtype;

TVM_DECLARE_ATTRS(UniformAttrs, "relay.attrs.UniformAttrs") {
TVM_ATTR_FIELD(out_shape).describe("Shape of random numbers to generate");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Data type of the generated numbers");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_RANDOM_H_
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,8 @@ class BatchToSpaceNDAttrs(Attrs):
@tvm._ffi.register_object("relay.attrs.ThreefryGenerateAttrs")
class ThreefryGenerateAttrs(Attrs):
"""Attributes used in ThreefryGenerateAttrs operators"""


@tvm._ffi.register_object("relay.attrs.UniformAttrs")
class UniformAttrs(Attrs):
"""Attributes used in UniformAttrs operators"""
4 changes: 4 additions & 0 deletions python/tvm/relay/op/random/_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@
register_pattern("random.threefry_generate", OpPattern.OPAQUE)
register_strategy("random.threefry_split", strategy.threefry_split_strategy)
register_pattern("random.threefry_split", OpPattern.OPAQUE)

# Distribution
register_strategy("random.uniform", strategy.uniform_strategy)
register_pattern("random.uniform", OpPattern.OPAQUE)
54 changes: 53 additions & 1 deletion python/tvm/relay/op/random/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import sys
import numpy as np

from ...expr import Constant
from ...expr import Constant, Expr, const
from .... import nd
from . import _make

Expand Down Expand Up @@ -132,3 +132,55 @@ def foo(key):
:py:func:`threefry_generate`.
"""
return _make.threefry_split(key)


def uniform(key, shape, dtype="float32", low=0.0, high=1.0):
"""Draw samples from a uniform distribution.
Samples are uniformly distributed over the half-open interval [low, high)
(includes low, but excludes high). In other words, any value within the
given interval is equally likely to be drawn by uniform.
Example
-------
.. code-block:: python
key = threefry_key(0)
key, random_values = uniform(key, (100,), low=0, high=10)
Parameters
----------
key : relay.Expr
key that uniquely determines the random values. Multiple uses with the
same generator will generate the same random values. This generator should be
treated as an opaque pointer. You can create one from calling
:py:func:`threefry_key`, :py:func:`threefry_split`, or
:py:func:`threefry_generate`. **Do not use this generator again after calling
this function.**
shape : Sequence[int]
Desired outputs shape of random numbers.
dtype : str
Desired outputs type of random numbers.
low : float or relay.Expr, optional
Lower bound of the uniform distribution.
high : float or relay.Expr, optional
Upper bound of the uniform distribution.
Returns
-------
new_key : relay.Expr
New random key to pass to future uses of random functions.
random_values : relay.Expr
The generated uniform distributed random numbers.
"""
if not isinstance(low, Expr):
low = const(low, dtype=dtype)
if not isinstance(high, Expr):
high = const(high, dtype=dtype)
return _make.uniform(key, low, high, shape, dtype)
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,6 +1526,28 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
return strategy


# uniform
def wrap_compute_uniform(topi_compute):
"""Wrap uniform topi compute"""

def _compute_uniform(attrs, inputs, _):
return list(topi_compute(inputs[0], inputs[1], inputs[2], attrs.out_shape, attrs.out_dtype))

return _compute_uniform


@override_native_generic_func("uniform_strategy")
def uniform_strategy(attrs, inputs, out_type, target):
"""uniform generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_uniform(topi.random.uniform),
wrap_topi_schedule(topi.generic.schedule_extern),
name="uniform.generic",
)
return strategy


def wrap_compute_scanop(topi_compute):
"""Wrap scanop style topi compute"""

Expand Down
66 changes: 66 additions & 0 deletions python/tvm/topi/random/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,69 @@ def gen_ir(out_ptr):
out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
tvm.build(s, [f], target=target)(out_ary)
return out_ary.asnumpy()[0] == 0


def uniform(gen, low, high, out_shape, out_dtype):
"""Draw samples from a uniform distribution.
Samples are uniformly distributed over the half-open interval [low, high)
(includes low, but excludes high). In other words, any value within the
given interval is equally likely to be drawn by uniform.
Parameters
----------
gen : ThreefryKey
Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
reused in another function, otherwise random numbers will be repeated.
low : Tensor[(), out_dtype]
Lower boundary of the output interval. All values generated will be
greater than or equal to low.
high : Tensor[(), out_dtype]
Upper boundary of the output interval. All values generated will be
less than high.
out_shape : Sequence[int]
Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
out_dtype : str
The output dtype.
Returns
-------
new_gen : ThreefryKey
New generator state that is distinct from `gen`.
out : Tensor[out_shape, out_dtype]
Tensor of random numbers with shape `out_shape` and type `out_dtype`.
"""
new_gen, random_bits = threefry_generate(gen, out_shape)
assert out_dtype in ("float32", "float64"), (
"Only support float32 or float64 for now, got %s" % out_dtype
)
if out_dtype == "float32":
random_dtype = "uint32"
nbits = 32
nfraction = 23
elif out_dtype == "float64":
random_dtype = "uint64"
nbits = 64
nfraction = 52
nexp = nbits - nfraction - 1
random_bits = random_bits.astype(random_dtype)

fraction = tvm.topi.right_shift(
random_bits, tvm.tir.const(nbits - nfraction, dtype=random_dtype)
)
exponent = tvm.topi.left_shift(
tvm.topi.full(out_shape, random_dtype, (1 << (nexp - 1)) - 1),
tvm.tir.const(nfraction, dtype=random_dtype),
)
mantissa = tvm.topi.bitwise_or(fraction, exponent).astype(random_dtype)
standard_uniform_values = tvm.topi.reinterpret(mantissa, out_dtype) - tvm.tir.const(
1, dtype=out_dtype
)
uniform_values = tvm.topi.add(tvm.topi.multiply(standard_uniform_values, high - low), low)

return new_gen, uniform_values
47 changes: 47 additions & 0 deletions src/relay/op/random/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,52 @@ RELAY_REGISTER_OP("random.threefry_split")
.add_argument("key", "Tensor", "Input Threefry key")
.add_type_rel("ThreefrySplit", ThreefrySplitRel);

TVM_REGISTER_NODE_TYPE(UniformAttrs);

bool UniformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
const UniformAttrs* param = attrs.as<UniformAttrs>();
ICHECK_EQ(types.size(), 4) << "Uniform should have three inputs and one output";

std::vector<IndexExpr> oshape;
for (auto& x : param->out_shape) {
oshape.push_back(x);
}
DataType out_dtype = param->out_dtype;
// we are supporting float32 and float64 at the moment.
if (!(out_dtype.is_float() && (out_dtype.bits() == 32 || out_dtype.bits() == 64))) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "We only support generating uniform random value of "
<< "type float32 or float64, got " << out_dtype << ".");
return false;
}
reporter->Assign(types[0], ThreefryKeyType());
reporter->Assign(types[1], TensorType({}, out_dtype));
reporter->Assign(types[2], TensorType({}, out_dtype));
// generate returns the next key and an array of random values
reporter->Assign(types[3], TupleType({ThreefryKeyType(), TensorType(oshape, out_dtype)}));
return true;
}

Expr MakeUniform(Expr key, Expr low, Expr high, Array<Integer> out_shape, DataType out_dtype) {
auto attrs = make_object<UniformAttrs>();
attrs->out_shape = out_shape;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("random.uniform");
return Call(op, {key, low, high}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.random._make.uniform").set_body_typed(MakeUniform);

RELAY_REGISTER_OP("random.uniform")
.describe(
R"doc(Generate an array of random numbers under uniform distribution.)doc" TVM_ADD_FILELINE)
.set_num_inputs(3)
.set_attrs_type<UniformAttrs>()
.add_argument("key", "Tensor", "Input Threefry key")
.add_argument("low", "Tensor", "Lower bound of the distribution")
.add_argument("high", "Tensor", "Higher bound of the distribution")
.add_type_rel("Uniform", UniformRel);

} // namespace relay
} // namespace tvm
15 changes: 15 additions & 0 deletions tests/python/relay/test_prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@ def test_threefry_split_infer():
assert tvm.ir.structural_equal(f.ret_type, expected_type)


def test_uniform_infer():
oshape = (12,)
odtypes = ["float32", "float64"]
for odtype in odtypes:
key_type = tvm.relay.TensorType([10], dtype="uint64")
gen_type = tvm.relay.TensorType(oshape, dtype=odtype)
expected_type = tvm.relay.TupleType([key_type, gen_type])

key = tvm.relay.random.threefry_key(1)
rand1 = tvm.relay.random.uniform(key, oshape, odtype)
f = tvm.relay.Function([], rand1)
f = run_infer_type(f)
assert tvm.ir.structural_equal(f.ret_type, expected_type)


@pytest.mark.xfail(raises=tvm.error.TVMError)
def test_threefry_generate_infer_fail():
# xfail: key size should be 10
Expand Down
34 changes: 34 additions & 0 deletions tests/python/topi/python/test_topi_prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def threefry_generate(target, dev, gen, size):
return out_gen.asnumpy(), rands.asnumpy()


def uniform(target, dev, gen, low, high, size, dtype):
gen_placeholder = tvm.te.placeholder(gen.shape, name="gen", dtype="uint64")
low_placeholder = tvm.te.placeholder(low.shape, name="low", dtype=dtype)
high_placeholder = tvm.te.placeholder(high.shape, name="high", dtype=dtype)
left_placeholder, right_placeholder = tvm.topi.random.uniform(
gen_placeholder, low_placeholder, high_placeholder, size, dtype
)
s = tvm.topi.generic.schedule_extern([left_placeholder, right_placeholder])
f = tvm.build(
s, [gen_placeholder, low_placeholder, high_placeholder, left_placeholder, right_placeholder]
)
out_gen = tvm.nd.array(np.zeros(gen.shape, dtype="uint64"))
rands = tvm.nd.array(np.zeros(size, dtype=dtype))
f(tvm.nd.array(gen), tvm.nd.array(low), tvm.nd.array(high), out_gen, rands)
return out_gen.asnumpy(), rands.asnumpy()


@tvm.testing.parametrize_targets
def test_threefry_split(target, dev):
# test that results of split do not equal eachother or the input
Expand Down Expand Up @@ -118,7 +135,24 @@ def test_threefry_wrapping(target, dev):
), f"{target} does not suppport wrapping unsigned integer arithmetic"


@tvm.testing.parametrize_targets
def test_uniform(target, dev):
gen = tvm.relay.random.threefry_key(0).data.asnumpy()
m = 1024
n = 1024
dtypes = ["float32", "float64"]
for dtype in dtypes:
low = np.array(5.0, dtype=dtype)
high = np.array(10.0, dtype=dtype)
new_gen, rands = uniform(target, dev, gen, low, high, (m, n), dtype)
assert (gen != new_gen).any()
assert abs(np.mean(rands) - 7.5) < 1e-1
assert np.min(rands) >= 5.0
assert np.max(rands) <= 10.0


if __name__ == "__main__":
test_threefry_split(tvm.target.Target("llvm"), tvm.device("cpu"))
test_threefry_generate(tvm.target.Target("llvm"), tvm.device("cpu"))
test_threefry_wrapping(tvm.target.Target("llvm"), tvm.device("cpu"))
test_uniform(tvm.target.Target("llvm"), tvm.device("cpu"))

0 comments on commit f8f3f6c

Please sign in to comment.