Skip to content

Commit

Permalink
feat: add support for custom auth (#335)
Browse files Browse the repository at this point in the history
Signed-off-by: Radek Ježek <[email protected]>
  • Loading branch information
jezekra1 authored Mar 4, 2024
1 parent 6d6455c commit 4082c4b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/genai/_utils/api_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from asyncio import AbstractEventLoop
from typing import Any, Optional
from typing import Any, Optional, cast

from httpx import Timeout
from httpx import Auth, Request, Timeout
from httpx._auth import FunctionAuth
from pydantic import BaseModel, ConfigDict, Field, field_validator

from genai._types import ModelLike
Expand All @@ -26,6 +27,7 @@ class HttpClientOptions(BaseModel):
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)

timeout: Timeout = Timeout(timeout=10 * 60, connect=10)
auth: Optional[Auth] = None

@field_validator("timeout", mode="before")
@classmethod
Expand Down Expand Up @@ -129,15 +131,24 @@ def get_async_http_client(
def _get_headers(self, override: Optional[dict]) -> dict:
headers = {
**(override or {}),
"Authorization": f"Bearer {self._credentials.api_key.get_secret_value()}",
"x-request-origin": f"python-sdk/{__version__}",
"user-agent": f"python-sdk/{__version__}",
}

return headers

def _get_default_auth(self) -> Auth:
"""Default Authorization function, can be overridden in client_options"""

def _auth_fn(request: Request) -> Request:
request.headers["Authorization"] = f"Bearer {self._credentials.api_key.get_secret_value()}"
return request

return FunctionAuth(_auth_fn)

def _get_client_options(self, override: Optional[dict] = None) -> dict:
final = merge_objects(
cast(dict[str, Any], {"auth": self._get_default_auth()}),
self.config.client_options.model_dump(exclude_none=True),
override,
{
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/test_api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import logging

import pytest
from httpx import Request
from httpx._auth import FunctionAuth
from pytest_httpx import HTTPXMock

from genai import Client, Credentials
from genai.schema import TextGenerationLimitRetrieveEndpoint
from tests.helpers import match_endpoint

logger = logging.getLogger(__name__)


@pytest.mark.unit
class TestApiClient:
def test_default_auth_header(self, patch_generate_limits, httpx_mock: HTTPXMock):
test_api_key = "test_api_key"
client = Client(credentials=Credentials(api_key=test_api_key))

client.text.generation.limit.retrieve()
request = httpx_mock.get_request(url=match_endpoint(TextGenerationLimitRetrieveEndpoint))

assert request.headers.get("Authorization") == f"Bearer {test_api_key}"

def test_custom_auth_header(self, patch_generate_limits, httpx_mock: HTTPXMock):
custom_auth_header = "CUSTOM_AUTH_HEADER"
custom_auth_fn_called = False

def _custom_auth_fn(request: Request) -> Request:
nonlocal custom_auth_fn_called
custom_auth_fn_called = True
request.headers["Authorization"] = custom_auth_header
return request

client = Client(
credentials=Credentials(api_key="dummy_api_key"),
config={"api_client_config": {"client_options": {"auth": FunctionAuth(_custom_auth_fn)}}},
)

client.text.generation.limit.retrieve()
request = httpx_mock.get_request(url=match_endpoint(TextGenerationLimitRetrieveEndpoint))

assert custom_auth_fn_called
assert request.headers.get("Authorization") == custom_auth_header

0 comments on commit 4082c4b

Please sign in to comment.