From a0bf926440472ce7966fce23009fcad5ad7cba08 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 23 Feb 2024 10:35:47 -0500 Subject: [PATCH] feat: add token counting to BaseLM --- dspy/backends/lm/base.py | 5 +++++ dspy/backends/lm/litellm.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/dspy/backends/lm/base.py b/dspy/backends/lm/base.py index 76fd9b09d..911a7ca06 100644 --- a/dspy/backends/lm/base.py +++ b/dspy/backends/lm/base.py @@ -15,3 +15,8 @@ def __call__( ) -> list[str]: """Generates `n` predictions for the signature output.""" ... + + @abstractmethod + def count_tokens(self, prompt: str) -> int: + """Counts the number of tokens for a specific prompt.""" + ... diff --git a/dspy/backends/lm/litellm.py b/dspy/backends/lm/litellm.py index ac9e884a6..d647622b5 100644 --- a/dspy/backends/lm/litellm.py +++ b/dspy/backends/lm/litellm.py @@ -1,6 +1,6 @@ import typing as t -from litellm import completion +from litellm import completion, token_counter from pydantic import Field @@ -8,7 +8,7 @@ class LiteLLM(BaseLM): - STANDARD_PARAMS = { + STANDARD_PARAMS: t.Dict[str, t.Union[float, int]] = { "temperature": 0.0, "max_tokens": 150, "top_p": 1, @@ -33,3 +33,9 @@ def __call__( ) choices = [c for c in response["choices"] if c["finish_reason"] != "length"] return [c["message"]["content"] for c in choices] + + def count_tokens(self, prompt: str) -> int: + """Counts the number of tokens for a specific prompt.""" + return token_counter( + model=self.model, messages=[{"role": "user", "content": prompt}] + )