From 214f22269b99464892993deb286cfc48f098a15b Mon Sep 17 00:00:00 2001 From: sxhu Date: Thu, 11 Nov 2021 10:30:22 +0800 Subject: [PATCH] [Frontend][ONNX] Support RandomNormal operator --- include/tvm/relay/attrs/random.h | 12 ++ python/tvm/relay/frontend/onnx.py | 91 +++++++++++++ python/tvm/relay/op/random/_kernel.py | 2 + python/tvm/relay/op/random/kernel.py | 48 +++++++ python/tvm/relay/op/strategy/generic.py | 12 ++ python/tvm/topi/random/kernel.py | 58 ++++++++ src/relay/op/random/kernel.cc | 47 +++++++ tests/python/frontend/onnx/test_forward.py | 150 ++++++++++++++++++++- 8 files changed, 418 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/attrs/random.h b/include/tvm/relay/attrs/random.h index 46cab8831caf..650737f1d0fb 100644 --- a/include/tvm/relay/attrs/random.h +++ b/include/tvm/relay/attrs/random.h @@ -49,6 +49,18 @@ struct UniformAttrs : public tvm::AttrsNode { } }; +struct NormalAttrs : public tvm::AttrsNode { + Array out_shape; + DataType out_dtype; + + TVM_DECLARE_ATTRS(NormalAttrs, "relay.attrs.NormalAttrs") { + TVM_ATTR_FIELD(out_shape).describe("Shape of random numbers to generate"); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Data type of the generated numbers"); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_RANDOM_H_ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3c88f659f6f0..337fb903faac 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3835,6 +3835,62 @@ def _impl_v12(cls, inputs, attr, params): return _op.einsum(inputs, equation) +class RandomNormal(OnnxOpConverter): + """Operator converter for random_normal""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + dtype = get_type(attr.get("dtype", 1)) + mean = attr.get("mean", 0.0) + scale = attr.get("scale", 1.0) + seed = attr.get("seed", None) + shape = attr["shape"] + + assert dtype in [ + "float32", + "float64", + ], "Only float random value generation is currently supported." + + if seed is None: + seed = np.random.randint(1e6) + else: + seed = int(seed) + key = _random.threefry_key(seed) + output = _op.random.normal(key, shape, dtype=dtype, mean=mean, scale=scale) + _, vals = _expr.TupleWrapper(output, 2) + return vals + + +class RandomNormalLike(OnnxOpConverter): + """Operator converter for random_normal_like""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + dtype = attr.get("dtype", None) + scale = attr.get("scale", 1.0) + mean = attr.get("mean", 0.0) + seed = attr.get("seed", None) + shape = infer_shape(inputs[0]) + if dtype is None: + dtype = infer_type(inputs[0]).checked_type.dtype + else: + dtype = get_type(dtype) + + assert dtype in [ + "float32", + "float64", + ], "Only float random value generation is currently supported." + + if seed is None: + seed = np.random.randint(1e6) + else: + seed = int(seed) + key = _random.threefry_key(seed) + output = _op.random.normal(key, shape, dtype=dtype, mean=mean, scale=scale) + _, vals = _expr.TupleWrapper(output, 2) + return vals + + class RandomUniform(OnnxOpConverter): """Operator converter for random_uniform""" @@ -3853,6 +3909,38 @@ def _impl_v1(cls, inputs, attr, params): if seed is None: seed = np.random.randint(1e6) + else: + seed = int(seed) + key = _random.threefry_key(seed) + output = _op.random.uniform(key, shape, dtype=dtype, low=low, high=high) + _, vals = _expr.TupleWrapper(output, 2) + return vals + + +class RandomUniformLike(OnnxOpConverter): + """Operator converter for random_uniform_like""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + dtype = attr.get("dtype", None) + high = attr.get("high", 1.0) + low = attr.get("low", 0.0) + seed = attr.get("seed", None) + shape = infer_shape(inputs[0]) + if dtype is None: + dtype = infer_type(inputs[0]).checked_type.dtype + else: + dtype = get_type(dtype) + + assert dtype in [ + "float32", + "float64", + ], "Only float random value generation is currently supported." + + if seed is None: + seed = np.random.randint(1e6) + else: + seed = int(seed) key = _random.threefry_key(seed) output = _op.random.uniform(key, shape, dtype=dtype, low=low, high=high) _, vals = _expr.TupleWrapper(output, 2) @@ -4343,7 +4431,10 @@ def _get_convert_map(opset): "QLinearGlobalAveragePool": QLinearGlobalAveragePool.get_converter(opset), "QLinearLeakyRelu": QLinearLeakyRelu.get_converter(opset), # Random number generation. + "RandomNormal": RandomNormal.get_converter(opset), + "RandomNormalLike": RandomNormalLike.get_converter(opset), "RandomUniform": RandomUniform.get_converter(opset), + "RandomUniformLike": RandomUniformLike.get_converter(opset), # Loss functions / training "NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset), "SoftmaxCrossEntropyLoss": SoftmaxCrossEntropyLoss.get_converter(opset), diff --git a/python/tvm/relay/op/random/_kernel.py b/python/tvm/relay/op/random/_kernel.py index e601a7073cff..70c8f9a855cf 100644 --- a/python/tvm/relay/op/random/_kernel.py +++ b/python/tvm/relay/op/random/_kernel.py @@ -31,3 +31,5 @@ # Distribution register_strategy("random.uniform", strategy.uniform_strategy) register_pattern("random.uniform", OpPattern.OPAQUE) +register_strategy("random.normal", strategy.normal_strategy) +register_pattern("random.normal", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/random/kernel.py b/python/tvm/relay/op/random/kernel.py index 6c82cc154eb6..7b5e955c9c55 100644 --- a/python/tvm/relay/op/random/kernel.py +++ b/python/tvm/relay/op/random/kernel.py @@ -183,3 +183,51 @@ def uniform(key, shape, dtype="float32", low=0.0, high=1.0): if not isinstance(high, Expr): high = const(high, dtype=dtype) return _make.uniform(key, low, high, shape, dtype) + + +def normal(key, shape, dtype="float32", mean=0.0, scale=1.0): + """Draw samples from a normal distribution. + + Example + ------- + + .. code-block:: python + + key = threefry_key(0) + key, random_values = normal(key, (100,), low=0, high=10) + + Parameters + ---------- + key : relay.Expr + key that uniquely determines the random values. Multiple uses with the + same generator will generate the same random values. This generator should be + treated as an opaque pointer. You can create one from calling + :py:func:`threefry_key`, :py:func:`threefry_split`, or + :py:func:`threefry_generate`. **Do not use this generator again after calling + this function.** + + shape : Sequence[int] + Desired outputs shape of random numbers. + + dtype : str + Desired outputs type of random numbers. + + low : float or relay.Expr, optional + Mean of the normal distribution. + + high : float or relay.Expr, optional + Standard deviation of the normal distribution. + + Returns + ------- + new_key : relay.Expr + New random key to pass to future uses of random functions. + + random_values : relay.Expr + The generated normal distributed random numbers. + """ + if not isinstance(mean, Expr): + mean = const(mean, dtype=dtype) + if not isinstance(scale, Expr): + scale = const(scale, dtype=dtype) + return _make.normal(key, mean, scale, shape, dtype) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 777f17ba6084..c59237e5f866 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1629,6 +1629,18 @@ def uniform_strategy(attrs, inputs, out_type, target): return strategy +@override_native_generic_func("normal_strategy") +def normal_strategy(attrs, inputs, out_type, target): + """normal generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_uniform(topi.random.normal), + wrap_topi_schedule(topi.generic.schedule_extern), + name="normal.generic", + ) + return strategy + + def wrap_compute_scanop(topi_compute): """Wrap scanop style topi compute""" diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py index 2ef97e2edc5c..64afcf066c11 100644 --- a/python/tvm/topi/random/kernel.py +++ b/python/tvm/topi/random/kernel.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Pseudorandom number kernels.""" +import math import numpy as np import tvm @@ -544,3 +545,60 @@ def uniform(gen, low, high, out_shape, out_dtype): uniform_values = tvm.topi.add(tvm.topi.multiply(standard_uniform_values, high - low), low) return new_gen, uniform_values + + +def normal(gen, mean, scale, out_shape, out_dtype): + """Draw samples from a normal distribution. + The algorithm is based on Box-Muller transform + + Parameters + ---------- + gen : ThreefryKey + Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be + reused in another function, otherwise random numbers will be repeated. + + mean : Tensor[(), out_dtype] + The mean of the normal distribution. + + scale : Tensor[(), out_dtype] + The standard deviation of the normal distribution. + + out_shape : Sequence[int] + Output shape of the random numbers. + + out_dtype : str + The output dtype. + + Returns + ------- + new_gen : ThreefryKey + New generator state that is distinct from `gen`. + + out : Tensor[out_shape, out_dtype] + Tensor of random numbers with shape `out_shape` and type `out_dtype`. + """ + out_shape = list(out_shape) + # Box-Muller transform need two pieces of original uniform data + out_shape.insert(0, 2) + new_gen, uniform_values = uniform( + gen, + tvm.tir.const(0.0, out_dtype), + tvm.tir.const(1.0, out_dtype), + out_shape, + out_dtype, + ) + two_pi = tvm.tir.const(2.0 * math.pi, out_dtype) + uniform_values_1 = tvm.topi.strided_slice(uniform_values, [0], [1], strides=[1], axes=[0]) + uniform_values_1 = tvm.topi.squeeze(uniform_values_1, axis=0) + uniform_values_2 = tvm.topi.strided_slice(uniform_values, [1], [2], strides=[1], axes=[0]) + uniform_values_2 = tvm.topi.squeeze(uniform_values_2, axis=0) + uniform_values_1 = tvm.topi.subtract(tvm.tir.const(1.0, out_dtype), uniform_values_1) + sqrt_values = tvm.topi.sqrt( + tvm.topi.multiply(tvm.tir.const(-2.0, out_dtype), tvm.topi.log(uniform_values_1)) + ) + sin_values = tvm.topi.sin(tvm.topi.multiply(two_pi, uniform_values_2)) + random_values = tvm.topi.add( + tvm.topi.multiply(tvm.topi.multiply(sqrt_values, sin_values), scale), mean + ) + + return new_gen, random_values diff --git a/src/relay/op/random/kernel.cc b/src/relay/op/random/kernel.cc index 077a1ea11673..7c957df3ed9b 100644 --- a/src/relay/op/random/kernel.cc +++ b/src/relay/op/random/kernel.cc @@ -132,5 +132,52 @@ RELAY_REGISTER_OP("random.uniform") .add_argument("high", "Tensor", "Higher bound of the distribution") .add_type_rel("Uniform", UniformRel); +TVM_REGISTER_NODE_TYPE(NormalAttrs); + +bool NormalRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const NormalAttrs* param = attrs.as(); + ICHECK_EQ(types.size(), 4) << "Normal should have three inputs and one output"; + + std::vector oshape; + for (auto& x : param->out_shape) { + oshape.push_back(x); + } + DataType out_dtype = param->out_dtype; + // we are supporting float32 and float64 at the moment. + if (!(out_dtype.is_float() && (out_dtype.bits() == 32 || out_dtype.bits() == 64))) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "We only support generating Normal random value of " + << "type float32 or float64, got " << out_dtype << "."); + return false; + } + reporter->Assign(types[0], ThreefryKeyType()); + reporter->Assign(types[1], TensorType({}, out_dtype)); + reporter->Assign(types[2], TensorType({}, out_dtype)); + // generate returns the next key and an array of random values + reporter->Assign(types[3], TupleType({ThreefryKeyType(), TensorType(oshape, out_dtype)})); + return true; +} + +Expr MakeNormal(Expr key, Expr mean, Expr scale, Array out_shape, DataType out_dtype) { + auto attrs = make_object(); + attrs->out_shape = out_shape; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("random.normal"); + return Call(op, {key, mean, scale}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.random._make.normal").set_body_typed(MakeNormal); + +RELAY_REGISTER_OP("random.normal") + .describe( + R"doc(Generate an array of random numbers under normal distribution.)doc" TVM_ADD_FILELINE) + .set_num_inputs(3) + .set_attrs_type() + .add_argument("key", "Tensor", "Input Threefry key") + .add_argument("mean", "Tensor", "Mean of the distribution") + .add_argument("scale", "Tensor", "Standard deviation of the distribution") + .add_type_rel("Normal", NormalRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dd1c77330986..75ce7f7c0218 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5594,7 +5594,7 @@ def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None): assert list(vals.shape) == [1, 3, 100, 100] # Check that bounds aren't exceeded. - vals = get_random_uniform(shape=[100], high=100, low=-100) + vals = get_random_uniform(shape=[100], high=100.0, low=-100.0) assert list(vals.shape) == [100] assert all(vals >= -100) and all(vals <= 100) @@ -5604,7 +5604,7 @@ def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None): assert all(vals_1 == vals_2) # Test against an expected output with a fixed seed. - real = get_random_uniform(shape=[10], seed=5) + real = get_random_uniform(shape=[10], seed=5.0) expected = np.asarray( [ 0.043976, @@ -5622,6 +5622,149 @@ def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None): tvm.testing.assert_allclose(real, expected, rtol=1e-5) +@tvm.testing.parametrize_targets +def test_random_uniform_like(target, dev): + def get_random_uniform_like(input, shape, dtype=None, high=1.0, low=0.0, seed=None): + node = helper.make_node("RandomUniformLike", ["in"], ["out"], high=high, low=low) + if seed is not None: + seed_attr = helper.make_attribute("seed", seed) + node.attribute.append(seed_attr) + + ONNX_DTYPE = None + if dtype is not None: + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + dtype_attr = helper.make_attribute("dtype", ONNX_DTYPE) + node.attribute.append(dtype_attr) + else: + dtype = input.dtype + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + + graph = helper.make_graph( + [node], + "random_uniform_test", + inputs=[helper.make_tensor_value_info("in", ONNX_DTYPE, shape)], + outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)], + ) + model = helper.make_model(graph, producer_name="random_uniform_like_test") + return get_tvm_output_with_vm(model, [input], target=target, dev=dev) + + # Check that function runs and produces proper shape and dtype. + shape = [10] + input = np.random.random(shape).astype("float32") + vals = get_random_uniform_like(input, shape, dtype="float32") + assert list(vals.shape) == [10] + assert vals.dtype == "float32" + + # Test N-D tensor generation. + shape = [1, 3, 100, 100] + input = np.random.random(shape).astype("float32") + vals = get_random_uniform_like(input, shape, dtype="float64") + assert list(vals.shape) == shape + assert vals.dtype == "float64" + + # Check that bounds aren't exceeded. + shape = [100] + input = np.random.random(shape).astype("float64") + vals = get_random_uniform_like(input, shape, high=100.0, low=-100.0) + assert list(vals.shape) == shape + assert all(vals >= -100) and all(vals <= 100) + + # Test against an expected output with a fixed seed. + shape = [10] + input = np.random.random(shape).astype("float32") + real = get_random_uniform_like(input, shape=[10], seed=5.0) + expected = np.asarray( + [ + 0.043976, + 0.96656, + 0.292199, + 0.904297, + 0.25167, + 0.521778, + 0.778985, + 0.085463, + 0.939846, + 0.194201, + ] + ) + tvm.testing.assert_allclose(real, expected, rtol=1e-5) + + +@tvm.testing.parametrize_targets +def test_random_normal(target, dev): + def get_random_normal(shape, dtype="float32", scale=1.0, mean=0.0, seed=None): + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + node = helper.make_node( + "RandomNormal", [], ["out"], shape=shape, dtype=ONNX_DTYPE, scale=scale, mean=mean + ) + if seed is not None: + seed_attr = helper.make_attribute("seed", seed) + node.attribute.append(seed_attr) + + graph = helper.make_graph( + [node], + "random_normal_test", + inputs=[], + outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)], + ) + model = helper.make_model(graph, producer_name="random_normal_test") + return get_tvm_output_with_vm(model, [], target=target, dev=dev) + + # Test N-D tensor generation. + vals = get_random_normal([1, 3, 100, 100], dtype="float32") + assert list(vals.shape) == [1, 3, 100, 100] + tvm.testing.assert_allclose(vals.mean(), 0.0, rtol=0.1, atol=0.1) + tvm.testing.assert_allclose(np.std(vals), 1.0, rtol=0.1, atol=0.1) + + # Test mean=2.0 scale=10.0 + vals = get_random_normal([1, 3, 100, 100], mean=2.0, scale=10.0, dtype="float32") + assert list(vals.shape) == [1, 3, 100, 100] + tvm.testing.assert_allclose(vals.mean(), 2.0, rtol=0.1, atol=0.1) + tvm.testing.assert_allclose(np.std(vals), 10.0, rtol=0.1, atol=0.1) + + # Check that a fixed seed produces the same values when run twice. + vals_1 = get_random_normal(shape=[10], seed=1.0) + vals_2 = get_random_normal(shape=[10], seed=1.0) + assert all(vals_1 == vals_2) + + +@tvm.testing.parametrize_targets +def test_random_normal_like(target, dev): + def get_random_normal_like(input, shape, dtype="float32", scale=1.0, mean=0.0, seed=None): + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + node = helper.make_node( + "RandomNormalLike", ["in"], ["out"], dtype=ONNX_DTYPE, scale=scale, mean=mean + ) + if seed is not None: + seed_attr = helper.make_attribute("seed", seed) + node.attribute.append(seed_attr) + + graph = helper.make_graph( + [node], + "random_normal_like_test", + inputs=[helper.make_tensor_value_info("in", ONNX_DTYPE, shape)], + outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)], + ) + model = helper.make_model(graph, producer_name="random_normal_like_test") + return get_tvm_output_with_vm(model, [input], target=target, dev=dev) + + # Test N-D tensor generation. + shape = [1, 3, 100, 100] + input = np.random.random(shape).astype("float32") + vals = get_random_normal_like(input, [1, 3, 100, 100], dtype="float32") + assert list(vals.shape) == [1, 3, 100, 100] + tvm.testing.assert_allclose(vals.mean(), 0.0, rtol=0.1, atol=0.1) + tvm.testing.assert_allclose(np.std(vals), 1.0, rtol=0.1, atol=0.1) + + # Test mean=2.0 scale=10.0 + shape = [1, 3, 100, 100] + input = np.random.random(shape).astype("float32") + vals = get_random_normal_like(input, [1, 3, 100, 100], mean=2.0, scale=10.0, dtype="float32") + assert list(vals.shape) == [1, 3, 100, 100] + tvm.testing.assert_allclose(vals.mean(), 2.0, rtol=0.1, atol=0.1) + tvm.testing.assert_allclose(np.std(vals), 10.0, rtol=0.1, atol=0.1) + + @tvm.testing.parametrize_targets def test_convinteger(target, dev): def verify_convinteger( @@ -5864,3 +6007,6 @@ def repeat(N, D): test_convinteger() test_batch_matmul() test_global_lppool() + test_random_uniform_like() + test_random_normal() + test_random_normal_like()