diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index 4c4ee6419..91f24a333 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -54,6 +54,7 @@ class OpenAILLM(AsyncLLM): api_key: the API key to authenticate the requests to the OpenAI API. Defaults to `None` which means that the value set for the environment variable `OPENAI_API_KEY` will be used, or `None` if not set. + default_headers: the default headers to use for the OpenAI API requests. max_retries: the maximum number of times to retry the request to the API before failing. Defaults to `6`. timeout: the maximum time in seconds to wait for a response from the API. Defaults diff --git a/tests/unit/models/llms/test_anyscale.py b/tests/unit/models/llms/test_anyscale.py index d12dbebd0..6a31d6080 100644 --- a/tests/unit/models/llms/test_anyscale.py +++ b/tests/unit/models/llms/test_anyscale.py @@ -46,6 +46,7 @@ def test_serialization(self) -> None: "model": self.model_id, "generation_kwargs": {}, "max_retries": 6, + "default_headers": None, "base_url": "https://api.endpoints.anyscale.com/v1", "timeout": 120, "structured_output": None, diff --git a/tests/unit/models/llms/test_azure.py b/tests/unit/models/llms/test_azure.py index a2122b611..1e874c5f9 100644 --- a/tests/unit/models/llms/test_azure.py +++ b/tests/unit/models/llms/test_azure.py @@ -71,6 +71,7 @@ def test_azure_openai_llm_env_vars(self) -> None: "api_version": "preview", "generation_kwargs": {}, "max_retries": 6, + "default_headers": None, "base_url": "https://example-resource.azure.openai.com/", "timeout": 120, "structured_output": None, @@ -95,6 +96,7 @@ def test_azure_openai_llm_env_vars(self) -> None: "generation_kwargs": {}, "max_retries": 6, "base_url": "https://example-resource.azure.openai.com/", + "default_headers": None, "timeout": 120, "structured_output": { "schema": DummyUserDetail.model_json_schema(), diff --git a/tests/unit/models/llms/test_together.py b/tests/unit/models/llms/test_together.py index 88208bf6c..b7a045fbb 100644 --- a/tests/unit/models/llms/test_together.py +++ b/tests/unit/models/llms/test_together.py @@ -46,6 +46,7 @@ def test_serialization(self) -> None: "model": self.model_id, "generation_kwargs": {}, "max_retries": 6, + "default_headers": None, "base_url": "https://api.together.xyz/v1", "timeout": 120, "structured_output": None,