From f115c74903df8f5ca00b3941798bd6df78f0ad8c Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 18 Aug 2020 00:38:31 +0800 Subject: [PATCH] [random] support random fill (#5913) --- cmake/config.cmake | 2 +- .../contrib/random/mt_random_engine.cc | 47 +++++++++++++++ src/runtime/contrib/random/random.cc | 6 ++ tests/python/contrib/test_random.py | 57 +++++++++++++++++++ 4 files changed, 111 insertions(+), 1 deletion(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index b7b9de8d0b08e..c41ed95bccbc7 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -152,7 +152,7 @@ set(USE_MKLDNN OFF) set(USE_OPENMP none) # Whether use contrib.random in runtime -set(USE_RANDOM OFF) +set(USE_RANDOM ON) # Whether use NNPack set(USE_NNPACK OFF) diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index c628e327643e4..8a4ee9af24bb4 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -22,11 +22,15 @@ * \brief mt19937 random engine */ #include +#include +#include #include #include #include +#include "../3rdparty/compiler-rt/builtin_fp16.h" + namespace tvm { namespace contrib { @@ -111,6 +115,49 @@ class RandomEngine { } } + void RandomFill(DLTensor* data) { + int64_t size = 1; + for (int i = 0; i < data->ndim; ++i) { + size *= data->shape[i]; + } + + if (data->ctx.device_type == kDLCPU) { + FillData(data, size); + } else { + runtime::NDArray local = runtime::NDArray::Empty( + std::vector{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0}); + FillData(&local.ToDLPack()->dl_tensor, size); + runtime::NDArray::CopyFromTo(&local.ToDLPack()->dl_tensor, data); + } + } + + private: + void FillData(DLTensor* tensor, int64_t size) { + // Make the value be 1.0 - 10.0, not (0.0 - 1.0) so that we could satisfy + // quantized dtype (uint8 / int8) data non-empty requirement + std::uniform_real_distribution<> dist(1.0, 10.0); + // Use float representation could make us work well on float / int type too. + if (tensor->dtype.bits == 1) { + std::generate_n(static_cast(tensor->data), size, [&]() { return dist(rnd_engine_); }); + } else if (tensor->dtype.bits == 8) { + std::generate_n(static_cast(tensor->data), size, + [&]() { return dist(rnd_engine_); }); + } else if (tensor->dtype.bits == 16) { + std::generate_n(static_cast(tensor->data), size, [&]() { + return __truncXfYf2__( + static_cast(dist(rnd_engine_))); + }); + } else if (tensor->dtype.bits == 32) { + std::generate_n(static_cast(tensor->data), size, [&]() { return dist(rnd_engine_); }); + } else if (tensor->dtype.bits == 64) { + std::generate_n(static_cast(tensor->data), size, + [&]() { return dist(rnd_engine_); }); + } else { + LOG(FATAL) << "Doesn't support dtype code " << tensor->dtype.code << " dtype bits " + << tensor->dtype.bits; + } + } + private: std::mt19937 rnd_engine_; unsigned rseed_; diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index acba193c12305..14bdd267d38c4 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -117,5 +117,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRe entry->random_engine.SampleNormal(out, loc, scale); }); +TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + DLTensor* out = args[0]; + entry->random_engine.RandomFill(out); +}); + } // namespace contrib } // namespace tvm diff --git a/tests/python/contrib/test_random.py b/tests/python/contrib/test_random.py index 9efdc3e5a7631..bc081a422ef2a 100644 --- a/tests/python/contrib/test_random.py +++ b/tests/python/contrib/test_random.py @@ -18,6 +18,22 @@ from tvm import te import numpy as np from tvm.contrib import random +from tvm import rpc + +def enabled_ctx_list(): + ctx_list = [('cpu', tvm.cpu(0)), + ('gpu', tvm.gpu(0)), + ('cl', tvm.opencl(0)), + ('metal', tvm.metal(0)), + ('rocm', tvm.rocm(0)), + ('vulkan', tvm.vulkan(0)), + ('vpi', tvm.vpi(0))] + for k, v in ctx_list: + assert tvm.context(k, 0) == v + ctx_list = [x[1] for x in ctx_list if x[1].exist] + return ctx_list + +ENABLED_CTX_LIST = enabled_ctx_list() def test_randint(): m = 1024 @@ -89,8 +105,49 @@ def verify(target="llvm"): assert abs(np.std(na) - 4) < 1e-2 verify() +def test_random_fill(): + def test_local(ctx, dtype): + if not tvm.get_global_func("tvm.contrib.random.random_fill", True): + print("skip because extern function is not available") + return + np_ones = np.ones((512, 512), dtype=dtype) + value = tvm.nd.empty(np_ones.shape, np_ones.dtype, ctx) + random_fill = tvm.get_global_func("tvm.contrib.random.random_fill") + random_fill(value) + + assert np.count_nonzero(value.asnumpy()) == 512 * 512 + + # make sure arithmentic doesn't overflow too + np_values = value.asnumpy() + assert np.isfinite(np_values * np_values + np_values).any() + + def test_rpc(dtype): + if not tvm.get_global_func("tvm.contrib.random.random_fill", True): + print("skip because extern function is not available") + return + if not tvm.runtime.enabled("rpc") or not tvm.runtime.enabled("llvm"): + return + np_ones = np.ones((512, 512), dtype=dtype) + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + value = tvm.nd.empty(np_ones.shape, np_ones.dtype, remote.cpu()) + random_fill = remote.get_function("tvm.contrib.random.random_fill") + random_fill(value) + + assert np.count_nonzero(value.asnumpy()) == 512 * 512 + + # make sure arithmentic doesn't overflow too + np_values = value.asnumpy() + assert np.isfinite(np_values * np_values + np_values).any() + + for dtype in ["bool", "int8", "uint8", "int16", "uint16", "int32", "int32", + "int64", "uint64", "float16", "float32", "float64"]: + for ctx in ENABLED_CTX_LIST: + test_local(ctx, dtype) + test_rpc(dtype) if __name__ == "__main__": test_randint() test_uniform() test_normal() + test_random_fill()