Skip to content

Commit

Permalink
Fixes service_type discriminators (#270)
Browse files Browse the repository at this point in the history
To use Literal types
  • Loading branch information
markwaddle authored Nov 26, 2024
1 parent 9118a0e commit adc6dda
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ async def respond_to_conversation(

if result.error:
logger.exception(
f"exception occurred calling {config.service_config.service_type} chat completion: {result.error}"
f"exception occurred calling {config.service_config.llm_service_type} chat completion: {result.error}"
)

# set the message type based on the content
Expand Down Expand Up @@ -281,9 +281,9 @@ async def respond_to_conversation(

response_content = content
if not response_content and "error" in metadata:
response_content = f"[error from {config.service_config.service_type}: {metadata['error']}]"
response_content = f"[error from {config.service_config.llm_service_type}: {metadata['error']}]"
if not response_content:
response_content = f"[no response from {config.service_config.service_type}]"
response_content = f"[no response from {config.service_config.llm_service_type}]"

# send the response to the conversation
await context.send_messages(
Expand Down
16 changes: 7 additions & 9 deletions examples/python/python-03-multimodel-chatbot/assistant/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib
from abc import ABC, abstractmethod
from enum import StrEnum
from typing import Annotated, Any
from typing import Annotated, Any, Literal

import google.generativeai as genai
import openai
Expand Down Expand Up @@ -48,12 +48,10 @@ class ServiceType(StrEnum):


class ServiceConfig(ABC, BaseModel):
llm_service_type: Annotated[ServiceType, UISchema(widget="hidden")]

@property
def service_type_display_name(self) -> str:
# get from the class title
return self.model_config.get("title") or self.llm_service_type
return self.model_config.get("title") or self.__class__.__name__

@abstractmethod
def new_client(self, **kwargs) -> Any:
Expand All @@ -76,7 +74,7 @@ class AzureOpenAIServiceConfig(ServiceConfig, openai_client.AzureOpenAIServiceCo
},
)

llm_service_type: Annotated[ServiceType, UISchema(widget="hidden")] = ServiceType.AzureOpenAI
llm_service_type: Annotated[Literal[ServiceType.AzureOpenAI], UISchema(widget="hidden")] = ServiceType.AzureOpenAI

openai_model: Annotated[
str,
Expand All @@ -103,7 +101,7 @@ class OpenAIServiceConfig(ServiceConfig, openai_client.OpenAIServiceConfig):
},
)

llm_service_type: Annotated[ServiceType, UISchema(widget="hidden")] = ServiceType.OpenAI
llm_service_type: Annotated[Literal[ServiceType.OpenAI], UISchema(widget="hidden")] = ServiceType.OpenAI

openai_model: Annotated[
str,
Expand All @@ -129,7 +127,7 @@ class AnthropicServiceConfig(ServiceConfig):
},
)

service_type: Annotated[ServiceType, UISchema(widget="hidden")] = ServiceType.Anthropic
llm_service_type: Annotated[Literal[ServiceType.Anthropic], UISchema(widget="hidden")] = ServiceType.Anthropic

anthropic_api_key: Annotated[
# ConfigSecretStr is a custom type that should be used for any secrets.
Expand Down Expand Up @@ -165,7 +163,7 @@ class GeminiServiceConfig(ServiceConfig):
},
)

service_type: Annotated[ServiceType, UISchema(widget="hidden")] = ServiceType.Gemini
llm_service_type: Annotated[Literal[ServiceType.Gemini], UISchema(widget="hidden")] = ServiceType.Gemini

gemini_api_key: Annotated[
# ConfigSecretStr is a custom type that should be used for any secrets.
Expand Down Expand Up @@ -202,7 +200,7 @@ class OllamaServiceConfig(ServiceConfig):
},
)

service_type: Annotated[ServiceType, UISchema(widget="hidden")] = ServiceType.Ollama
llm_service_type: Annotated[Literal[ServiceType.Ollama], UISchema(widget="hidden")] = ServiceType.Ollama

ollama_endpoint: Annotated[
str,
Expand Down
4 changes: 2 additions & 2 deletions libraries/python/openai-client/openai_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class AzureOpenAIServiceConfig(BaseModel):
},
)

service_type: Annotated[ServiceType, UISchema(widget="hidden")] = ServiceType.AzureOpenAI
service_type: Annotated[Literal[ServiceType.AzureOpenAI], UISchema(widget="hidden")] = ServiceType.AzureOpenAI

auth_config: Annotated[
AzureOpenAIAzureIdentityAuthConfig | AzureOpenAIApiKeyAuthConfig,
Expand Down Expand Up @@ -89,7 +89,7 @@ class AzureOpenAIServiceConfig(BaseModel):
class OpenAIServiceConfig(BaseModel):
model_config = ConfigDict(title="OpenAI", json_schema_extra={"required": ["openai_api_key"]})

service_type: Annotated[ServiceType, UISchema(widget="hidden")] = ServiceType.OpenAI
service_type: Annotated[Literal[ServiceType.OpenAI], UISchema(widget="hidden")] = ServiceType.OpenAI

openai_api_key: Annotated[
# ConfigSecretStr is a custom type that should be used for any secrets.
Expand Down

0 comments on commit adc6dda

Please sign in to comment.