diff --git a/litellm/llms/databricks/chat/transformation.py b/litellm/llms/databricks/chat/transformation.py index b1f79d565b29..7e5c1f6c23da 100644 --- a/litellm/llms/databricks/chat/transformation.py +++ b/litellm/llms/databricks/chat/transformation.py @@ -73,6 +73,8 @@ def get_supported_openai_params(self, model: Optional[str] = None) -> list: "max_completion_tokens", "n", "response_format", + "tools", + "tool_choice", ] def _should_fake_stream(self, optional_params: dict) -> bool: diff --git a/tests/local_testing/test_function_calling.py b/tests/local_testing/test_function_calling.py index 7dddeb11cf84..2452b362d481 100644 --- a/tests/local_testing/test_function_calling.py +++ b/tests/local_testing/test_function_calling.py @@ -687,3 +687,85 @@ async def test_watsonx_tool_choice(sync_mode): pytest.skip("Skipping test due to timeout") else: raise e + + +@pytest.mark.asyncio +async def test_function_calling_with_dbrx(): + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + + client = AsyncHTTPHandler() + with patch.object(client, "post", return_value=MagicMock()) as mock_completion: + try: + resp = await litellm.acompletion( + model="databricks/databricks-dbrx-instruct", + messages=[ + { + "role": "system", + "content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.", + }, + { + "role": "user", + "content": "Hi, can you tell me the delivery date for my order?", + }, + { + "role": "assistant", + "content": "Hi there! I can help with that. Can you please provide your order ID?", + }, + { + "role": "user", + "content": "i think it is order_12345, also what is the weather in Phoenix, AZ?", + }, + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_delivery_date", + "description": "Get the delivery date for a customer'''s order. Call this whenever you need to know the delivery date, for example when a customer asks '''Where is my package'''", + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The customer'''s order ID.", + } + }, + "required": ["order_id"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "check_weather", + "description": "Check the current weather in a location. For example when asked: '''What is the temperature in San Fransisco, CA?'''", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to check the weather for.", + }, + "state": { + "type": "string", + "description": "The state to check the weather for.", + }, + }, + "required": ["city", "state"], + "additionalProperties": False, + }, + }, + }, + ], + client=client, + tool_choice="auto", + ) + except Exception as e: + print(e) + + mock_completion.assert_called_once() + print(mock_completion.call_args.kwargs) + json_data = json.loads(mock_completion.call_args.kwargs["data"]) + assert "tools" in json_data + assert "tool_choice" in json_data diff --git a/tests/local_testing/test_parallel_request_limiter.py b/tests/local_testing/test_parallel_request_limiter.py index 7dffdd2f370e..9221efa825ae 100644 --- a/tests/local_testing/test_parallel_request_limiter.py +++ b/tests/local_testing/test_parallel_request_limiter.py @@ -454,6 +454,7 @@ async def test_failure_call_hook(): """ +@pytest.mark.flaky(retries=6, delay=2) @pytest.mark.asyncio async def test_normal_router_call(): model_list = [ @@ -528,6 +529,7 @@ async def test_normal_router_call(): ) +@pytest.mark.flaky(retries=6, delay=2) @pytest.mark.asyncio async def test_normal_router_tpm_limit(): import logging @@ -615,6 +617,7 @@ async def test_normal_router_tpm_limit(): assert e.status_code == 429 +@pytest.mark.flaky(retries=6, delay=2) @pytest.mark.asyncio async def test_streaming_router_call(): model_list = [ @@ -690,6 +693,7 @@ async def test_streaming_router_call(): ) +@pytest.mark.flaky(retries=6, delay=2) @pytest.mark.asyncio async def test_streaming_router_tpm_limit(): litellm.set_verbose = True @@ -845,6 +849,7 @@ async def test_bad_router_call(): ) +@pytest.mark.flaky(retries=6, delay=2) @pytest.mark.asyncio async def test_bad_router_tpm_limit(): model_list = [ @@ -923,6 +928,7 @@ async def test_bad_router_tpm_limit(): ) +@pytest.mark.flaky(retries=6, delay=2) @pytest.mark.asyncio async def test_bad_router_tpm_limit_per_model(): model_list = [ @@ -1023,6 +1029,7 @@ async def test_bad_router_tpm_limit_per_model(): ) +@pytest.mark.flaky(retries=6, delay=2) @pytest.mark.asyncio async def test_pre_call_hook_rpm_limits_per_model(): """ @@ -1101,6 +1108,7 @@ async def test_pre_call_hook_rpm_limits_per_model(): ) +@pytest.mark.flaky(retries=6, delay=2) @pytest.mark.asyncio async def test_pre_call_hook_tpm_limits_per_model(): """