Skip to content

Commit

Permalink
Ensures tool calls do not have unknown args (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
markwaddle authored Oct 18, 2024
1 parent 576b792 commit dc8d978
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.functions import FunctionResult, KernelArguments, KernelPlugin
from semantic_kernel.functions import FunctionResult, KernelArguments, KernelFunction
from semantic_kernel.functions.kernel_function_decorator import kernel_function

execution_template = """<message role="system">You are a helpful, thoughtful, and meticulous assistant.
Expand Down Expand Up @@ -44,22 +44,22 @@ def end_conversation() -> None:


async def execution(
kernel: Kernel, reasoning: str, filter: list[str], req_settings: PromptExecutionSettings, artifact_schema: str
kernel: Kernel, reasoning: str, functions: list[str], req_settings: PromptExecutionSettings, artifact_schema: str
) -> FunctionResult:
"""Executes the actions recommended by the reasoning/planning call in the given context.
Args:
kernel (Kernel): The kernel object.
reasoning (str): The reasoning from a previous model call.
filter (list[str]): The list of plugins to INCLUDE for the tool call.
functions (list[str]): The list of plugins to INCLUDE for the tool call.
req_settings (PromptExecutionSettings): The prompt execution settings.
artifact (str): The artifact schema for the execution prompt.
Returns:
FunctionResult: The result of the execution.
"""
req_settings.function_choice_behavior = FunctionChoiceBehavior.Auto(
auto_invoke=False, filters={"included_plugins": filter}
auto_invoke=False, filters={"included_plugins": functions}
)

kernel_function = kernel.add_function(
Expand All @@ -69,7 +69,7 @@ async def execution(
template_format="jinja2",
prompt_execution_settings=req_settings,
)
if isinstance(kernel_function, KernelPlugin):
if not isinstance(kernel_function, KernelFunction):
raise ValueError("Invalid kernel function type.")

arguments = KernelArguments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ async def execute_plan(
result = await execution(
kernel=self.kernel,
reasoning=plan,
filter=functions,
functions=functions,
req_settings=req_settings,
artifact_schema=self.artifact.get_schema_for_prompt(),
)
Expand Down Expand Up @@ -318,7 +318,7 @@ async def final_update(self):
execution_response = await execution(
kernel=self.kernel,
reasoning=reasoning_response.value[0].content,
filter=functions,
functions=functions,
req_settings=req_settings,
artifact_schema=self.artifact.get_schema_for_prompt(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ToolValidationResult(Enum):
INVALID_TOOL_CALLED = "A tool was called with an unexpected name"
MISSING_REQUIRED_ARGUMENT = "The tool called is missing a required argument"
INVALID_ARGUMENT_TYPE = "The value of an argument is of an unexpected type"
INVALID_ARGUMENT = "The tool called has an unexpected argument"
SUCCESS = "success"


Expand Down Expand Up @@ -158,4 +159,9 @@ def validate_tool_calling(response: dict[str, Any], request_tool_param: dict) ->
logger.warning(f"Missing required argument '{arg.argument_name}' for tool '{tool_name}'.")
return ToolValidationResult.MISSING_REQUIRED_ARGUMENT

for tool_arg_name in tool_args.keys():
if tool_arg_name not in [arg.argument_name for arg in tool.args]:
logger.warning(f"Unexpected argument '{tool_arg_name}' for tool '{tool_name}'.")
return ToolValidationResult.INVALID_ARGUMENT

return ToolValidationResult.SUCCESS

0 comments on commit dc8d978

Please sign in to comment.