Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-organize fill-form agent #214

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .agent import execute, extend
from .config import FormFillAgentConfig
from .step import LLMConfig
from .steps.types import LLMConfig

__all__ = [
"execute",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from . import state
from .config import FormFillAgentConfig
from .step import Context, IncompleteErrorResult, IncompleteResult, LLMConfig
from .steps import acquire_form, extract_form_fields, fill_form
from .steps import acquire_form_step, extract_form_fields_step, fill_form_step
from .steps.types import ConfigT, Context, IncompleteErrorResult, IncompleteResult, LLMConfig

logger = logging.getLogger(__name__)

Expand All @@ -23,133 +23,105 @@ async def execute(
) -> None:
user_messages = [latest_user_message]

async with state.agent_state(context) as agent_state:
for mode in state.FormFillAgentMode:
if mode in agent_state.mode_debug_log:
continue

agent_state.mode_debug_log[mode] = []
def build_step_context(config: ConfigT) -> Context[ConfigT]:
return Context(
context=context, llm_config=llm_config, config=config, get_attachment_messages=get_attachment_messages
)

async with state.agent_state(context) as agent_state:
while True:
logger.info("form-fill-agent step; mode: %s", agent_state.mode)
logger.info("form-fill-agent execute loop; mode: %s", agent_state.mode)

match agent_state.mode:
case state.FormFillAgentMode.acquire_form_step:
step_context = Context(
context=context,
llm_config=llm_config,
config=config.acquire_form_config,
get_attachment_messages=get_attachment_messages,
)

result = await acquire_form.execute(
step_context=step_context,
result = await acquire_form_step.execute(
step_context=build_step_context(config.acquire_form_config),
latest_user_message=user_messages.pop() if user_messages else None,
)

agent_state.mode_debug_log[agent_state.mode].insert(0, result.debug)
match result:
case IncompleteResult():
await _send_message(context, result.ai_message, result.debug)
return

case IncompleteErrorResult():
await _send_error_message(context, result.error_message, result.debug)
return

case acquire_form.CompleteResult():
case acquire_form_step.CompleteResult():
await _send_message(context, result.ai_message, result.debug)

agent_state.form_filename = result.filename
agent_state.mode = state.FormFillAgentMode.extract_form_fields_step
continue

case IncompleteResult() | IncompleteErrorResult():
await _handle_incomplete_results(context, result)
return

case _:
raise ValueError(f"Unexpected result: {result}")

case state.FormFillAgentMode.extract_form_fields_step:
step_context = Context(
context=context,
llm_config=llm_config,
config=config.extract_form_fields_config,
get_attachment_messages=get_attachment_messages,
)

result = await extract_form_fields.execute(
step_context=step_context,
result = await extract_form_fields_step.execute(
step_context=build_step_context(config.extract_form_fields_config),
filename=agent_state.form_filename,
)

agent_state.mode_debug_log[agent_state.mode].insert(0, result.debug)
match result:
case IncompleteErrorResult():
await _send_error_message(context, result.error_message, result.debug)
return

case IncompleteResult():
await _send_message(context, result.ai_message, result.debug)
return

case extract_form_fields.CompleteResult():
case extract_form_fields_step.CompleteResult():
await _send_message(context, result.ai_message, result.debug)

agent_state.extracted_form_fields = result.extracted_form_fields
agent_state.mode = state.FormFillAgentMode.fill_form_step
continue

case IncompleteResult() | IncompleteErrorResult():
await _handle_incomplete_results(context, result)
return

case _:
raise ValueError(f"Unexpected result: {result}")

case state.FormFillAgentMode.fill_form_step:
step_context = Context(
context=context,
llm_config=llm_config,
config=config.fill_form_config,
get_attachment_messages=get_attachment_messages,
)

result = await fill_form.execute(
step_context=step_context,
result = await fill_form_step.execute(
step_context=build_step_context(config.fill_form_config),
latest_user_message=user_messages.pop() if user_messages else None,
form_fields=agent_state.extracted_form_fields,
)

agent_state.mode_debug_log[agent_state.mode].insert(0, result.debug)
match result:
case IncompleteResult():
await _send_message(context, result.ai_message, result.debug)
return

case IncompleteErrorResult():
await _send_error_message(context, result.error_message, result.debug)
return

case fill_form.CompleteResult():
case fill_form_step.CompleteResult():
await _send_message(context, result.ai_message, result.debug)

agent_state.fill_form_gc_artifact = result.artifact
agent_state.mode = state.FormFillAgentMode.generate_filled_form_step
continue

case IncompleteResult() | IncompleteErrorResult():
await _handle_incomplete_results(context, result)
return

case _:
raise ValueError(f"Unexpected result: {result}")

case state.FormFillAgentMode.generate_filled_form_step:
await context.send_messages(
NewConversationMessage(
content="I'd love to generate the fill form now, but it's not yet implemented. :)"
)
await _send_message(
context, "I'd love to generate the filled-out form now, but it's not yet implemented. :)", {}
)
return

case state.FormFillAgentMode.end_conversation:
await context.send_messages(NewConversationMessage(content="Conversation has ended."))
await _send_message(context, "Conversation has ended.", {})
return

case _:
raise ValueError(f"Unexpected mode: {state.mode}")


async def _handle_incomplete_results(
context: ConversationContext, result: IncompleteErrorResult | IncompleteResult
) -> None:
match result:
case IncompleteResult():
await _send_message(context, result.ai_message, result.debug)

case IncompleteErrorResult():
await _send_error_message(context, result.error_message, result.debug)

case _:
raise ValueError(f"Unexpected incomplete result: {result}")


async def _send_message(context: ConversationContext, message: str, debug: dict) -> None:
if not message:
return
Expand Down Expand Up @@ -182,5 +154,5 @@ def extend(app: AssistantAppProtocol) -> None:
app.add_inspector_state_provider(state.inspector.state_id, state.inspector)

# for step level states
acquire_form.extend(app)
fill_form.extend(app)
acquire_form_step.extend(app)
fill_form_step.extend(app)
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@

from pydantic import BaseModel, Field

from . import gce_config
from .steps import acquire_form, extract_form_fields, fill_form
from .steps import acquire_form_step, extract_form_fields_step, fill_form_step, types


class FormFillAgentConfig(BaseModel):
acquire_form_config: Annotated[
gce_config.GuidedConversationDefinition,
types.GuidedConversationDefinition,
Field(title="Form Acquisition", description="Guided conversation for acquiring a form from the user."),
] = acquire_form.definition.model_copy()
] = acquire_form_step.definition.model_copy()

extract_form_fields_config: Annotated[
extract_form_fields.ExtractFormFieldsConfig,
extract_form_fields_step.ExtractFormFieldsConfig,
Field(title="Extract Form Fields", description="Configuration for extracting form fields from the form."),
] = extract_form_fields.ExtractFormFieldsConfig()
] = extract_form_fields_step.ExtractFormFieldsConfig()

fill_form_config: Annotated[
gce_config.GuidedConversationDefinition,
types.GuidedConversationDefinition,
Field(title="Fill Form", description="Guided conversation for filling out the form."),
] = fill_form.definition.model_copy()
] = fill_form_step.definition.model_copy()
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@


class FileStateInspector(ReadOnlyAssistantConversationInspectorStateProvider):
"""
A conversation inspector state provider that reads the state from a file and displays it as a yaml code block.
"""

def __init__(
self,
display_name: str,
Expand Down Expand Up @@ -53,6 +57,7 @@ def read_state(path: Path) -> dict:

@contextlib.asynccontextmanager
async def state_change_event_after(context: ConversationContext, state_id: str, set_focus=False) -> AsyncIterator[None]:
"""Raise a state change event after the context manager block is executed (optionally set focus as well)"""
yield
if set_focus:
await context.send_conversation_state_event(AssistantStateEvent(state_id=state_id, event="focus", state=None))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class FormFillAgentState(BaseModel):
extracted_form_fields: list[FormField] = []
fill_form_gc_artifact: dict | None = None

mode_debug_log: dict[FormFillAgentMode, list[dict]] = {}


def path_for_state(context: ConversationContext) -> Path:
return storage_directory_for_context(context) / "state.json"
Expand All @@ -48,6 +46,10 @@ def path_for_state(context: ConversationContext) -> Path:

@asynccontextmanager
async def agent_state(context: ConversationContext) -> AsyncIterator[FormFillAgentState]:
"""
Context manager that provides the agent state, reading it from disk, and saving back
to disk after the context manager block is executed.
"""
state = current_state.get()
if state is not None:
yield state
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Utility functions for handling attachments in chat messages.
"""

from typing import Awaitable, Callable, Sequence

from openai.types.chat import ChatCompletionMessageParam
from semantic_workbench_assistant.assistant_app.context import ConversationContext

from .. import state


async def message_with_recent_attachments(
context: ConversationContext,
latest_user_message: str | None,
get_attachment_messages: Callable[[Sequence[str]], Awaitable[Sequence[ChatCompletionMessageParam]]],
) -> str:
files = await context.get_files()

new_filenames = set()

async with state.agent_state(context) as agent_state:
max_timestamp = agent_state.most_recent_attachment_timestamp
for file in files.files:
if file.updated_datetime.timestamp() <= agent_state.most_recent_attachment_timestamp:
continue

max_timestamp = max(file.updated_datetime.timestamp(), max_timestamp)
new_filenames.add(file.filename)

agent_state.most_recent_attachment_timestamp = max_timestamp

attachment_messages = await get_attachment_messages(list(new_filenames))

return "\n\n".join(
(
latest_user_message or "",
*(
str(attachment.get("content"))
for attachment in attachment_messages
if "<ATTACHMENT>" in str(attachment.get("content", ""))
),
),
)


async def attachment_for_filename(
filename: str, get_attachment_messages: Callable[[Sequence[str]], Awaitable[Sequence[ChatCompletionMessageParam]]]
) -> str:
attachment_messages = await get_attachment_messages([filename])
return "\n\n".join(
(
str(attachment.get("content"))
for attachment in attachment_messages
if "<ATTACHMENT>" in str(attachment.get("content", ""))
)
)
Loading