Skip to content

Commit

Permalink
feat(databricks/chat/transformation.py): add tools and 'tool_choice' …
Browse files Browse the repository at this point in the history
…param support (BerriAI#8076)

* feat(databricks/chat/transformation.py): add tools and 'tool_choice' param support

Closes BerriAI#7788

* refactor: cleanup redundant file

* test: mark flaky test

* test: mark all parallel request tests as flaky
  • Loading branch information
krrishdholakia authored Jan 30, 2025
1 parent 9fa44a4 commit ba8ba9e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
2 changes: 2 additions & 0 deletions litellm/llms/databricks/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def get_supported_openai_params(self, model: Optional[str] = None) -> list:
"max_completion_tokens",
"n",
"response_format",
"tools",
"tool_choice",
]

def _should_fake_stream(self, optional_params: dict) -> bool:
Expand Down
82 changes: 82 additions & 0 deletions tests/local_testing/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,85 @@ async def test_watsonx_tool_choice(sync_mode):
pytest.skip("Skipping test due to timeout")
else:
raise e


@pytest.mark.asyncio
async def test_function_calling_with_dbrx():
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler

client = AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_completion:
try:
resp = await litellm.acompletion(
model="databricks/databricks-dbrx-instruct",
messages=[
{
"role": "system",
"content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.",
},
{
"role": "user",
"content": "Hi, can you tell me the delivery date for my order?",
},
{
"role": "assistant",
"content": "Hi there! I can help with that. Can you please provide your order ID?",
},
{
"role": "user",
"content": "i think it is order_12345, also what is the weather in Phoenix, AZ?",
},
],
tools=[
{
"type": "function",
"function": {
"name": "get_delivery_date",
"description": "Get the delivery date for a customer'''s order. Call this whenever you need to know the delivery date, for example when a customer asks '''Where is my package'''",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "The customer'''s order ID.",
}
},
"required": ["order_id"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "check_weather",
"description": "Check the current weather in a location. For example when asked: '''What is the temperature in San Fransisco, CA?'''",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to check the weather for.",
},
"state": {
"type": "string",
"description": "The state to check the weather for.",
},
},
"required": ["city", "state"],
"additionalProperties": False,
},
},
},
],
client=client,
tool_choice="auto",
)
except Exception as e:
print(e)

mock_completion.assert_called_once()
print(mock_completion.call_args.kwargs)
json_data = json.loads(mock_completion.call_args.kwargs["data"])
assert "tools" in json_data
assert "tool_choice" in json_data
8 changes: 8 additions & 0 deletions tests/local_testing/test_parallel_request_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ async def test_failure_call_hook():
"""


@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_normal_router_call():
model_list = [
Expand Down Expand Up @@ -528,6 +529,7 @@ async def test_normal_router_call():
)


@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_normal_router_tpm_limit():
import logging
Expand Down Expand Up @@ -615,6 +617,7 @@ async def test_normal_router_tpm_limit():
assert e.status_code == 429


@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_streaming_router_call():
model_list = [
Expand Down Expand Up @@ -690,6 +693,7 @@ async def test_streaming_router_call():
)


@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_streaming_router_tpm_limit():
litellm.set_verbose = True
Expand Down Expand Up @@ -845,6 +849,7 @@ async def test_bad_router_call():
)


@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_bad_router_tpm_limit():
model_list = [
Expand Down Expand Up @@ -923,6 +928,7 @@ async def test_bad_router_tpm_limit():
)


@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_bad_router_tpm_limit_per_model():
model_list = [
Expand Down Expand Up @@ -1023,6 +1029,7 @@ async def test_bad_router_tpm_limit_per_model():
)


@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_pre_call_hook_rpm_limits_per_model():
"""
Expand Down Expand Up @@ -1101,6 +1108,7 @@ async def test_pre_call_hook_rpm_limits_per_model():
)


@pytest.mark.flaky(retries=6, delay=2)
@pytest.mark.asyncio
async def test_pre_call_hook_tpm_limits_per_model():
"""
Expand Down

0 comments on commit ba8ba9e

Please sign in to comment.