Skip to content

Commit

Permalink
[Fix][Relax] Fix top-p/top-k sampling kernel
Browse files Browse the repository at this point in the history
This PR fixes a typo in the samping kernel of top-p/top-k sampling op.
Prior to this PR, the kernel has out-of-bound global memory access
due to a miss when introducing `sample_indices` in apache#16675.

The correctness pass did not reveal this issue by directly running
the test or running through pytest. But actually, if we use
compute-sanitizer from NVIDIA, it will report the illegal memory
access:
```
> compute-sanitizer --tool memcheck --print-limit=5 --launch-timeout 3600 python tests/python/relax/test_frontend_nn_op.py
========= COMPUTE-SANITIZER
========= Invalid __global__ read of size 8 bytes
=========     at 0x4e90 in get_index_from_sorted_kernel
=========     by thread (7,0,0) in block (0,0,0)
=========     Address 0x7fe35ac00238 is out of bounds
=========     and is 9 bytes after the nearest allocation at 0x7fe35ac00200 of size 48 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
...
```
  • Loading branch information
MasterJH5574 committed Mar 12, 2024
1 parent e051945 commit 62fc412
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

import numpy as np

from tvm import te
from tvm import tir as _tir
from tvm.script import tir as T
from tvm import te

from ... import expr as rx
from ... import op as _op
Expand Down Expand Up @@ -2386,13 +2386,13 @@ def _get_index_from_sorted(
or v_ax1 + 1 == vocab_size
):
if v_ax1 == 0:
output_index[v_ax0, 0] = indices[v_ax0, 0]
output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], 0]
elif (
usample[v_ax0, T.int64(0)]
>= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
/ renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
):
output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], v_ax1]

cumsum_sorted = cumsum(sorted_prob, axis=1)

Expand Down
6 changes: 3 additions & 3 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,14 +973,14 @@ def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E:
for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_index_from_sorted"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[v_ax0, T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))])
T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[sample_indices[v_ax0, T.int64(0)], T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))])
T.writes(output_index[v_ax0, 0])
if usample[v_ax0, T.int64(0)] < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) == vocab_size:
if v_ax1 == T.int64(0):
output_index[v_ax0, 0] = indices[v_ax0, 0]
output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], 0]
else:
if usample[v_ax0, T.int64(0)] >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]:
output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], v_ax1]

@T.prim_func(private=True)
def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
Expand Down

0 comments on commit 62fc412

Please sign in to comment.