Skip to content

Commit

Permalink
Fix Bedrock token count and IDs for Anthropic models (#341)
Browse files Browse the repository at this point in the history
Fixes #314

The `get_num_tokens()` and `get_token_ids()` methods of `BedrockLLM` and
`ChatBedrock` currently fail when used with Anthropic models and
`anthropic>=0.39.0` installed.

In such scenarios, this PR fixes `get_num_tokens()` and
`get_token_ids()` by falling back to the base class implementations [in
`BaseLanguageModel`](https://python.langchain.com/api_reference/core/language_models/langchain_core.language_models.base.BaseLanguageModel.html#langchain_core.language_models.base.BaseLanguageModel.get_num_tokens)
and [in
`BaseChatModel`](https://python.langchain.com/api_reference/core/language_models/langchain_core.language_models.chat_models.BaseChatModel.html#langchain_core.language_models.chat_models.BaseChatModel.get_num_tokens),
which use the HuggingFace [GPT2
Tokenizer](https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2TokenizerFast).

If `anthropic<=0.38.0` (and other requirements) are present instead, the
Anthropic SDK token methods will continue to be used as normal.

**Note about tokenizer accuracy:** 

The GPT2 and Anthropic SDK tokenizers (see
[here](https://github.com/anthropics/anthropic-sdk-python/pull/726/files#diff-9595fafc42ceb1044adb6f4f2a93774f58704ee22b6f9d40ad9d0a336c118d39L540-L542))
both may produce inaccurate token estimates for Claude 3 and 3.5 models.
To obtain a more accurate estimate, Anthropic recommends using the new
[Count Message Tokens
API](https://docs.anthropic.com/en/api/messages-count-tokens). For
example:
```
import os
import anthropic

os.environ["ANTHROPIC_API_KEY"] = "<your-api-key>"
anthropic.Anthropic().messages.count_tokens(
    model="claude-3-5-sonnet-20241022",
    messages=[
        {"role": "user", "content": "Hello, world"}
    ]
)
```

As another alternative, you can implement your own token counter method,
and pass this using
[`custom_get_token_ids`](https://python.langchain.com/api_reference/core/language_models/langchain_core.language_models.base.BaseLanguageModel.html#langchain_core.language_models.base.BaseLanguageModel.custom_get_token_ids)
when initializing the model. This will override both the GPT2 and
Anthropic tokenizers.

---------

Co-authored-by: Piyush Jain <[email protected]>
  • Loading branch information
michaelnchin and 3coins authored Feb 7, 2025
1 parent c5ec714 commit 6355b0f
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 27 deletions.
27 changes: 20 additions & 7 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import re
import warnings
from collections import defaultdict
from operator import itemgetter
from typing import (
Expand Down Expand Up @@ -51,6 +52,7 @@
_combine_generation_info_for_llm_result,
)
from langchain_aws.utils import (
anthropic_tokens_supported,
get_num_tokens_anthropic,
get_token_ids_anthropic,
)
Expand Down Expand Up @@ -620,16 +622,27 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return final_output

def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
if (
self._model_is_anthropic
and not self.custom_get_token_ids
and anthropic_tokens_supported()
):
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
return super().get_num_tokens(text)

def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)
if self._model_is_anthropic and not self.custom_get_token_ids:
if anthropic_tokens_supported():
return get_token_ids_anthropic(text)
else:
warnings.warn(
f"Falling back to default token method due to missing or incompatible `anthropic` installation "
f"(needs <=0.38.0).\n\nIf using `anthropic>0.38.0`, it is recommended to provide the model "
f"class with a custom_get_token_ids method implementing a more accurate tokenizer for Anthropic. "
f"For get_num_tokens, as another alternative, you can implement your own token counter method "
f"using the ChatAnthropic or AnthropicLLM classes."
)
return super().get_token_ids(text)

def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None:
"""Workaround to bind. Sets the system prompt with tools"""
Expand Down
25 changes: 17 additions & 8 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from langchain_aws.function_calling import _tools_in_params
from langchain_aws.utils import (
anthropic_tokens_supported,
enforce_stop_tokens,
get_num_tokens_anthropic,
get_token_ids_anthropic,
Expand Down Expand Up @@ -1301,13 +1302,21 @@ async def _acall(
return "".join([chunk.text for chunk in chunks])

def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)
else:
return super().get_num_tokens(text)
if self._model_is_anthropic and not self.custom_get_token_ids:
if anthropic_tokens_supported():
return get_num_tokens_anthropic(text)
return super().get_num_tokens(text)

def get_token_ids(self, text: str) -> List[int]:
if self._model_is_anthropic:
return get_token_ids_anthropic(text)
else:
return super().get_token_ids(text)
if self._model_is_anthropic and not self.custom_get_token_ids:
if anthropic_tokens_supported():
return get_token_ids_anthropic(text)
else:
warnings.warn(
f"Falling back to default token method due to missing or incompatible `anthropic` installation "
f"(needs <=0.38.0).\n\nFor `anthropic>0.38.0`, it is recommended to provide the model "
f"class with a custom_get_token_ids method implementing a more accurate tokenizer for Anthropic. "
f"For get_num_tokens, as another alternative, you can implement your own token counter method "
f"using the ChatAnthropic or AnthropicLLM classes."
)
return super().get_token_ids(text)
29 changes: 23 additions & 6 deletions libs/aws/langchain_aws/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
import re
from typing import Any, List

from packaging import version


def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]


def _get_anthropic_client() -> Any:
def anthropic_tokens_supported() -> bool:
"""Check if we have all requirements for Anthropic count_tokens() and get_tokenizer()."""
try:
import anthropic
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
"This is needed in order to accurately tokenize the text "
"for anthropic models. Please install it with `pip install anthropic`."
)
return False

if version.parse(anthropic.__version__) > version.parse("0.38.0"):
return False

try:
import httpx

if version.parse(httpx.__version__) > version.parse("0.27.2"):
raise ImportError()
except ImportError:
raise ImportError("httpx<=0.27.2 is required.")

return True


def _get_anthropic_client() -> Any:
import anthropic

return anthropic.Anthropic()


Expand Down
8 changes: 2 additions & 6 deletions libs/aws/tests/unit_tests/retrievers/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,7 @@ def test_when_get_content_from_result_then_get_expected_content(
search_result_input, expected_output
):
assert (
AmazonKnowledgeBasesRetriever._get_content_from_result(
search_result_input
)
AmazonKnowledgeBasesRetriever._get_content_from_result(search_result_input)
== expected_output
)

Expand All @@ -518,9 +516,7 @@ def test_when_get_content_from_result_with_invalid_content_then_raise_error(
search_result_input,
):
with pytest.raises(ValueError):
AmazonKnowledgeBasesRetriever._get_content_from_result(
search_result_input
)
AmazonKnowledgeBasesRetriever._get_content_from_result(search_result_input)


def set_return_value_and_query(
Expand Down

0 comments on commit 6355b0f

Please sign in to comment.