Skip to content

Commit

Permalink
mistral, openai: support custom tokenizers in chat models (langchain-…
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme authored and pprados committed Apr 26, 2024
1 parent 97ba189 commit 4417aa1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
10 changes: 9 additions & 1 deletion libs/partners/mistralai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test MistralAI Chat API wrapper."""

import os
from typing import Any, AsyncGenerator, Dict, Generator, cast
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -190,3 +190,11 @@ def test__convert_dict_to_message_tool_call() -> None:
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message


def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
return [1, 2, 3]

llm = ChatMistralAI(custom_get_token_ids=token_encoder)
assert llm.get_token_ids("foo") == [1, 2, 3]
2 changes: 2 additions & 0 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,8 @@ def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:

def get_token_ids(self, text: str) -> List[int]:
"""Get the tokens present in the text with tiktoken package."""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python 3.7 or below
if sys.version_info[1] <= 7:
return super().get_token_ids(text)
Expand Down
10 changes: 9 additions & 1 deletion libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test OpenAI Chat API wrapper."""

import json
from typing import Any
from typing import Any, List
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand Down Expand Up @@ -279,3 +279,11 @@ def test_openai_invoke_name(mock_completion: dict) -> None:
# check return type has name
assert res.content == "Bar Baz"
assert res.name == "Erick"


def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
return [1, 2, 3]

llm = ChatOpenAI(custom_get_token_ids=token_encoder)
assert llm.get_token_ids("foo") == [1, 2, 3]

0 comments on commit 4417aa1

Please sign in to comment.