Skip to content

Commit

Permalink
feat: Enhance OpenAI client registration with provider handling and m…
Browse files Browse the repository at this point in the history
…odel-specific headers
  • Loading branch information
ankumar committed Feb 7, 2025
1 parent f8b9c75 commit f6b7da8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 115 deletions.
10 changes: 5 additions & 5 deletions examples/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
javelin_api_key=javelin_api_key,
)
client = JavelinClient(config)
client.register_openai(openai_client, provider_name="OpenAI", route_name="openai")
client.register_openai(openai_client, route_name="openai")

# Call OpenAI endpoints
print("OpenAI: 1 - Chat completions")
Expand Down Expand Up @@ -85,7 +85,7 @@
javelin_api_key=javelin_api_key,
)
client = JavelinClient(config)
client.register_openai(openai_async_client, provider_name="OpenAI", route_name="openai")
client.register_openai(openai_async_client, route_name="openai")

async def main() -> None:
chat_completion = await openai_async_client.chat.completions.create(
Expand Down Expand Up @@ -135,7 +135,7 @@ async def main():
javelin_api_key=javelin_api_key,
)
client = JavelinClient(config)
client.register_openai(openai_client, provider_name="Google", route_name="openai")
client.register_gemini(openai_client, route_name="openai")

print("Gemini: 1 - Chat completions")

Expand Down Expand Up @@ -289,7 +289,7 @@ class CalendarEvent(BaseModel):
javelin_api_key=javelin_api_key,
)
client = JavelinClient(config)
client.register_openai(openai_client, provider_name="Azure", route_name="openai")
client.register_azureopenai(openai_client, route_name="openai")
completion = openai_client.chat.completions.create(
model="gpt-4o-mini", # e.g. gpt-35-instant
Expand Down Expand Up @@ -319,7 +319,7 @@ class CalendarEvent(BaseModel):
)
# client = JavelinClient(config)
# client.register_openai(openai_client, provider_name="DeepSeek", route_name="openai")
# client.register_deepseek(openai_client, route_name="openai")
response = openai_client.chat.completions.create(
model="deepseek-chat",
Expand Down
145 changes: 35 additions & 110 deletions javelin_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,148 +96,73 @@ def close(self):
if self._client:
self._client.close()

def register_openai(self,
openai_client: Any,
provider_name: str = None,
route_name: str = None) -> Any:
def register_provider(self,
openai_client: Any,
provider_name: str,
route_name: str = None) -> Any:
"""
Register the passed-in OpenAI client so that calls to:
- client.chat.completions.create(...)
- client.completions.create(...)
- client.embeddings.create(...)
are intercepted by Javelin.
Generalized function to register OpenAI, Azure OpenAI, and Gemini clients.
Additionally sets:
- openai_client.base_url to self.base_url
- openai_client._custom_headers to include self._headers
"""

'''
self.client_is_async = None
if type(openai_client) == OpenAI:
print("DEBUG - OpenAI client is sync")
self.client_is_async = False
elif type(openai_client) == AsyncOpenAI:
print("DEBUG - OpenAI client is async")
self.client_is_async = True
else:
raise Exception(f"Unknown client type: {type(openai_client)}")
'''

# Store the OpenAI base URL
if self.openai_base_url is None:
self.openai_base_url = openai_client.base_url

# Point the OpenAI client to Javelin's base URL
openai_client.base_url=self.base_url + "/" + provider_name
openai_client.base_url = f"{self.base_url}/{provider_name}"

if not hasattr(openai_client, "_custom_headers"):
openai_client._custom_headers = {}
openai_client._custom_headers.update(self._headers)

base_url_str = str(self.openai_base_url)
# Remove trailing slash if present
if base_url_str.endswith("/"):
base_url_str = base_url_str[:-1]
base_url_str = str(self.openai_base_url).rstrip("/") # Remove trailing slash if present

# Update Javelin headers into the client's _custom_headers
# Update Javelin headers into the client's _custom_headers
openai_client._custom_headers["x-javelin-provider"] = base_url_str
openai_client._custom_headers["x-javelin-route"] = route_name

# Print out the headers you’ve set (for debug)
print("DEBUG - Patched OpenAI client headers:", openai_client._custom_headers)

# Store references to the original methods
original_chat_completions_create = openai_client.chat.completions.create
original_completions_create = openai_client.completions.create
original_embeddings_create = openai_client.embeddings.create

# Define patched versions, injecting Javelin logs/traces

def patched_chat_completions_create(*args, **kwargs):
# Update openai_client._custom_headers directly if model is set
model = kwargs.get('model') # Extract the 'model' field
if model and hasattr(openai_client, "_custom_headers"):
openai_client._custom_headers['x-javelin-model'] = model # Add or update the custom header

'''
TODO: self.trace_service.log_trace(
message="OpenAI chat.completions.create called",
extra={"args": args, "kwargs": kwargs},
)
'''

# Call the real method
response = original_chat_completions_create(*args, **kwargs)

# AFTER calling original
'''
TODO: self.trace_service.log_trace(
message="OpenAI chat.completions.create response",
extra={"response": response},
)
'''

return response

def patched_completions_create(*args, **kwargs):
# Update openai_client._custom_headers directly if model is set
model = kwargs.get('model') # Extract the 'model' field
if model and hasattr(openai_client, "_custom_headers"):
openai_client._custom_headers['x-javelin-model'] = model # Add or update the custom header
original_methods = {
"chat_completions_create": openai_client.chat.completions.create,
"completions_create": openai_client.completions.create,
"embeddings_create": openai_client.embeddings.create,
}

'''
TODO: self.trace_service.log_trace(
message="OpenAI completions.create called",
extra={"args": args, "kwargs": kwargs},
)
'''
# Patch methods with tracing and header updates
def create_patched_method(original_method):
def patched_method(*args, **kwargs):
model = kwargs.get('model')
if model and hasattr(openai_client, "_custom_headers"):
openai_client._custom_headers['x-javelin-model'] = model

# Call the real method
response = original_completions_create(*args, **kwargs)
response = original_method(*args, **kwargs)
return response

'''
TODO: self.trace_service.log_trace(
message="OpenAI completions.create response",
extra={"response": response},
)
'''
return patched_method

return response
# Apply patches
openai_client.chat.completions.create = create_patched_method(original_methods["chat_completions_create"])
openai_client.completions.create = create_patched_method(original_methods["completions_create"])
openai_client.embeddings.create = create_patched_method(original_methods["embeddings_create"])

def patched_embeddings_create(*args, **kwargs):
# Update openai_client._custom_headers directly if model is set
model = kwargs.get('model') # Extract the 'model' field
if model and hasattr(openai_client, "_custom_headers"):
openai_client._custom_headers['x-javelin-model'] = model # Add or update the custom header

'''
TODO: self.trace_service.log_trace(
message="OpenAI embeddings.create called",
extra={"args": args, "kwargs": kwargs},
)
'''
return openai_client

# Call the real method
response = original_embeddings_create(*args, **kwargs)
def register_openai(self, openai_client: Any, route_name: str = None) -> Any:
return self.register_provider(openai_client, provider_name="openai", route_name=route_name)

'''
TODO: self.trace_service.log_trace(
message="OpenAI embeddings.create response",
extra={"response": response},
)
'''
def register_azureopenai(self, openai_client: Any, route_name: str = None) -> Any:
return self.register_provider(openai_client, provider_name="azureopenai", route_name=route_name)

return response
def register_gemini(self, openai_client: Any, route_name: str = None) -> Any:
return self.register_provider(openai_client, provider_name="gemini", route_name=route_name)

# patch the client’s methods
openai_client.chat.completions.create = patched_chat_completions_create
openai_client.completions.create = patched_completions_create
openai_client.embeddings.create = patched_embeddings_create
def register_deepseek(self, openai_client: Any, route_name: str = None) -> Any:
return self.register_provider(openai_client, provider_name="deepseek", route_name=route_name)

# Return the patched client
return openai_client

def register_bedrock(self,
bedrock_runtime_client: Any,
bedrock_client: Any = None,
Expand Down

0 comments on commit f6b7da8

Please sign in to comment.