Skip to content

Commit

Permalink
fix: add support for univ endpoints in javelin sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
javelin authored and javelin committed Feb 20, 2025
1 parent a84d4df commit 7addf8e
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 32 deletions.
13 changes: 0 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,10 @@
repos:
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)

- repo: https://github.com/psf/black
rev: 24.3.0
hooks:
- id: black
language_version: python3

- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: "v0.0.265"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/python-poetry/poetry
rev: "1.4.0" # add version here
hooks:
Expand Down
55 changes: 55 additions & 0 deletions examples/univ_endpoint_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio
import json
import os
from typing import Any, Dict

from javelin_sdk import JavelinClient, JavelinConfig


# Helper function to pretty print responses
def print_response(provider: str, response: Dict[str, Any]) -> None:
print(f"=== Response from {provider} ===")
print(json.dumps(response, indent=2))


# Setup client configuration
config = JavelinConfig(
base_url="https://api-dev.javelin.live",
javelin_api_key=os.getenv("JAVELIN_API_KEY"),
llm_api_key=os.getenv("OPENAI_API_KEY"),
)
client = JavelinClient(config)

# Example messages in OpenAI format
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What are the three primary colors?"},
]

# Define the headers based on the curl command
custom_headers = {
"Content-Type": "application/json",
"x-javelin-route": "openai_univ",
"x-javelin-model": "gpt-4",
"x-javelin-provider": "https://api.openai.com/v1",
"x-api-key": os.getenv("JAVELIN_API_KEY"), # Use environment variable for security
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}", # Use environment variable for security
}


async def main():
try:
query_body = {"messages": messages, "temperature": 0.7}
openai_response = await client.aquery_unified_endpoint(
provider_name="openai",
endpoint_type="chat",
query_body=query_body,
headers=custom_headers,
)
print_response("OpenAI", openai_response)
except Exception as e:
print(f"OpenAI query failed: {str(e)}")


# Run the async function
asyncio.run(main())
3 changes: 1 addition & 2 deletions javelin_cli/_internal/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import os
from pathlib import Path

from pydantic import ValidationError

from javelin_sdk.client import JavelinClient
from javelin_sdk.exceptions import (
BadRequest,
Expand All @@ -29,6 +27,7 @@
Template,
Templates,
)
from pydantic import ValidationError


def get_javelin_client():
Expand Down
72 changes: 67 additions & 5 deletions javelin_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from urllib.parse import unquote, urljoin, urlparse, urlunparse

import httpx
from opentelemetry.semconv._incubating.attributes import gen_ai_attributes
from opentelemetry.trace import SpanKind, Status, StatusCode

from javelin_sdk.chat_completions import Chat, Completions
from javelin_sdk.models import HttpMethod, JavelinConfig, Request
from javelin_sdk.services.gateway_service import GatewayService
Expand All @@ -19,6 +16,8 @@
from javelin_sdk.services.template_service import TemplateService
from javelin_sdk.services.trace_service import TraceService
from javelin_sdk.tracing_setup import configure_span_exporter
from opentelemetry.semconv._incubating.attributes import gen_ai_attributes
from opentelemetry.trace import SpanKind, Status, StatusCode

API_BASEURL = "https://api-dev.javelin.live"
API_BASE_PATH = "/v1"
Expand Down Expand Up @@ -546,7 +545,7 @@ def get_inference_model(inference_profile_identifier: str) -> str:
)
model_id = foundation_model_response["modelDetails"]["modelId"]
return model_id
except Exception as e:
except Exception:
# Fail silently if the model is not found
return None

Expand All @@ -557,7 +556,7 @@ def get_foundation_model(model_identifier: str) -> str:
modelIdentifier=model_identifier
)
return response["modelDetails"]["modelId"]
except Exception as e:
except Exception:
# Fail silently if the model is not found
return None

Expand Down Expand Up @@ -668,6 +667,7 @@ def _prepare_request(self, request: Request) -> tuple:
is_transformation_rules=request.is_transformation_rules,
is_model_specs=request.is_model_specs,
is_reload=request.is_reload,
univ_model=request.univ_model_config,
)
headers = {**self._headers, **(request.headers or {})}
return url, headers
Expand Down Expand Up @@ -708,6 +708,7 @@ def _construct_url(
is_transformation_rules: bool = False,
is_model_specs: bool = False,
is_reload: bool = False,
univ_model: Optional[Dict[str, Any]] = None,
) -> str:
url_parts = [self.base_url]

Expand Down Expand Up @@ -770,6 +771,12 @@ def _construct_url(
if query_params:
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
url += f"?{query_string}"

# Integrate construct_endpoint_url logic
if univ_model:
endpoint_url = self.construct_endpoint_url(univ_model)
url = urljoin(url, endpoint_url)

return url

# Gateway methods
Expand Down Expand Up @@ -876,6 +883,12 @@ def _construct_url(
aquery_llama = lambda self, route_name, query_body: self.route_service.aquery_llama(
route_name, query_body
)
query_unified_endpoint = lambda self, provider_name, endpoint_type, query_body, headers=None, query_params=None: self.route_service.query_unified_endpoint(
provider_name, endpoint_type, query_body, headers, query_params
)
aquery_unified_endpoint = lambda self, provider_name, endpoint_type, query_body, headers=None, query_params=None: self.route_service.aquery_unified_endpoint(
provider_name, endpoint_type, query_body, headers, query_params
)

# Secret methods
create_secret = lambda self, secret: self.secret_service.create_secret(secret)
Expand Down Expand Up @@ -969,3 +982,52 @@ async def aget_last_n_chronicle_records(
)
response = await self._send_request_async(request)
return response

def construct_endpoint_url(self, request_model: Dict[str, Any]) -> str:
"""
Constructs the endpoint URL based on the request model.
:param base_url: The base URL for the API.
:param request_model: The request model containing endpoint details.
:return: The constructed endpoint URL.
"""
base_url = self.base_url
provider_name = request_model.get("provider_name")
endpoint_type = request_model.get("endpoint_type")
deployment = request_model.get("deployment")
arn = request_model.get("arn")
api_version = request_model.get(
"api_version", "2023-07-01-preview"
) # Default version

if not provider_name:
raise ValueError("Provider name is not specified in the request model.")

if provider_name == "azureopenai" and deployment:
# Handle Azure OpenAI endpoints
if endpoint_type == "chat":
return f"{base_url}/{provider_name}/deployments/{deployment}/chat/completions?api-version={api_version}"
elif endpoint_type == "completion":
return f"{base_url}/{provider_name}/deployments/{deployment}/completions?api-version={api_version}"
elif endpoint_type == "embeddings":
return f"{base_url}/{provider_name}/deployments/{deployment}/embeddings?api-version={api_version}"
elif arn:
# Handle Bedrock endpoints
if endpoint_type == "invoke":
return f"{base_url}/v1/model/{arn}/invoke"
elif endpoint_type == "converse":
return f"{base_url}/v1/model/{arn}/converse"
elif endpoint_type == "invoke_stream":
return f"{base_url}/v1/model/{arn}/invoke-with-response-stream"
elif endpoint_type == "converse_stream":
return f"{base_url}/v1/model/{arn}/converse-stream"
else:
# Handle OpenAI compatible endpoints
if endpoint_type == "chat":
return f"{base_url}/{provider_name}/chat/completions"
elif endpoint_type == "completion":
return f"{base_url}/{provider_name}/completions"
elif endpoint_type == "embeddings":
return f"{base_url}/{provider_name}/embeddings"

raise ValueError("Invalid request model configuration")
21 changes: 19 additions & 2 deletions javelin_sdk/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from enum import Enum, auto
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, field_validator

from javelin_sdk.exceptions import UnauthorizedError
from pydantic import BaseModel, Field, field_validator


class GatewayConfig(BaseModel):
Expand Down Expand Up @@ -478,6 +477,7 @@ def __init__(
is_transformation_rules: bool = False,
is_model_specs: bool = False,
is_reload: bool = False,
univ_model_config: Optional[Dict[str, Any]] = None,
):
self.method = method
self.gateway = gateway
Expand All @@ -494,6 +494,7 @@ def __init__(
self.is_transformation_rules = is_transformation_rules
self.is_model_specs = is_model_specs
self.is_reload = is_reload
self.univ_model_config = univ_model_config


class Message(BaseModel):
Expand Down Expand Up @@ -568,3 +569,19 @@ class EndpointType(str, Enum):
INVOKE_STREAM = "invoke_stream"
CONVERSE_STREAM = "converse_stream"
ALL = "all"


class UnivModelConfig:
def __init__(
self,
provider_name: str,
endpoint_type: str,
deployment: Optional[str] = None,
arn: Optional[str] = None,
api_version: Optional[str] = None,
):
self.provider_name = provider_name
self.endpoint_type = endpoint_type
self.deployment = deployment
self.arn = arn
self.api_version = api_version
1 change: 0 additions & 1 deletion javelin_sdk/services/gateway_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
GatewayAlreadyExistsError,
Expand Down
1 change: 0 additions & 1 deletion javelin_sdk/services/modelspec_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, Optional

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down
1 change: 0 additions & 1 deletion javelin_sdk/services/provider_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down
52 changes: 48 additions & 4 deletions javelin_sdk/services/route_service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import json
import time
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union

import httpx
from jsonpath_ng import parse

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand All @@ -13,7 +10,8 @@
RouteNotFoundError,
UnauthorizedError,
)
from javelin_sdk.models import HttpMethod, Request, Route, Routes
from javelin_sdk.models import HttpMethod, Request, Route, Routes, UnivModelConfig
from jsonpath_ng import parse


class RouteService:
Expand Down Expand Up @@ -310,3 +308,49 @@ async def areload_route(self, route_name: str) -> str:
)
)
return response

def query_unified_endpoint(
self,
provider_name: str,
endpoint_type: str,
query_body: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
query_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
univ_model_config = UnivModelConfig(
provider_name=provider_name,
endpoint_type=endpoint_type,
)

request = Request(
method=HttpMethod.POST,
data=query_body,
univ_model_config=univ_model_config.__dict__,
headers=headers,
query_params=query_params,
)
response = self.client._send_request_sync(request)
return response.json()

async def aquery_unified_endpoint(
self,
provider_name: str,
endpoint_type: str,
query_body: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
query_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
univ_model_config = UnivModelConfig(
provider_name=provider_name,
endpoint_type=endpoint_type,
)

request = Request(
method=HttpMethod.POST,
data=query_body,
univ_model_config=univ_model_config.__dict__,
headers=headers,
query_params=query_params,
)
response = await self.client._send_request_async(request)
return response.json()
1 change: 0 additions & 1 deletion javelin_sdk/services/secret_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down
1 change: 0 additions & 1 deletion javelin_sdk/services/template_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down
1 change: 0 additions & 1 deletion javelin_sdk/services/trace_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down

0 comments on commit 7addf8e

Please sign in to comment.