Skip to content

Commit

Permalink
Dev model providers (#3628)
Browse files Browse the repository at this point in the history
* gemini 初始化参数问题

* gemini 同步工具调用
  • Loading branch information
glide-the authored Apr 6, 2024
1 parent b3dee0b commit 5169228
Showing 1 changed file with 101 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, Union

import google.api_core.exceptions as exceptions
import google.generativeai as genai
import google.generativeai.client as client
from google.ai.generativelanguage_v1beta import FunctionCall, FunctionResponse
from google.generativeai.types import (
ContentType,
GenerateContentResponse,
HarmBlockThreshold,
HarmCategory,
)
from google.generativeai.types.content_types import to_part
from google.generativeai.types.content_types import to_part, FunctionDeclaration, Tool, FunctionLibrary

from model_providers.core.model_runtime.entities.llm_entities import (
LLMResult,
Expand Down Expand Up @@ -56,15 +58,15 @@

class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
Expand All @@ -81,15 +83,15 @@ def _invoke(
"""
# invoke model
return self._generate(
model, credentials, prompt_messages, model_parameters, stop, stream, user
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user
)

def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
Expand Down Expand Up @@ -138,14 +140,15 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
raise CredentialsValidateFailedError(str(ex))

def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
Expand All @@ -160,9 +163,13 @@ def _generate(
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop(
config_kwargs.pop(
"max_tokens_to_sample", None
)
# https://github.com/google/generative-ai-python/issues/170
# config_kwargs["max_output_tokens"] = config_kwargs.pop(
# "max_tokens_to_sample", None
# )

if stop:
config_kwargs["stop_sequences"] = stop
Expand Down Expand Up @@ -197,12 +204,21 @@ def _generate(
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
tools_one = []
for tool in tools:
one_tool = Tool(function_declarations=[FunctionDeclaration(name=tool.name,
description=tool.description,
parameters=tool.parameters
)
])
tools_one.append(one_tool)

response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs),
stream=stream,
safety_settings=safety_settings,
tools=FunctionLibrary(tools=tools_one),
)

if stream:
Expand All @@ -215,11 +231,11 @@ def _generate(
)

def _handle_generate_response(
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm response
Expand All @@ -230,8 +246,23 @@ def _handle_generate_response(
:param prompt_messages: prompt messages
:return: llm response
"""
part = response.candidates[0].content.parts[0]
part_message_function_call = part.function_call
tool_calls = []
if part_message_function_call:
function_call = self._extract_response_function_call(
part_message_function_call
)
tool_calls.append(function_call)
part_message_function_response = part.function_response
if part_message_function_response:
function_call = self._extract_response_function_call(
part_message_function_call
)
tool_calls.append(function_call)

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=response.text)
assistant_prompt_message = AssistantPromptMessage(content=part.text, tool_calls=tool_calls)

# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
Expand All @@ -255,11 +286,11 @@ def _handle_generate_response(
return result

def _handle_generate_stream_response(
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
self,
model: str,
credentials: dict,
response: GenerateContentResponse,
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm stream response
Expand Down Expand Up @@ -413,3 +444,37 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]
exceptions.Cancelled,
],
}

def _extract_response_function_call(
self, response_function_call: Union[FunctionCall, FunctionResponse]
) -> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
:param response_function_call: response function call
:return: tool call
"""
tool_call = None
if response_function_call:
from google.protobuf import json_format

if isinstance(response_function_call, FunctionCall):
map_composite_dict = dict(response_function_call.args.items())
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.name,
arguments=str(map_composite_dict),
)
elif isinstance(response_function_call, FunctionResponse):
map_composite_dict = dict(response_function_call.response.items())
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.name,
arguments=str(map_composite_dict),
)
else:
raise ValueError(f"Unsupported response_function_call type: {type(response_function_call)}")

tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.name, type="function", function=function
)

return tool_call

0 comments on commit 5169228

Please sign in to comment.