Skip to content

Commit

Permalink
Fix streaming tests (#2032)
Browse files Browse the repository at this point in the history
* Fix streaming tests

* Updated lockfile

* Update poetry version
  • Loading branch information
CyrusNuevoDia authored Jan 10, 2025
1 parent 238e312 commit fea2d38
Show file tree
Hide file tree
Showing 10 changed files with 1,047 additions and 789 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
types: [opened, synchronize, reopened]

env:
POETRY_VERSION: "1.7.1"
POETRY_VERSION: "2.0.0"

jobs:
fix:
Expand Down
3 changes: 2 additions & 1 deletion dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):

try:
provider = lm.model.split("/", 1)[0] or "openai"
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
params = litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider)
if params and "response_format" in params:
try:
response_format = _get_structured_outputs_response_format(signature)
outputs = lm(**inputs, **lm_kwargs, response_format=response_format)
Expand Down
4 changes: 2 additions & 2 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ def cache_key(request: Dict[str, Any]) -> str:

def transform_value(value):
if isinstance(value, type) and issubclass(value, pydantic.BaseModel):
return value.schema()
return value.model_json_schema()
elif isinstance(value, pydantic.BaseModel):
return value.dict()
return value.model_dump()
elif callable(value) and hasattr(value, "__code__") and hasattr(value.__code__, "co_code"):
return value.__code__.co_code.decode("utf-8")
else:
Expand Down
2 changes: 1 addition & 1 deletion dspy/utils/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator:
if isinstance(value, Prediction):
data = {"prediction": {k: v for k, v in value.items(include_dspy=False)}}
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, litellm.ModelResponse):
elif isinstance(value, litellm.ModelResponseStream):
data = {"chunk": value.json()}
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, str) and value.startswith("data:"):
Expand Down
1,800 changes: 1,028 additions & 772 deletions poetry.lock

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ requires = ["setuptools>=40.8.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
# Do not add spaces around the '=' sign for any of the fields
# preceeded by a marker comment as it affects the publish workflow.
# Do not add spaces around the '=' sign for any of the fields
# preceeded by a marker comment as it affects the publish workflow.
#replace_package_name_marker
name="dspy"
name = "dspy"
#replace_package_version_marker
version="2.6.0rc8"
version = "2.6.0rc8"
description = "DSPy"
readme = "README.md"
authors = [{ name = "Omar Khattab", email = "[email protected]" }]
Expand Down Expand Up @@ -40,7 +40,7 @@ dependencies = [
"pydantic~=2.0",
"jinja2",
"magicattr~=0.1.6",
"litellm==1.55.3",
"litellm==1.57.4",
"diskcache",
"json-repair",
"tenacity>=8.2.3",
Expand Down Expand Up @@ -134,7 +134,7 @@ pgvector = { version = "^0.2.5", optional = true }
llama-index = { version = "^0.10.30", optional = true }
jinja2 = "^3.1.3"
magicattr = "^0.1.6"
litellm = { version = "==1.55.3", extras = ["proxy"] }
litellm = { version = "==1.57.4", extras = ["proxy"] }
diskcache = "^5.6.0"
json-repair = "^0.30.0"
tenacity = ">=8.2.3"
Expand Down Expand Up @@ -231,7 +231,8 @@ target-version = "py39"
select = [
"F", # Pyflakes
"E", # Pycodestyle
"TID252", # Absolute imports
"TID252", # Absolute imports

]
ignore = [
"E501", # Line too long
Expand Down
Empty file added tests/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions tests/reliability/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import pytest

import dspy
from tests.conftest import clear_settings
from tests.reliability.utils import get_adapter, parse_reliability_conf_yaml
from ..conftest import clear_settings
from ..reliability.utils import get_adapter, parse_reliability_conf_yaml

# Standard list of models that should be used for periodic DSPy reliability testing
MODEL_LIST = [
Expand Down
Empty file added tests/utils/__init__.py
Empty file.
6 changes: 3 additions & 3 deletions tests/utils/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import dspy
from dspy.utils.streaming import streaming_response
from tests.test_utils.server import litellm_test_server
from ..test_utils.server import litellm_test_server


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_streamify_yields_expected_response_chunks(litellm_test_server):
api_base, _ = litellm_test_server
lm = dspy.LM(
Expand Down Expand Up @@ -37,7 +37,7 @@ class TestSignature(dspy.Signature):
assert last_chunk2.output_text == "Hello!"


@pytest.mark.asyncio
@pytest.mark.anyio
async def test_streaming_response_yields_expected_response_chunks(litellm_test_server):
api_base, _ = litellm_test_server
lm = dspy.LM(
Expand Down

0 comments on commit fea2d38

Please sign in to comment.