Skip to content

Commit

Permalink
feat: Implement tool call and evaluate steps
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Jul 30, 2024
1 parent 7e40bbe commit 687f487
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 42 deletions.
44 changes: 20 additions & 24 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

from openai.types.chat.chat_completion import ChatCompletion

# import celpy
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import (
# ToolCallWorkflowStep,
PromptWorkflowStep,
EvaluateWorkflowStep,
ToolCallWorkflowStep,
# ErrorWorkflowStep,
IfElseWorkflowStep,
InputChatMLMessage,
PromptWorkflowStep,
# EvaluateWorkflowStep,
YieldWorkflowStep,
)
from ...clients.worker.types import ChatML
Expand Down Expand Up @@ -62,20 +61,17 @@ async def prompt_step(context: StepContext) -> dict:
return response


# @activity.defn
# async def evaluate_step(context: StepContext) -> dict:
# if not isinstance(context.definition, EvaluateWorkflowStep):
# return {}
@activity.defn
async def evaluate_step(context: StepContext) -> dict:
assert isinstance(context.definition, EvaluateWorkflowStep)

# FIXME: set the field to keep source code
source: str = context.definition.evaluate
# FIXME: set up names
names = {}
result = simple_eval(source, names=names)

# # FIXME: set the field to keep source code
# source: str = context.definition.evaluate
# env = celpy.Environment()
# ast = env.compile(source)
# prog = env.program(ast)
# # TODO: set args
# args = {}
# result = prog.evaluate(args)
# return {"result": result}
return {"result": result}


@activity.defn
Expand All @@ -88,14 +84,14 @@ async def yield_step(context: StepContext) -> dict:
return {"test": "result"}


# @activity.defn
# async def tool_call_step(context: StepContext) -> dict:
# assert isinstance(context.definition, ToolCallWorkflowStep)
@activity.defn
async def tool_call_step(context: StepContext) -> dict:
assert isinstance(context.definition, ToolCallWorkflowStep)

# context.definition.tool_id
# context.definition.arguments
# # get tool by id
# # call tool
context.definition.tool_id
context.definition.arguments
# get tool by id
# call tool


# @activity.defn
Expand Down
7 changes: 4 additions & 3 deletions agents-api/agents_api/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from ..activities.salient_questions import salient_questions
from ..activities.summarization import summarization
from ..activities.task_steps import (
prompt_step,
evaluate_step,
yield_step,
# tool_call_step,
# error_step,
if_else_step,
prompt_step,
transition_step,
# evaluate_step,
yield_step,
)
from ..activities.truncation import truncation
from ..env import (
Expand Down Expand Up @@ -74,7 +75,7 @@ async def main():

task_activities = [
prompt_step,
# evaluate_step,
evaluate_step,
yield_step,
# tool_call_step,
# error_step,
Expand Down
32 changes: 17 additions & 15 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
if_else_step,
prompt_step,
transition_step,
evaluate_step,
tool_call_step,
)
from ..common.protocol.tasks import (
ExecutionInput,
# ToolCallWorkflowStep,
PromptWorkflowStep,
EvaluateWorkflowStep,
ToolCallWorkflowStep,
# ErrorWorkflowStep,
IfElseWorkflowStep,
PromptWorkflowStep,
StepContext,
TransitionInfo,
# EvaluateWorkflowStep,
YieldWorkflowStep,
)

Expand Down Expand Up @@ -67,23 +69,23 @@ async def run(
# if outputs.tool_calls is not None:
# should_wait = True

# case EvaluateWorkflowStep():
# result = await workflow.execute_activity(
# evaluate_step,
# context,
# schedule_to_close_timeout=timedelta(seconds=600),
# )
case EvaluateWorkflowStep():
outputs = await workflow.execute_activity(
evaluate_step,
context,
schedule_to_close_timeout=timedelta(seconds=600),
)
case YieldWorkflowStep():
outputs = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=[execution_input, (step.workflow, 0), previous_inputs],
)
# case ToolCallWorkflowStep():
# outputs = await workflow.execute_activity(
# tool_call_step,
# context,
# schedule_to_close_timeout=timedelta(seconds=600),
# )
case ToolCallWorkflowStep():
outputs = await workflow.execute_activity(
tool_call_step,
context,
schedule_to_close_timeout=timedelta(seconds=600),
)
# case ErrorWorkflowStep():
# result = await workflow.execute_activity(
# error_step,
Expand Down

0 comments on commit 687f487

Please sign in to comment.