diff --git a/src/genai/_utils/api_client.py b/src/genai/_utils/api_client.py index f5e565cc..f4c0777e 100644 --- a/src/genai/_utils/api_client.py +++ b/src/genai/_utils/api_client.py @@ -1,7 +1,8 @@ from asyncio import AbstractEventLoop -from typing import Any, Optional +from typing import Any, Optional, cast -from httpx import Timeout +from httpx import Auth, Request, Timeout +from httpx._auth import FunctionAuth from pydantic import BaseModel, ConfigDict, Field, field_validator from genai._types import ModelLike @@ -26,6 +27,7 @@ class HttpClientOptions(BaseModel): model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) timeout: Timeout = Timeout(timeout=10 * 60, connect=10) + auth: Optional[Auth] = None @field_validator("timeout", mode="before") @classmethod @@ -129,15 +131,24 @@ def get_async_http_client( def _get_headers(self, override: Optional[dict]) -> dict: headers = { **(override or {}), - "Authorization": f"Bearer {self._credentials.api_key.get_secret_value()}", "x-request-origin": f"python-sdk/{__version__}", "user-agent": f"python-sdk/{__version__}", } return headers + def _get_default_auth(self) -> Auth: + """Default Authorization function, can be overridden in client_options""" + + def _auth_fn(request: Request) -> Request: + request.headers["Authorization"] = f"Bearer {self._credentials.api_key.get_secret_value()}" + return request + + return FunctionAuth(_auth_fn) + def _get_client_options(self, override: Optional[dict] = None) -> dict: final = merge_objects( + cast(dict[str, Any], {"auth": self._get_default_auth()}), self.config.client_options.model_dump(exclude_none=True), override, { diff --git a/tests/unit/test_api_client.py b/tests/unit/test_api_client.py new file mode 100644 index 00000000..2941fa2e --- /dev/null +++ b/tests/unit/test_api_client.py @@ -0,0 +1,45 @@ +import logging + +import pytest +from httpx import Request +from httpx._auth import FunctionAuth +from pytest_httpx import HTTPXMock + +from genai import Client, Credentials +from genai.schema import TextGenerationLimitRetrieveEndpoint +from tests.helpers import match_endpoint + +logger = logging.getLogger(__name__) + + +@pytest.mark.unit +class TestApiClient: + def test_default_auth_header(self, patch_generate_limits, httpx_mock: HTTPXMock): + test_api_key = "test_api_key" + client = Client(credentials=Credentials(api_key=test_api_key)) + + client.text.generation.limit.retrieve() + request = httpx_mock.get_request(url=match_endpoint(TextGenerationLimitRetrieveEndpoint)) + + assert request.headers.get("Authorization") == f"Bearer {test_api_key}" + + def test_custom_auth_header(self, patch_generate_limits, httpx_mock: HTTPXMock): + custom_auth_header = "CUSTOM_AUTH_HEADER" + custom_auth_fn_called = False + + def _custom_auth_fn(request: Request) -> Request: + nonlocal custom_auth_fn_called + custom_auth_fn_called = True + request.headers["Authorization"] = custom_auth_header + return request + + client = Client( + credentials=Credentials(api_key="dummy_api_key"), + config={"api_client_config": {"client_options": {"auth": FunctionAuth(_custom_auth_fn)}}}, + ) + + client.text.generation.limit.retrieve() + request = httpx_mock.get_request(url=match_endpoint(TextGenerationLimitRetrieveEndpoint)) + + assert custom_auth_fn_called + assert request.headers.get("Authorization") == custom_auth_header