diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py b/python/tvm/relax/frontend/nn/_tensor_op.py index 3a646e29b8dc..7f44ca24386d 100644 --- a/python/tvm/relax/frontend/nn/_tensor_op.py +++ b/python/tvm/relax/frontend/nn/_tensor_op.py @@ -67,6 +67,22 @@ def __truediv__(self, other): other = _convert_scalar(other, self) return _op().divide(self, other) + def __lt__(self, other): + other = _convert_scalar(other, self) + return _op().less(self, other) + + def __le__(self, other): + other = _convert_scalar(other, self) + return _op().less_equal(self, other) + + def __gt__(self, other): + other = _convert_scalar(other, self) + return _op().greater(self, other) + + def __ge__(self, other): + other = _convert_scalar(other, self) + return _op().greater_equal(self, other) + def astype(self, dtype): return _op().astype(self, dtype) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index b6c34ca265b8..6944fc8535af 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -24,6 +24,8 @@ import numpy as np from tvm import tir as _tir +from tvm.script import tir as T +from tvm import te from ... import expr as rx from ... import op as _op @@ -1825,3 +1827,502 @@ def print_(tensor: Tensor): filename, line_number = inspect.getframeinfo(inspect.currentframe().f_back)[:2] line_info = f"{filename}:{line_number}" debug_func("vm.builtin.debug_print", tensor, _line_info=line_info) + + +def less(a: Tensor, b: Tensor, name: str = "less") -> Tensor: + """Broadcasted element-wise comparison for (lhs < rhs). + + Parameters + ---------- + a : Tensor + The first input tensor. + + b : Tensor + The second input tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.less(a._expr, b._expr), name) + + +def less_equal(a: Tensor, b: Tensor, name: str = "less_equal") -> Tensor: + """Broadcasted element-wise comparison for (lhs <= rhs). + + Parameters + ---------- + a : Tensor + The first input tensor. + + b : Tensor + The second input tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.less_equal(a._expr, b._expr), name) + + +def greater(a: Tensor, b: Tensor, name: str = "greater") -> Tensor: + """Broadcasted element-wise comparison for (lhs > rhs). + + Parameters + ---------- + a : Tensor + The first input tensor. + + b : Tensor + The second input tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.greater(a._expr, b._expr), name) + + +def greater_equal(a: Tensor, b: Tensor, name: str = "greater_equal") -> Tensor: + """Broadcasted element-wise comparison for (lhs >= rhs). + + Parameters + ---------- + a : Tensor + The first input tensor. + + b : Tensor + The second input tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.greater_equal(a._expr, b._expr), name) + + +def equal(a: Tensor, b: Tensor, name: str = "equal") -> Tensor: + """Broadcasted element-wise comparison for (lhs == rhs). + + Parameters + ---------- + a : Tensor + The first input tensor. + + b : Tensor + The second input tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.equal(a._expr, b._expr), name) + + +def not_equal(a: Tensor, b: Tensor, name: str = "not_equal") -> Tensor: + """Broadcasted element-wise comparison for (lhs != rhs). + + Parameters + ---------- + a : Tensor + The first input tensor. + + b : Tensor + The second input tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.not_equal(a._expr, b._expr), name) + + +def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") -> Tensor: + """Selecting elements from either the input tensors depending on the value of the + condition. + + For a given position, return the corresponding value in `x1` if `condition` is True, + and return the corresponding value in `x2` otherwise. + + Parameters + ---------- + condition : Tensor + When True, yield `x1`; otherwise, yield `x2`. + Must be broadcasting compatible with `x1` and `x2`. + Must have boolean dtype. + + x1 : Tensor + The first input tensor. + Must be broadcasting compatible with `condition` and `x2`. + + x2 : Tensor + The second input tensor. + Must be broadcasting compatible with `condition` and `x1`. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The result tensor. + """ + return wrap_nested(_op.where(condition._expr, x1._expr, x2._expr), name) + + +def cumsum( + data: Tensor, + axis: Optional[int] = None, + dtype: Optional[str] = None, + exclusive: Optional[bool] = None, + name: str = "cumsum", +) -> Tensor: + """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along + a given axis. + + Parameters + ---------- + data : Tensor + The input data to the operator. + + axis : Optional[int] + Axis along which the cumulative sum is computed. The default (None) is to compute + the cumsum over the flattened array. + + dtype : Optional[str] + Type of the returned array and of the accumulator in which the elements are summed. + If dtype is not specified, it defaults to the dtype of data. + + exclusive : Optional[bool] + If true will return exclusive sum in which the first element is not + included. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + + Examples + -------- + .. code-block:: python + + a = [[1, 2, 3], [4, 5, 6]] + + cumsum(a) # if axis is not provided, cumsum is done over the flattened input. + -> [ 1, 3, 6, 10, 15, 21] + + cumsum(a, dtype="float32") + -> [ 1., 3., 6., 10., 15., 21.] + + cumsum(a, axis=0) # sum over rows for each of the 3 columns + -> [[1, 2, 3], + [5, 7, 9]] + + cumsum(a, axis=1) + -> [[ 1, 3, 6], + [ 4, 9, 15]] + + a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array + cumsum(a, dtype=int32) # dtype should be provided to get the expected results + -> [1, 1, 2, 2, 3, 4, 4] + """ + return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name) + + +def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str = "int64"): + """Returns a tensor where each row contains the index sampled from the multinomial + probability distribution located in the corresponding row of tensor prob. + + Notes + ----- + For better cpu performance, use 'vm.builtin.multinomial_from_uniform'. + For accurate results, ensure probabilities are between 0 and 1 and sum to 1. + + Parameters + ---------- + prob : Tensor + A 2-D tensor of shape (batch, vocab_size) representing probability distributions. + Each row is a distribution across vocabulary for a batch, where: + Values range from [0, 1], indicating the probability of each vocabulary item. + The sum of values in each row is 1, forming a valid distribution. + + uniform_sample : Tensor + The uniformly sampled 2-D tensor with the shape (batch, 1). + Values range from 0 to 1, indicating probabilities sampled uniformly. + + Returns + ------- + result : Tensor + The computed tensor with shape (batch, 1). + + Examples + -------- + .. code-block:: python + + prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]] + usample = [[0.4], [0.9]] + + multinomial_from_uniform(prob, usample) + -> [[1], [2]] + """ + prob_dtype = prob.dtype + sample_dtype = uniform_sample.dtype + batch = prob.shape[0] + + @T.prim_func(private=True) + def _get_sample_index(A: T.handle, B: T.handle, C: T.handle): + batch, vocab_size = T.int64(), T.int64() + prob = T.match_buffer(A, (batch, vocab_size), prob_dtype) + usample = T.match_buffer(B, (batch, 1), sample_dtype) + output_index = T.match_buffer(C, (batch, 1), dtype) + + for ax0, ax1 in T.grid(batch, vocab_size): + with T.block("T_get_sample_index"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.writes(output_index[v_ax0, 0]) + if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 1 == vocab_size: + if v_ax1 == 0: + output_index[v_ax0, 0] = 0 + elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]: + output_index[v_ax0, 0] = v_ax1 + + cumsum_prob = cumsum(prob, axis=1, exclusive=False) + + return tensor_ir_op( + _get_sample_index, + "get_sample_index", + args=[cumsum_prob, uniform_sample], + out=Tensor.placeholder([batch, 1], dtype), + ) + + +def sample_top_p_top_k_from_sorted_prob( + sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor, uniform_sample: Tensor +): + """Samples indices from a sorted probability tensor based on top_p and top_k criteria. + + Notes + ----- + For accurate results, ensure probabilities are between 0 and 1 and sum to 1. + + Parameters + ---------- + sorted_prob : Tensor + A 2-D tensor, with shape (batch, vocab_size), contains probabilities + sorted in descending order. + + sorted_index: Tensor + The indices tensor with shape (batch, vocab_size), corresponding to the + sorted_prob. Potentially from applying argsort on the original probability + tensor in descending order. + + top_p : Tensor + The cumulative probability threshold with shape (batch, 1) for nucleus sampling. + + top_k :Tensor + A tensor with shape (batch, 1), representing the number of top probabilities + to consider for top-k sampling. + + uniform_sample : Tensor + Uniformly sampled values with shape (batch, 1) are used to select the output indices. + + Returns + ------- + result : Tensor + The selected indices with shape (batch, 1). + + Examples + -------- + .. code-block:: python + + prob = [[0.1 , 0.4, 0.5], + [0.3, 0.3, 0.4]] + sorted_prob = [[0.5, 0.4, 0.1], + [0.4, 0.3, 0.3]] + sorted_index = [[2, 1, 0], + [2, 0, 1]] + top_p = [[0.6],[0.9]] + top_k = [[3],[2]] + uniform_sample = [[0.5], [0.6]] + + sample_top_p_top_k_from_sorted_prob( + sorted_prob, sorted_index,top_p, top_k, uniform_sample) + -> [2, 0] + + """ + prob_dtype = sorted_prob.dtype + index_dtype = sorted_index.dtype + batch = sorted_prob.shape[0] + + def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j): + return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0]) + + @T.prim_func(private=True) + def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): + batch, vocab_size = T.int64(), T.int64() + cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype) + top_p = T.match_buffer(B, (batch, 1), prob_dtype) + top_k = T.match_buffer(C, (batch, 1), index_dtype) + renorm_prob = T.match_buffer(D, (batch, 1), prob_dtype) + for ax0, ax1 in T.grid(batch, vocab_size): + with T.block("T_get_renorm_prob"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0: + renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0] + elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1) == 1: + if v_ax1 + 1 == vocab_size: + renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1] + elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1 + 1) == 0: + renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1] + + @T.prim_func(private=True) + def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle): + batch, vocab_size = T.int64(), T.int64() + cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype) + renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype) + usample = T.match_buffer(C, (batch, 1), prob_dtype) + indices = T.match_buffer(D, (batch, vocab_size), index_dtype) + output_index = T.match_buffer(E, (batch, 1), index_dtype) + + for ax0, ax1 in T.grid(batch, vocab_size): + with T.block("T_get_index_from_sorted"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.writes(output_index[v_ax0, 0]) + if ( + usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0] + or v_ax1 + 1 == vocab_size + ): + if v_ax1 == 0: + output_index[v_ax0, 0] = indices[v_ax0, 0] + elif ( + usample[v_ax0, T.int64(0)] + >= cumsum_sorted[v_ax0, v_ax1 - 1] / renorm_prob[v_ax0, 0] + ): + output_index[v_ax0, 0] = indices[v_ax0, v_ax1] + + cumsum_sorted = cumsum(sorted_prob, axis=1) + + renorm_prob = tensor_ir_op( + _get_renorm_prob, + "get_renorm_prob", + args=[cumsum_sorted, top_p, top_k], + out=Tensor.placeholder( + [batch, 1], + prob_dtype, + ), + ) + + out_index_in_sorted = tensor_ir_op( + _get_index_from_sorted, + "get_index_from_sorted", + args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index], + out=Tensor.placeholder([batch, 1], index_dtype), + ) + return out_index_in_sorted + + +def renormalize_top_p_top_k_prob(prob, sorted_prob, top_p, top_k): + """Renormalizes probabilities after filtering with top_p and top_k, ensuring + they sum up to 1. + + Notes + ----- + For accurate results, ensure probabilities are between 0 and 1 and sum to 1. + + Parameters + ---------- + prob : Tensor + A 2-D tensor of shape (batch, vocab_size) representing probability distributions. + + sorted_prob : Tensor + Probabilities sorted in descending order. + + top_p : Tensor + The cumulative probability threshold with shape (batch, 1) for nucleus sampling. + + top_k :Tensor + A tensor with shape (batch, 1), representing the number of top probabilities + to consider for top-k sampling. + + Returns + ------- + result : Tensor + The filtered and nomalized tensor with the sampe shape as input prob. + """ + prob_dtype = prob.dtype + top_k_dtype = top_k.dtype + batch = sorted_prob.shape[0] + + def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j): + return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0]) + + @T.prim_func(private=True) + def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle): + batch, vocab_size = T.int64(), T.int64() + sorted_prob = T.match_buffer(A, (batch, vocab_size), prob_dtype) + cumsum_sorted = T.match_buffer(B, (batch, vocab_size), prob_dtype) + top_p = T.match_buffer(C, (batch, 1), prob_dtype) + top_k = T.match_buffer(D, (batch, 1), top_k_dtype) + cutoff = T.match_buffer(E, (batch, 1), prob_dtype) + for ax0, ax1 in T.grid(batch, vocab_size): + with T.block("T_get_renorm_prob"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0: + cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0] + elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1) == 1: + if v_ax1 + 1 == vocab_size: + cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1] + elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1 + 1) == 0: + cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1 + 1] + + cumsum_sorted = cumsum(sorted_prob, axis=1) + + renorm_cutoff = tensor_ir_op( + _get_renorm_cutoff, + "get_renorm_cutoff", + args=[sorted_prob, cumsum_sorted, top_p, top_k], + out=Tensor.placeholder( + [batch, 1], + prob_dtype, + ), + ) + + filtered_prob = tensor_expr_op( + lambda prob, renorm_cutoff: te.compute( + prob.shape, + lambda i, j: _tir.Select(prob[i, j] >= renorm_cutoff[i, 0], prob[i, j], 0.0), + name="filter_with_top_p_top_k", + ), + "filter_with_top_p_top_k", + args=[prob, renorm_cutoff], + ) + renorm_prob = filtered_prob / sum(filtered_prob, axis=1, keepdims=True) + return renorm_prob diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index fccff2cecdd0..eecac5a8c2d2 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -495,6 +495,43 @@ int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) { TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb); +NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { + ICHECK(prob.IsContiguous()); + ICHECK(uniform_sample.IsContiguous()); + + if (prob->device.device_type != kDLCPU) { + prob = prob.CopyTo(DLDevice{kDLCPU, 0}); + } + if (uniform_sample->device.device_type != kDLCPU) { + uniform_sample = uniform_sample.CopyTo(DLDevice{kDLCPU, 0}); + } + + ICHECK(prob->device.device_type == kDLCPU); + ICHECK(uniform_sample->device.device_type == kDLCPU); + + int64_t batch_size = prob->shape[0]; + int64_t vocab_size = prob->shape[prob->ndim - 1]; + const float* pprob = static_cast(prob->data); + const float* psample = static_cast(uniform_sample->data); + NDArray new_array = NDArray::Empty({batch_size, 1}, DataType::Int(64), uniform_sample->device); + int64_t* parray = static_cast(new_array->data); + for (int64_t i = 0; i < batch_size; ++i) { + float cum_sum_prob = 0.0f; + int64_t prob_idx = 0; + for (int64_t j = 0; j < vocab_size; ++j) { + prob_idx = j; + cum_sum_prob += pprob[i * vocab_size + j]; + if (cum_sum_prob > psample[i]) { + break; + } + } + parray[i] = prob_idx; + } + return new_array; +} + +TVM_REGISTER_GLOBAL("vm.builtin.multinomial_from_uniform").set_body_typed(MultinomialFromUniform); + // This is an inplace operation. void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) { ICHECK(logits.IsContiguous()); diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 650d8ace303f..3457989a551f 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring, invalid-name +import numpy as np import tvm import tvm.testing from tvm import relax, tir @@ -61,11 +62,18 @@ def test(self, x: Tensor, y: Tensor): z4 = op.maximum(x, y) z5 = op.minimum(x, y) z6 = op.subtract(x, y) - return (z0, z1, z2, z3, z4, z5, z6) + z7 = op.greater(x, y) + z8 = op.greater_equal(x, y) + z9 = op.less(x, y) + z10 = op.less_equal(x, y) + z11 = op.equal(x, y) + z12 = op.not_equal(x, y) + + return (z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12) # fmt: off @R.function - def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")), R.Tuple(R.Object)): + def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1), dtype="float32"), _io: R.Object): R.func_attr({"num_input": 3}) with R.dataflow(): add: R.Tensor((10, 10), dtype="float32") = R.add(x, y) @@ -75,7 +83,13 @@ def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1), dtype="floa maximum: R.Tensor((10, 10), dtype="float32") = R.maximum(x, y) minimum: R.Tensor((10, 10), dtype="float32") = R.minimum(x, y) subtract: R.Tensor((10, 10), dtype="float32") = R.subtract(x, y) - gv1: R.Tuple(R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")), R.Tuple(R.Object)) = (add, mul, divide, matmul, maximum, minimum, subtract), (_io,) + greater: R.Tensor((10, 10), dtype="bool") = x > y + greater_equal: R.Tensor((10, 10), dtype="bool") = x >= y + less: R.Tensor((10, 10), dtype="bool") = x < y + less_equal: R.Tensor((10, 10), dtype="bool") = x <= y + equal: R.Tensor((10, 10), dtype="bool") = R.equal(x, y) + not_equal: R.Tensor((10, 10), dtype="bool") = R.not_equal(x, y) + gv1 = (add, mul, divide, matmul, maximum, minimum, subtract, greater, greater_equal, less, less_equal, equal, not_equal), (_io,) R.output(gv1) return gv1 # fmt: on @@ -829,5 +843,350 @@ def test(self): vm["test"](*effects) +@tvm.testing.requires_gpu +def test_multinomial_from_uniform(): + + prob_shape = (4, 5) + sample_shape = (4, 1) + + class Model(Module): + def foo(self, prob: Tensor, uniform_sample: Tensor): + z0 = op.multinomial_from_uniform(prob, uniform_sample) + return z0 + + # fmt: off + @I.ir_module + class Expected: + @T.prim_func(private=True) + def get_sample_index(A: T.handle, B: T.handle, C: T.handle): + batch, vocab_size = T.int64(), T.int64() + prob = T.match_buffer(A, (batch, vocab_size)) + usample = T.match_buffer(B, (batch, 1)) + output_index = T.match_buffer(C, (batch, 1), "int64") + # with T.block("root"): + for ax0, ax1 in T.grid(batch, vocab_size): + with T.block("T_get_sample_index"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(usample[v_ax0, T.int64(0)], prob[v_ax0, v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)]) + T.writes(output_index[v_ax0, 0]) + if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + T.int64(1) == vocab_size: + if v_ax1 == T.int64(0): + output_index[v_ax0, 0] = T.int64(0) + else: + if usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - T.int64(1)]: + output_index[v_ax0, 0] = v_ax1 + + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io: R.Object = R.null_value() + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv + R.output(gv) + return gv + + @R.function + def foo(prob: R.Tensor((4, 5), dtype="float32"), uniform_sample: R.Tensor((4, 1), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((4, 1), dtype="int64"), R.Tuple(R.Object)): + R.func_attr({"num_input": 3}) + cls = Expected + with R.dataflow(): + cumsum: R.Tensor((4, 5), dtype="float32") = R.cumsum(prob, axis=1, dtype="void", exclusive=False) + lv1 = R.call_tir(cls.get_sample_index, (cumsum, uniform_sample), out_sinfo=R.Tensor((4, 1), dtype="int64")) + gv1: R.Tuple(R.Tensor((4, 1), dtype="int64"), R.Tuple(R.Object)) = lv1, (_io,) + R.output(gv1) + return gv1 + # fmt: on + + m = Model() + mod, _ = m.export_tvm( + spec={ + "foo": { + "prob": spec.Tensor(prob_shape, "float32"), + "uniform_sample": spec.Tensor(sample_shape, "float32"), + } + }, + debug=True, + ) + + tvm.ir.assert_structural_equal(mod, Expected) + + target = tvm.target.Target("cuda -libs=thrust", host="llvm") + with target: + mod = tir.transform.DefaultGPUSchedule()(mod) + ex = relax.build(mod, target) + dev = tvm.cuda(0) + vm = relax.VirtualMachine(ex, dev) + + effects = vm["_initialize_effect"]() + + np_rand = np.random.rand(*prob_shape).astype(np.float32) + # normalize it to get the random prob + np_prob = np_rand / np_rand.sum(axis=1, keepdims=True) + nd_prob = tvm.nd.array(np_prob, dev) + # special sample to get deterministic results + nd_sample = tvm.nd.array(np.array([[1], [0], [0], [1]]).astype(np.float32), dev) + inputs = [nd_prob, nd_sample, effects] + res = vm["foo"](*inputs) + tvm.testing.assert_allclose(res[0].numpy(), np.array([[4], [0], [0], [4]]).astype(np.int64)) + + +@tvm.testing.requires_gpu +def test_sample_top_p_top_k_from_sorted_prob(): + prob_shape = (2, 3) + sample_shape = (2, 1) + + class Model(Module): + def foo( + self, prob: Tensor, index: Tensor, top_p: Tensor, top_k: Tensor, uniform_sample: Tensor + ): + z0 = op.sample_top_p_top_k_from_sorted_prob(prob, index, top_p, top_k, uniform_sample) + return z0 + + # fmt: off + @I.ir_module + class Expected: + @T.prim_func(private=True) + def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle): + batch, vocab_size = T.int64(), T.int64() + cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) + renorm_prob = T.match_buffer(B, (batch, 1)) + usample = T.match_buffer(C, (batch, 1)) + indices = T.match_buffer(D, (batch, vocab_size), "int64") + output_index = T.match_buffer(E, (batch, 1), "int64") + # with T.block("root"): + for ax0, ax1 in T.grid(batch, vocab_size): + with T.block("T_get_index_from_sorted"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads( + usample[v_ax0, T.int64(0)], + cumsum_sorted[v_ax0, v_ax1 - T.int64(1) : v_ax1 - T.int64(1) + T.int64(2)], + renorm_prob[v_ax0, 0], + indices[ + v_ax0, + T.min(T.int64(0), v_ax1) : T.min(T.int64(0), v_ax1) + + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1)), + ], + ) + T.writes(output_index[v_ax0, 0]) + if ( + usample[v_ax0, T.int64(0)] + < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0] + or v_ax1 + T.int64(1) == vocab_size + ): + if v_ax1 == T.int64(0): + output_index[v_ax0, 0] = indices[v_ax0, 0] + else: + if ( + usample[v_ax0, T.int64(0)] + >= cumsum_sorted[v_ax0, v_ax1 - T.int64(1)] / renorm_prob[v_ax0, 0] + ): + output_index[v_ax0, 0] = indices[v_ax0, v_ax1] + + @T.prim_func(private=True) + def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): + batch, vocab_size = T.int64(), T.int64() + cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) + top_p = T.match_buffer(B, (batch, 1)) + top_k = T.match_buffer(C, (batch, 1), "int64") + renorm_prob = T.match_buffer(D, (batch, 1)) + # with T.block("root"): + for ax0, ax1 in T.grid(batch, vocab_size): + with T.block("T_get_renorm_prob"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0]) + T.writes(renorm_prob[v_ax0, 0]) + if (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > T.int64(1)) == T.bool(False): + renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0] + else: + if (cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < top_k[v_ax0, 0]) == T.bool(True): + if v_ax1 + T.int64(1) == vocab_size: + renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1] + else: + if (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]) == T.bool(False): + renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] + + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io: R.Object = R.null_value() + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv + R.output(gv) + return gv + + @R.function + def foo( + prob: R.Tensor((2, 3), dtype="float32"), + index: R.Tensor((2, 3), dtype="int64"), + top_p: R.Tensor((2, 1), dtype="float32"), + top_k: R.Tensor((2, 1), dtype="int64"), + uniform_sample: R.Tensor((2, 1), dtype="float32"), + _io: R.Object, + ) -> R.Tuple(R.Tensor((2, 1), dtype="int64"), R.Tuple(R.Object)): + R.func_attr({"num_input": 6}) + cls = Expected + with R.dataflow(): + cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(prob, axis=1, dtype="void", exclusive=None) + lv1 = R.call_tir(cls.get_renorm_prob, (cumsum, top_p, top_k), out_sinfo=R.Tensor((2, 1), dtype="float32")) + lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, lv1, uniform_sample, index), out_sinfo=R.Tensor((2, 1), dtype="int64")) + gv1: R.Tuple(R.Tensor((2, 1), dtype="int64"), R.Tuple(R.Object)) = lv2, (_io,) + R.output(gv1) + return gv1 + # fmt: on + + m = Model() + mod, _ = m.export_tvm( + spec={ + "foo": { + "prob": spec.Tensor(prob_shape, "float32"), + "index": spec.Tensor(prob_shape, "int64"), + "top_p": spec.Tensor(sample_shape, "float32"), + "top_k": spec.Tensor(sample_shape, "int64"), + "uniform_sample": spec.Tensor(sample_shape, "float32"), + } + }, + debug=True, + ) + + tvm.ir.assert_structural_equal(mod, Expected) + + target = tvm.target.Target("cuda -libs=thrust", host="llvm") + with target: + mod = tir.transform.DefaultGPUSchedule()(mod) + + ex = relax.build(mod, target) + dev = tvm.cuda(0) + vm = relax.VirtualMachine(ex, dev) + + effects = vm["_initialize_effect"]() + sorted_prob = tvm.nd.array(np.array([[0.5, 0.4, 0.1], [0.4, 0.3, 0.3]]).astype(np.float32), dev) + indices = tvm.nd.array(np.array([[2, 1, 0], [2, 0, 1]]).astype(np.int64), dev) + top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev) + top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev) + usample = tvm.nd.array(np.array([[0.5], [0.6]]).astype(np.float32), dev) + + inputs = [sorted_prob, indices, top_p, top_k, usample, effects] + + res = vm["foo"](*inputs) + tvm.testing.assert_allclose(res[0].numpy(), np.array([[2], [0]]).astype(np.int64)) + + +@tvm.testing.requires_gpu +def test_renormalize_top_p_top_k_prob(): + prob_shape = (2, 3) + sample_shape = (2, 1) + + class Model(Module): + def foo( + self, + prob: Tensor, + sorted_prob: Tensor, + top_p: Tensor, + top_k: Tensor, + ): + z0 = op.renormalize_top_p_top_k_prob(prob, sorted_prob, top_p, top_k) + return z0 + + # fmt: off + @I.ir_module + class Expected: + @T.prim_func(private=True) + def filter_with_top_p_top_k(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), B: T.Buffer((T.int64(2), T.int64(1)), "float32"), filter_with_top_p_top_k: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i, j in T.grid(T.int64(2), T.int64(3)): + with T.block("filter_with_top_p_top_k"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(B[v_i, T.int64(0)], A[v_i, v_j]) + T.writes(filter_with_top_p_top_k[v_i, v_j]) + filter_with_top_p_top_k[v_i, v_j] = T.Select(B[v_i, T.int64(0)] <= A[v_i, v_j], A[v_i, v_j], T.float32(0)) + + @T.prim_func(private=True) + def get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle): + batch, vocab_size = T.int64(), T.int64() + sorted_prob = T.match_buffer(A, (batch, vocab_size)) + cumsum_sorted = T.match_buffer(B, (batch, vocab_size)) + top_p = T.match_buffer(C, (batch, 1)) + top_k = T.match_buffer(D, (batch, 1), "int64") + cutoff = T.match_buffer(E, (batch, 1)) + # with T.block("root"): + for ax0, ax1 in T.grid(batch, vocab_size): + with T.block("T_get_renorm_prob"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0], sorted_prob[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))]) + T.writes(cutoff[v_ax0, 0]) + if (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > T.int64(1)) == T.bool(False): + cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0] + else: + if (cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < top_k[v_ax0, 0]) == T.bool(True): + if v_ax1 + T.int64(1) == vocab_size: + cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1] + else: + if (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]) == T.bool(False): + cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1 + T.int64(1)] + + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io: R.Object = R.null_value() + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv + R.output(gv) + return gv + + @R.function + def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob: R.Tensor((2, 3), dtype="float32"), top_p: R.Tensor((2, 1), dtype="float32"), top_k: R.Tensor((2, 1), dtype="int64"), _io: R.Object) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tuple(R.Object)): + R.func_attr({"num_input": 5}) + cls = Expected + with R.dataflow(): + cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(sorted_prob, axis=1, dtype="void", exclusive=None) + lv1 = R.call_tir(cls.get_renorm_cutoff, (sorted_prob, cumsum, top_p, top_k), out_sinfo=R.Tensor((2, 1), dtype="float32")) + lv2 = R.call_tir(cls.filter_with_top_p_top_k, (prob, lv1), out_sinfo=R.Tensor((2, 3), dtype="float32")) + sum: R.Tensor((2, 1), dtype="float32") = R.sum(lv2, axis=[1], keepdims=True) + divide: R.Tensor((2, 3), dtype="float32") = R.divide(lv2, sum) + gv1: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tuple(R.Object)) = divide, (_io,) + R.output(gv1) + return gv1 + # fmt: on + + m = Model() + mod, _ = m.export_tvm( + spec={ + "foo": { + "prob": spec.Tensor(prob_shape, "float32"), + "sorted_prob": spec.Tensor(prob_shape, "float32"), + "top_p": spec.Tensor(sample_shape, "float32"), + "top_k": spec.Tensor(sample_shape, "int64"), + } + }, + debug=True, + ) + + tvm.ir.assert_structural_equal(mod, Expected) + + target = tvm.target.Target("cuda -libs=thrust", host="llvm") + with target: + mod = relax.backend.DispatchSortScan()(mod) + mod = relax.transform.LegalizeOps()(mod) + mod = tir.transform.DefaultGPUSchedule()(mod) + + ex = relax.build(mod, target) + dev = tvm.cuda(0) + vm = relax.VirtualMachine(ex, dev) + + effects = vm["_initialize_effect"]() + prob = tvm.nd.array(np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]]).astype(np.float32), dev) + sorted_prob = tvm.nd.array(np.array([[0.5, 0.3, 0.2], [0.4, 0.3, 0.3]]).astype(np.float32), dev) + top_p = tvm.nd.array(np.array([[0.6], [0.9]]).astype(np.float32), dev) + top_k = tvm.nd.array(np.array([[3], [2]]).astype(np.int64), dev) + + inputs = [prob, sorted_prob, top_p, top_k, effects] + + res = vm["foo"](*inputs) + tvm.testing.assert_allclose( + res[0].numpy(), np.array([[0, 0.375, 0.625], [0.3, 0.3, 0.4]]).astype(np.float32) + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_vm_builtin.py b/tests/python/relax/test_vm_builtin.py new file mode 100644 index 000000000000..f786f707aff0 --- /dev/null +++ b/tests/python/relax/test_vm_builtin.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + +import tvm +import tvm.script +import tvm.testing +from tvm import relax +from tvm.script import relax as R + + +def test_multinomial_from_uniform(): + @tvm.script.ir_module + class CallSample: + @R.function + def foo(x: R.Tensor((3, 5), "float32"), y: R.Tensor((3, 1), "float32")): + z = R.call_pure_packed( + "vm.builtin.multinomial_from_uniform", + x, + y, + sinfo_args=(R.Tensor((3, 1), dtype="int64")), + ) + return z + + mod = CallSample + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target) + np_rand = np.random.rand(3, 5).astype(np.float32) + # normalize it to get the random prob + np_prob = np_rand / np_rand.sum(axis=1, keepdims=True) + nd_prob = tvm.nd.array(np_prob) + # special sample to get deterministic results + nd_sample = tvm.nd.array(np.array([[1.0], [0], [1]]).astype(np.float32)) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["foo"](nd_prob, nd_sample) + tvm.testing.assert_allclose(res.numpy(), np.array([[4], [0], [4]]).astype(np.int64)) + + +if __name__ == "__main__": + tvm.testing.main()