Skip to content

Commit

Permalink
added AsyncLM
Browse files Browse the repository at this point in the history
  • Loading branch information
itay-verkh-lightricks committed Jan 1, 2025
1 parent b652c8e commit 671a6a1
Show file tree
Hide file tree
Showing 7 changed files with 533 additions and 18 deletions.
2 changes: 2 additions & 0 deletions dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dspy.clients.lm import LM
from dspy.clients.async_lm import AsyncLM
from dspy.clients.provider import Provider, TrainingJob
from dspy.clients.base_lm import BaseLM, inspect_history
from dspy.clients.embedding import Embedder
Expand Down Expand Up @@ -37,6 +38,7 @@ def disable_litellm_logging():

__all__ = [
"LM",
"AsyncLM",
"Provider",
"TrainingJob",
"BaseLM",
Expand Down
138 changes: 138 additions & 0 deletions dspy/clients/async_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
from typing import Any, Awaitable, Dict, cast

import litellm
from anyio.streams.memory import MemoryObjectSendStream
from litellm.types.router import RetryPolicy

import dspy
from dspy.clients.lm import LM, request_cache
from dspy.utils import with_callbacks


class AsyncLM(LM):
@with_callbacks
def __call__(self, prompt=None, messages=None, **kwargs) -> Awaitable:
async def _async_call(prompt, messages, **kwargs):
# Build the request.
cache = kwargs.pop("cache", self.cache)
messages = messages or [{"role": "user", "content": prompt}]
kwargs = {**self.kwargs, **kwargs}

# Make the request and handle LRU & disk caching.
if self.model_type == "chat":
completion = cached_litellm_completion if cache else litellm_acompletion
else:
completion = cached_litellm_text_completion if cache else litellm_text_acompletion

response = await completion(
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
)
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
self._log_entry(prompt, messages, kwargs, response, outputs)
return outputs

return _async_call(prompt, messages, **kwargs)


@request_cache(maxsize=None)
async def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
return await litellm_acompletion(
request,
cache={"no-cache": False, "no-store": False},
num_retries=num_retries,
)


async def litellm_acompletion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
retry_kwargs = dict(
retry_policy=_get_litellm_retry_policy(num_retries),
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
# to completion()), the default value of max_retries is non-zero for certain providers, and
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
max_retries=0,
)

stream = dspy.settings.send_stream
if stream is None:
return await litellm.acompletion(
cache=cache,
**retry_kwargs,
**request,
)

# The stream is already opened, and will be closed by the caller.
stream = cast(MemoryObjectSendStream, stream)

async def stream_completion():
response = await litellm.acompletion(
cache=cache,
stream=True,
**retry_kwargs,
**request,
)
chunks = []
async for chunk in response:
chunks.append(chunk)
await stream.send(chunk)
return litellm.stream_chunk_builder(chunks)

return await stream_completion()


@request_cache(maxsize=None)
async def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
return await litellm_text_acompletion(
request,
num_retries=num_retries,
cache={"no-cache": False, "no-store": False},
)


async def litellm_text_acompletion(
request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}
):
# Extract the provider and model from the model string.
# TODO: Not all the models are in the format of "provider/model"
model = request.pop("model").split("/", 1)
provider, model = model[0] if len(model) > 1 else "openai", model[-1]

# Use the API key and base from the request, or from the environment.
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

# Build the prompt from the messages.
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])

return await litellm.atext_completion(
cache=cache,
model=f"text-completion-openai/{model}",
api_key=api_key,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
**request,
)


def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy:
"""
Get a LiteLLM retry policy for retrying requests when transient API errors occur.
Args:
num_retries: The number of times to retry a request if it fails transiently due to
network error, rate limiting, etc. Requests are retried with exponential
backoff.
Returns:
A LiteLLM RetryPolicy instance.
"""
return RetryPolicy(
TimeoutErrorRetries=num_retries,
RateLimitErrorRetries=num_retries,
InternalServerErrorRetries=num_retries,
ContentPolicyViolationErrorRetries=num_retries,
# We don't retry on errors that are unlikely to be transient
# (e.g. bad request, invalid auth credentials)
BadRequestErrorRetries=0,
AuthenticationErrorRetries=0,
)
6 changes: 4 additions & 2 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def __call__(self, prompt=None, messages=None, **kwargs):
else:
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]

self._log_entry(prompt, messages, kwargs, response, outputs)
return outputs

def _log_entry(self, prompt, messages, kwargs, response, outputs):
# Logging, with removed api key & where `cost` is None on cache hit.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
Expand All @@ -129,8 +133,6 @@ def __call__(self, prompt=None, messages=None, **kwargs):
self.history.append(entry)
self.update_global_history(entry)

return outputs

def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
launch_kwargs = launch_kwargs or self.launch_kwargs
self.provider.launch(self.model, launch_kwargs)
Expand Down
Loading

0 comments on commit 671a6a1

Please sign in to comment.