Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensures tool calls do not have unknown args #146

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.functions import FunctionResult, KernelArguments, KernelPlugin
from semantic_kernel.functions import FunctionResult, KernelArguments, KernelFunction
from semantic_kernel.functions.kernel_function_decorator import kernel_function

execution_template = """<message role="system">You are a helpful, thoughtful, and meticulous assistant.
Expand Down Expand Up @@ -44,22 +44,22 @@ def end_conversation() -> None:


async def execution(
kernel: Kernel, reasoning: str, filter: list[str], req_settings: PromptExecutionSettings, artifact_schema: str
kernel: Kernel, reasoning: str, functions: list[str], req_settings: PromptExecutionSettings, artifact_schema: str
) -> FunctionResult:
"""Executes the actions recommended by the reasoning/planning call in the given context.

Args:
kernel (Kernel): The kernel object.
reasoning (str): The reasoning from a previous model call.
filter (list[str]): The list of plugins to INCLUDE for the tool call.
functions (list[str]): The list of plugins to INCLUDE for the tool call.
req_settings (PromptExecutionSettings): The prompt execution settings.
artifact (str): The artifact schema for the execution prompt.

Returns:
FunctionResult: The result of the execution.
"""
req_settings.function_choice_behavior = FunctionChoiceBehavior.Auto(
auto_invoke=False, filters={"included_plugins": filter}
auto_invoke=False, filters={"included_plugins": functions}
)

kernel_function = kernel.add_function(
Expand All @@ -69,7 +69,7 @@ async def execution(
template_format="jinja2",
prompt_execution_settings=req_settings,
)
if isinstance(kernel_function, KernelPlugin):
if not isinstance(kernel_function, KernelFunction):
raise ValueError("Invalid kernel function type.")

arguments = KernelArguments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ async def execute_plan(
result = await execution(
kernel=self.kernel,
reasoning=plan,
filter=functions,
functions=functions,
req_settings=req_settings,
artifact_schema=self.artifact.get_schema_for_prompt(),
)
Expand Down Expand Up @@ -318,7 +318,7 @@ async def final_update(self):
execution_response = await execution(
kernel=self.kernel,
reasoning=reasoning_response.value[0].content,
filter=functions,
functions=functions,
req_settings=req_settings,
artifact_schema=self.artifact.get_schema_for_prompt(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ToolValidationResult(Enum):
INVALID_TOOL_CALLED = "A tool was called with an unexpected name"
MISSING_REQUIRED_ARGUMENT = "The tool called is missing a required argument"
INVALID_ARGUMENT_TYPE = "The value of an argument is of an unexpected type"
INVALID_ARGUMENT = "The tool called has an unexpected argument"
SUCCESS = "success"


Expand Down Expand Up @@ -158,4 +159,9 @@ def validate_tool_calling(response: dict[str, Any], request_tool_param: dict) ->
logger.warning(f"Missing required argument '{arg.argument_name}' for tool '{tool_name}'.")
return ToolValidationResult.MISSING_REQUIRED_ARGUMENT

for tool_arg_name in tool_args.keys():
if tool_arg_name not in [arg.argument_name for arg in tool.args]:
logger.warning(f"Unexpected argument '{tool_arg_name}' for tool '{tool_name}'.")
return ToolValidationResult.INVALID_ARGUMENT

return ToolValidationResult.SUCCESS