Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KVCache] Support mode "None" for Rotary Embebdding #16580

Merged
merged 1 commit into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ struct Sequence {
/*!
* \brief The rotary embedding mode adopted by the paged KV cache
* when computing attention.
* "None" means RoPE is never applied to q and k.
* "Normal" means RoPE is computed in a standalone kernel.
* "Inline" means RoPE is computed on-the-fly in attention kernels.
*/
enum class RoPEMode : int {
kNormal = 0,
kInline = 1,
kNone = 0,
kNormal = 1,
kInline = 2,
};

/*!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import enum
from typing import Dict, List, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -322,7 +323,19 @@ def create_kv_cache(rope_mode):
return cache


@pytest.fixture(params=[0, 1])
class RopeMode(enum.IntEnum):
"""The RoPE mode of the Paged KV cache.
If it is none, the KV cache will not apply RoPE to q and k.
If it is normal, RoPE will be applied to k before adding k to cache.
Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
"""

NONE = 0
NORMAL = 1
INLINE = 2


@pytest.fixture(params=[RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE])
def kv_cache_and_rope_mode(request):
set_global_func()
return create_kv_cache(request.param), request.param
Expand Down Expand Up @@ -361,7 +374,7 @@ def f_apply_rotary(x, offset, scale, theta):

def apply_attention(
kv_cache,
rope_mode: int,
rope_mode: RopeMode,
batch: List[Tuple[Union[int, Tuple[int, int]], int]],
cached_k: Dict[int, np.ndarray],
cached_v: Dict[int, np.ndarray],
Expand Down Expand Up @@ -406,10 +419,12 @@ def apply_attention(
cached_k[seq_id],
np.stack(
[
new_k[l]
if rope_mode == 1
else f_apply_rotary(
new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta
(
new_k[l]
if rope_mode != RopeMode.NORMAL
else f_apply_rotary(
new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta
)
)
for l in range(num_layers)
],
Expand Down Expand Up @@ -445,15 +460,19 @@ def apply_attention(
assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length

rope_offset = cached_k[seq_id].shape[1] - append_length
q_seq = f_apply_rotary(
q_array[i][layer_id],
rope_offset,
rope_scale,
rope_theta,
q_seq = (
q_array[i][layer_id]
if rope_mode == RopeMode.NONE
else f_apply_rotary(
q_array[i][layer_id],
rope_offset,
rope_scale,
rope_theta,
)
).transpose(1, 0, 2)
k_seq = (
cached_k[seq_id][layer_id]
if rope_mode == 0
if rope_mode != RopeMode.INLINE
else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta)
).transpose(1, 2, 0)
v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
Expand Down Expand Up @@ -586,7 +605,7 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode, fuse_qkv):

if __name__ == "__main__":
set_global_func()
for rope_mode in [0, 1]:
for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
cache = create_kv_cache(rope_mode)
for fuse_qkv in [False, True]:
test_paged_attention_kv_cache_prefill_and_decode((cache, rope_mode), fuse_qkv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import enum
import itertools
import math
from typing import Dict, List, Tuple, Union
Expand Down Expand Up @@ -140,7 +141,25 @@ def create_kv_cache(head_dim, dtype, rope_mode):
return cache


@pytest.fixture(params=itertools.product([64, 128], ["float16", "float32"], [0, 1]))
class RopeMode(enum.IntEnum):
"""The RoPE mode of the Paged KV cache.
If it is none, the KV cache will not apply RoPE to q and k.
If it is normal, RoPE will be applied to k before adding k to cache.
Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
"""

NONE = 0
NORMAL = 1
INLINE = 2


@pytest.fixture(
params=itertools.product(
[64, 128],
["float16", "float32"],
[RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE],
)
)
def kv_cache_and_rope_mode(request):
global head_dim, dtype
head_dim, dtype, rope_mode = request.param
Expand Down Expand Up @@ -181,7 +200,7 @@ def f_apply_rotary(x, offset, scale, theta):

def apply_attention(
kv_cache,
rope_mode: int,
rope_mode: RopeMode,
batch: List[Tuple[Union[int, Tuple[int, int]], int]],
cached_k: Dict[int, np.ndarray],
cached_v: Dict[int, np.ndarray],
Expand Down Expand Up @@ -228,7 +247,7 @@ def apply_attention(
[
(
new_k[l]
if rope_mode == 1
if rope_mode != RopeMode.NORMAL
else f_apply_rotary(
new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta
)
Expand Down Expand Up @@ -267,15 +286,19 @@ def apply_attention(
assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length

rope_offset = cached_k[seq_id].shape[1] - append_length
q_seq = f_apply_rotary(
q_array[i][layer_id],
rope_offset,
rope_scale,
rope_theta,
q_seq = (
q_array[i][layer_id]
if rope_mode == RopeMode.NONE
else f_apply_rotary(
q_array[i][layer_id],
rope_offset,
rope_scale,
rope_theta,
)
).transpose(1, 0, 2)
k_seq = (
cached_k[seq_id][layer_id]
if rope_mode == 0
if rope_mode != RopeMode.INLINE
else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta)
).transpose(1, 2, 0)
v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
Expand Down Expand Up @@ -1639,7 +1662,7 @@ def merge_state_inplace(
if __name__ == "__main__":
for head_dim in [64, 128]:
for dtype in ["float16", "float32"]:
for rope_mode in [0, 1]:
for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
set_global_func(head_dim, dtype)
cache = create_kv_cache(head_dim, dtype, rope_mode)
for fuse_qkv in [False, True]:
Expand Down
Loading