Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom default headers in OpenAILLM class. #1088

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/distilabel/models/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _offline_batch_generate_polling(
f" for {self.offline_batch_generation_block_until_done} seconds before"
" trying to get the results again."
)
# When running a `Step` in a child process, SIGINT is overriden so the child
# When running a `Step` in a child process, SIGINT is overridden so the child
# process doesn't stop when the parent process receives a SIGINT signal.
# The new handler sets an environment variable that is checked here to stop
# the polling.
Expand Down
7 changes: 7 additions & 0 deletions src/distilabel/models/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ class User(BaseModel):
default_factory=lambda: os.getenv(_OPENAI_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the OpenAI API.",
)
default_headers: Optional[RuntimeParameter[Dict[str, str]]] = Field(
default=None,
description="The default headers to use for the OpenAI API requests.",
)
max_retries: RuntimeParameter[int] = Field(
default=6,
description="The maximum number of times to retry the request to the API before"
Expand Down Expand Up @@ -196,13 +200,15 @@ def load(self) -> None:
api_key=self.api_key.get_secret_value(),
max_retries=self.max_retries, # type: ignore
timeout=self.timeout,
default_headers=self.default_headers,
)

self._aclient = AsyncOpenAI(
base_url=self.base_url,
api_key=self.api_key.get_secret_value(),
max_retries=self.max_retries, # type: ignore
timeout=self.timeout,
default_headers=self.default_headers,
)

if self.structured_output:
Expand All @@ -221,6 +227,7 @@ def unload(self) -> None:

self._client = None # type: ignore
self._aclient = None # type: ignore
self.default_headers = None
self.structured_output = None
super().unload()

Expand Down
25 changes: 23 additions & 2 deletions tests/unit/models/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,16 +569,18 @@ def test_create_jsonl_row(
}

@pytest.mark.parametrize(
"structured_output, dump",
"default_headers, structured_output, dump",
[
(
None,
None,
{
"model": "gpt-4",
"generation_kwargs": {},
"max_retries": 6,
"base_url": "https://api.openai.com/v1",
"timeout": 120,
"default_headers": None,
"structured_output": None,
"jobs_ids": None,
"offline_batch_generation_block_until_done": None,
Expand All @@ -590,6 +592,7 @@ def test_create_jsonl_row(
},
),
(
{"X-Custom-Header": "test"},
{
"schema": DummyUserDetail.model_json_schema(),
"mode": "tool_call",
Expand All @@ -601,6 +604,7 @@ def test_create_jsonl_row(
"max_retries": 6,
"base_url": "https://api.openai.com/v1",
"timeout": 120,
"default_headers": {"X-Custom-Header": "test"},
"structured_output": {
"schema": DummyUserDetail.model_json_schema(),
"mode": "tool_call",
Expand All @@ -621,10 +625,27 @@ def test_serialization(
self,
_async_openai_mock: MagicMock,
_openai_mock: MagicMock,
default_headers: Dict[str, Any],
structured_output: Dict[str, Any],
dump: Dict[str, Any],
) -> None:
llm = OpenAILLM(model=self.model_id, structured_output=structured_output)
llm = OpenAILLM(
model=self.model_id,
default_headers=default_headers,
structured_output=structured_output,
)

assert llm.dump() == dump
assert isinstance(OpenAILLM.from_dict(dump), OpenAILLM)

def test_openai_llm_default_headers(
self, _async_openai_mock: MagicMock, _openai_mock: MagicMock
) -> None:
custom_headers = {"X-Custom-Header": "test"}
llm = OpenAILLM(
model=self.model_id, api_key="api.key", default_headers=custom_headers
) # type: ignore

assert isinstance(llm, OpenAILLM)
assert llm.model_name == self.model_id
assert llm.default_headers == custom_headers