Skip to content

Commit

Permalink
adapt for ChatGLM3
Browse files Browse the repository at this point in the history
  • Loading branch information
00INDEX committed Mar 21, 2024
1 parent 509e33b commit 554838e
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 14 deletions.
10 changes: 9 additions & 1 deletion parallel_tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .logger import get_logger
from .parallel_tokenizer import ParallelTokenizer, get_parallel_tokenizer
from .sp_tokenizer import SentencePieceTokenizer
from .special_cases import SPECIAL_KEYS_DICT, SPECIAL_TOKENIZERS_DICT

__all__ = ["get_parallel_tokenizer", "SentencePieceTokenizer", "ParallelTokenizer", "get_logger"]
__all__ = [
"get_parallel_tokenizer",
"SentencePieceTokenizer",
"ParallelTokenizer",
"get_logger",
"SPECIAL_KEYS_DICT",
"SPECIAL_TOKENIZERS_DICT",
]
9 changes: 6 additions & 3 deletions parallel_tokenizer/parallel_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .logger import get_logger
from .sp_tokenizer import SentencePieceTokenizer
from .special_cases import SPECIAL_KEYS_DICT
from .utils import chunks, flatten, match, merge, pairs, to_list

logger = get_logger(__name__)
Expand Down Expand Up @@ -90,10 +91,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:

if isinstance(shards[0], (dict, BatchEncoding)):
result = shards[0].__class__()
for key in self.concat_keys:
result[key] = merge([shard[key] for shard in shards], matches)
for key in shards[0].keys():
if key not in self.concat_keys:
if key in SPECIAL_KEYS_DICT:
result[key] = SPECIAL_KEYS_DICT[key]([shard[key] for shard in shards], matches=matches)
elif key in self.concat_keys:
result[key] = merge([shard[key] for shard in shards], matches)
else:
result[key] = [shard[key] for shard in shards]
else:
result = merge(shards, matches)
Expand Down
24 changes: 24 additions & 0 deletions parallel_tokenizer/special_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import List, Tuple, Union

import numpy as np
import torch

from .logger import get_logger
from .utils import arange_like, get_size

logger = get_logger(__name__)


def position_ids(shards: List[Union[torch.Tensor, np.ndarray, List[int]]], matches: List[Tuple[int]]):
seqlen = sum(
[
get_size(x[0])[-1] - x[1][1] - ((get_size(x[0])[-1] - x[1][0]) % get_size(x[0])[-1])
for x in zip(shards, [matches[i : i + 2] for i in range(0, len(matches), 2)])
]
)
return arange_like(shards[0], start=0, end=seqlen)


SPECIAL_KEYS_DICT = {"position_ids": position_ids}

SPECIAL_TOKENIZERS_DICT = {}
34 changes: 34 additions & 0 deletions parallel_tokenizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,37 @@ def get_size(item: Union[torch.Tensor, np.ndarray, List[int]]) -> Tuple[int]:
return tuple(shape)
else:
raise ValueError(f"Unsupported type {type(item)} for get_size")


def add_scalar(
item: Union[torch.Tensor, np.ndarray, List[int]], scalar: int
) -> Union[torch.Tensor, np.ndarray, List[int]]:
if isinstance(item, (torch.Tensor, np.ndarray)):
return item + scalar
elif isinstance(item, List):
return [add_scalar(subitem, scalar) for subitem in item]
elif isinstance(item, int):
return item + scalar
else:
raise ValueError(f"Unsupported type {type(item)} for add_scalar")


def arange_like(
item: Union[torch.Tensor, np.ndarray, List[int]], start: int, end: int
) -> Union[torch.Tensor, np.ndarray, List[int]]:
ndim = get_ndim(item)
if isinstance(item, torch.Tensor):
item = torch.arange(start, end, device=item.device, dtype=item.dtype)
while get_ndim(item) < ndim:
item = item.unsqueeze(0)
elif isinstance(item, np.ndarray):
item = np.arange(start, end)
while get_ndim(item) < ndim:
item = item[np.newaxis]
elif isinstance(item, List):
item = list(range(start, end))
while get_ndim(item) < ndim:
item = [item]
else:
raise ValueError(f"Unsupported type {type(item)} for arange_like")
return item
16 changes: 6 additions & 10 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import pytest
import random

import pytest
import torch
from transformers import AutoTokenizer
from wonderwords import RandomWord

from parallel_tokenizer import ParallelTokenizer


TEST_MODELS = [
"internlm/internlm2-7b",
# "meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-7b-chat-hf",
"baichuan-inc/Baichuan2-7B-Chat",
"mistralai/Mistral-7B-Instruct-v0.2",
# "google/gemma-7b-it",
"Qwen/Qwen1.5-7B-Chat",
"google/gemma-7b-it",
"Qwen/Qwen1.5-72B-Chat",
"THUDM/chatglm3-6b",
]
TEST_LENGTHS = [8192, 16384]
Expand All @@ -26,11 +25,7 @@
@pytest.mark.parametrize("return_tensors", [None, "pt"])
@pytest.mark.parametrize("batch", [False])
def test_call(
model_name_or_path: str,
sentence_length: int,
add_special_tokens: bool,
return_tensors: str or None,
batch: bool
model_name_or_path: str, sentence_length: int, add_special_tokens: bool, return_tensors: str or None, batch: bool
):
random.seed(1024)
r = RandomWord()
Expand All @@ -53,6 +48,7 @@ def test_call(
ret_parallel = parallel_tokenizer(input_text, add_special_tokens=add_special_tokens, return_tensors=return_tensors)

for k in ret_hf:

if isinstance(ret_hf[k], list):
assert ret_hf[k] == ret_parallel[k], f"{k} is not equal"
elif isinstance(ret_hf[k], torch.Tensor):
Expand Down

0 comments on commit 554838e

Please sign in to comment.