Skip to content

Commit

Permalink
Re-organize fill-form agent (#214)
Browse files Browse the repository at this point in the history
To ease reading/understanding/debugging of code
  • Loading branch information
markwaddle authored Nov 5, 2024
1 parent c48cafd commit e4537e4
Show file tree
Hide file tree
Showing 13 changed files with 329 additions and 323 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .agent import execute, extend
from .config import FormFillAgentConfig
from .step import LLMConfig
from .steps.types import LLMConfig

__all__ = [
"execute",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from . import state
from .config import FormFillAgentConfig
from .step import Context, IncompleteErrorResult, IncompleteResult, LLMConfig
from .steps import acquire_form, extract_form_fields, fill_form
from .steps import acquire_form_step, extract_form_fields_step, fill_form_step
from .steps.types import ConfigT, Context, IncompleteErrorResult, IncompleteResult, LLMConfig

logger = logging.getLogger(__name__)

Expand All @@ -23,133 +23,105 @@ async def execute(
) -> None:
user_messages = [latest_user_message]

async with state.agent_state(context) as agent_state:
for mode in state.FormFillAgentMode:
if mode in agent_state.mode_debug_log:
continue

agent_state.mode_debug_log[mode] = []
def build_step_context(config: ConfigT) -> Context[ConfigT]:
return Context(
context=context, llm_config=llm_config, config=config, get_attachment_messages=get_attachment_messages
)

async with state.agent_state(context) as agent_state:
while True:
logger.info("form-fill-agent step; mode: %s", agent_state.mode)
logger.info("form-fill-agent execute loop; mode: %s", agent_state.mode)

match agent_state.mode:
case state.FormFillAgentMode.acquire_form_step:
step_context = Context(
context=context,
llm_config=llm_config,
config=config.acquire_form_config,
get_attachment_messages=get_attachment_messages,
)

result = await acquire_form.execute(
step_context=step_context,
result = await acquire_form_step.execute(
step_context=build_step_context(config.acquire_form_config),
latest_user_message=user_messages.pop() if user_messages else None,
)

agent_state.mode_debug_log[agent_state.mode].insert(0, result.debug)
match result:
case IncompleteResult():
await _send_message(context, result.ai_message, result.debug)
return

case IncompleteErrorResult():
await _send_error_message(context, result.error_message, result.debug)
return

case acquire_form.CompleteResult():
case acquire_form_step.CompleteResult():
await _send_message(context, result.ai_message, result.debug)

agent_state.form_filename = result.filename
agent_state.mode = state.FormFillAgentMode.extract_form_fields_step
continue

case IncompleteResult() | IncompleteErrorResult():
await _handle_incomplete_results(context, result)
return

case _:
raise ValueError(f"Unexpected result: {result}")

case state.FormFillAgentMode.extract_form_fields_step:
step_context = Context(
context=context,
llm_config=llm_config,
config=config.extract_form_fields_config,
get_attachment_messages=get_attachment_messages,
)

result = await extract_form_fields.execute(
step_context=step_context,
result = await extract_form_fields_step.execute(
step_context=build_step_context(config.extract_form_fields_config),
filename=agent_state.form_filename,
)

agent_state.mode_debug_log[agent_state.mode].insert(0, result.debug)
match result:
case IncompleteErrorResult():
await _send_error_message(context, result.error_message, result.debug)
return

case IncompleteResult():
await _send_message(context, result.ai_message, result.debug)
return

case extract_form_fields.CompleteResult():
case extract_form_fields_step.CompleteResult():
await _send_message(context, result.ai_message, result.debug)

agent_state.extracted_form_fields = result.extracted_form_fields
agent_state.mode = state.FormFillAgentMode.fill_form_step
continue

case IncompleteResult() | IncompleteErrorResult():
await _handle_incomplete_results(context, result)
return

case _:
raise ValueError(f"Unexpected result: {result}")

case state.FormFillAgentMode.fill_form_step:
step_context = Context(
context=context,
llm_config=llm_config,
config=config.fill_form_config,
get_attachment_messages=get_attachment_messages,
)

result = await fill_form.execute(
step_context=step_context,
result = await fill_form_step.execute(
step_context=build_step_context(config.fill_form_config),
latest_user_message=user_messages.pop() if user_messages else None,
form_fields=agent_state.extracted_form_fields,
)

agent_state.mode_debug_log[agent_state.mode].insert(0, result.debug)
match result:
case IncompleteResult():
await _send_message(context, result.ai_message, result.debug)
return

case IncompleteErrorResult():
await _send_error_message(context, result.error_message, result.debug)
return

case fill_form.CompleteResult():
case fill_form_step.CompleteResult():
await _send_message(context, result.ai_message, result.debug)

agent_state.fill_form_gc_artifact = result.artifact
agent_state.mode = state.FormFillAgentMode.generate_filled_form_step
continue

case IncompleteResult() | IncompleteErrorResult():
await _handle_incomplete_results(context, result)
return

case _:
raise ValueError(f"Unexpected result: {result}")

case state.FormFillAgentMode.generate_filled_form_step:
await context.send_messages(
NewConversationMessage(
content="I'd love to generate the fill form now, but it's not yet implemented. :)"
)
await _send_message(
context, "I'd love to generate the filled-out form now, but it's not yet implemented. :)", {}
)
return

case state.FormFillAgentMode.end_conversation:
await context.send_messages(NewConversationMessage(content="Conversation has ended."))
await _send_message(context, "Conversation has ended.", {})
return

case _:
raise ValueError(f"Unexpected mode: {state.mode}")


async def _handle_incomplete_results(
context: ConversationContext, result: IncompleteErrorResult | IncompleteResult
) -> None:
match result:
case IncompleteResult():
await _send_message(context, result.ai_message, result.debug)

case IncompleteErrorResult():
await _send_error_message(context, result.error_message, result.debug)

case _:
raise ValueError(f"Unexpected incomplete result: {result}")


async def _send_message(context: ConversationContext, message: str, debug: dict) -> None:
if not message:
return
Expand Down Expand Up @@ -182,5 +154,5 @@ def extend(app: AssistantAppProtocol) -> None:
app.add_inspector_state_provider(state.inspector.state_id, state.inspector)

# for step level states
acquire_form.extend(app)
fill_form.extend(app)
acquire_form_step.extend(app)
fill_form_step.extend(app)
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@

from pydantic import BaseModel, Field

from . import gce_config
from .steps import acquire_form, extract_form_fields, fill_form
from .steps import acquire_form_step, extract_form_fields_step, fill_form_step, types


class FormFillAgentConfig(BaseModel):
acquire_form_config: Annotated[
gce_config.GuidedConversationDefinition,
types.GuidedConversationDefinition,
Field(title="Form Acquisition", description="Guided conversation for acquiring a form from the user."),
] = acquire_form.definition.model_copy()
] = acquire_form_step.definition.model_copy()

extract_form_fields_config: Annotated[
extract_form_fields.ExtractFormFieldsConfig,
extract_form_fields_step.ExtractFormFieldsConfig,
Field(title="Extract Form Fields", description="Configuration for extracting form fields from the form."),
] = extract_form_fields.ExtractFormFieldsConfig()
] = extract_form_fields_step.ExtractFormFieldsConfig()

fill_form_config: Annotated[
gce_config.GuidedConversationDefinition,
types.GuidedConversationDefinition,
Field(title="Fill Form", description="Guided conversation for filling out the form."),
] = fill_form.definition.model_copy()
] = fill_form_step.definition.model_copy()
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@


class FileStateInspector(ReadOnlyAssistantConversationInspectorStateProvider):
"""
A conversation inspector state provider that reads the state from a file and displays it as a yaml code block.
"""

def __init__(
self,
display_name: str,
Expand Down Expand Up @@ -53,6 +57,7 @@ def read_state(path: Path) -> dict:

@contextlib.asynccontextmanager
async def state_change_event_after(context: ConversationContext, state_id: str, set_focus=False) -> AsyncIterator[None]:
"""Raise a state change event after the context manager block is executed (optionally set focus as well)"""
yield
if set_focus:
await context.send_conversation_state_event(AssistantStateEvent(state_id=state_id, event="focus", state=None))
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class FormFillAgentState(BaseModel):
extracted_form_fields: list[FormField] = []
fill_form_gc_artifact: dict | None = None

mode_debug_log: dict[FormFillAgentMode, list[dict]] = {}


def path_for_state(context: ConversationContext) -> Path:
return storage_directory_for_context(context) / "state.json"
Expand All @@ -48,6 +46,10 @@ def path_for_state(context: ConversationContext) -> Path:

@asynccontextmanager
async def agent_state(context: ConversationContext) -> AsyncIterator[FormFillAgentState]:
"""
Context manager that provides the agent state, reading it from disk, and saving back
to disk after the context manager block is executed.
"""
state = current_state.get()
if state is not None:
yield state
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Utility functions for handling attachments in chat messages.
"""

from typing import Awaitable, Callable, Sequence

from openai.types.chat import ChatCompletionMessageParam
from semantic_workbench_assistant.assistant_app.context import ConversationContext

from .. import state


async def message_with_recent_attachments(
context: ConversationContext,
latest_user_message: str | None,
get_attachment_messages: Callable[[Sequence[str]], Awaitable[Sequence[ChatCompletionMessageParam]]],
) -> str:
files = await context.get_files()

new_filenames = set()

async with state.agent_state(context) as agent_state:
max_timestamp = agent_state.most_recent_attachment_timestamp
for file in files.files:
if file.updated_datetime.timestamp() <= agent_state.most_recent_attachment_timestamp:
continue

max_timestamp = max(file.updated_datetime.timestamp(), max_timestamp)
new_filenames.add(file.filename)

agent_state.most_recent_attachment_timestamp = max_timestamp

attachment_messages = await get_attachment_messages(list(new_filenames))

return "\n\n".join(
(
latest_user_message or "",
*(
str(attachment.get("content"))
for attachment in attachment_messages
if "<ATTACHMENT>" in str(attachment.get("content", ""))
),
),
)


async def attachment_for_filename(
filename: str, get_attachment_messages: Callable[[Sequence[str]], Awaitable[Sequence[ChatCompletionMessageParam]]]
) -> str:
attachment_messages = await get_attachment_messages([filename])
return "\n\n".join(
(
str(attachment.get("content"))
for attachment in attachment_messages
if "<ATTACHMENT>" in str(attachment.get("content", ""))
)
)
Loading

0 comments on commit e4537e4

Please sign in to comment.