Skip to content

Commit

Permalink
Fix patching of NeMo tokenizers for correct Lambada evaluation (#11326)
Browse files Browse the repository at this point in the history
* Fix patching of NeMo tokenizers for correct Lambada evaluation

Signed-off-by: Jan Lasek <[email protected]>

* Apply isort and black reformatting

Signed-off-by: janekl <[email protected]>

---------

Signed-off-by: Jan Lasek <[email protected]>
Signed-off-by: janekl <[email protected]>
Co-authored-by: janekl <[email protected]>
  • Loading branch information
janekl and janekl authored Nov 20, 2024
1 parent 4b93e7f commit 341580e
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,24 +317,31 @@ def build_tokenizer(tokenizer):
if tokenizer.eos_token_id is None:
tokenizer.add_special_tokens({"eos_token": "</s>"})
else:
try:
# If NeMo tokenizer, monkey patch interface
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec

if isinstance(tokenizer, TokenizerSpec):

def batch_encode_patch(self, ids):
# For NeMo tokenizers, monkey patch encode & batch_decode methods for unified interface
from nemo.collections.common.tokenizers import AutoTokenizer, SentencePieceTokenizer, TokenizerSpec

if isinstance(tokenizer, TokenizerSpec):
if isinstance(tokenizer, AutoTokenizer):
# Unwrap the original methods of HF tokenizer
batch_decode = tokenizer.tokenizer.batch_decode
encode = tokenizer.tokenizer.encode
elif isinstance(tokenizer, SentencePieceTokenizer):
# Define HF equivalents based on available SP methods
def batch_decode(self, ids):
if torch.is_tensor(ids):
ids = ids.cpu().numpy()
ids = ids[0] if len(ids.shape) > 1 else ids
return self.ids_to_text(ids)

tokenizer.bos_token_id = tokenizer.bos_id
tokenizer.eos_token_id = tokenizer.eos_id
tokenizer.encode = tokenizer.text_to_ids
TokenizerSpec.batch_decode = batch_encode_patch
except:
raise TypeError(f'Unsupported tokenizer build input: {type(tokenizer)}')
if isinstance(ids, np.ndarray):
ids = ids.tolist()
return self.tokenizer.decode(ids)

encode = tokenizer.tokenizer.encode_as_ids
else:
raise NotImplementedError(f"Patching tokenizer methods for {type(tokenizer)} is not available")

tokenizer.bos_token_id = tokenizer.bos_id
tokenizer.eos_token_id = tokenizer.eos_id
TokenizerSpec.encode = encode
TokenizerSpec.batch_decode = batch_decode

return tokenizer

Expand Down

0 comments on commit 341580e

Please sign in to comment.