Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ask user action #473

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
Video,
)
from chainlit.logger import logger
from chainlit.message import AskFileMessage, AskUserMessage, ErrorMessage, Message
from chainlit.message import (
AskActionMessage,
AskFileMessage,
AskUserMessage,
ErrorMessage,
Message,
)
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.sync import make_async, run_sync
from chainlit.telemetry import trace
Expand Down Expand Up @@ -299,6 +305,7 @@ def sleep(duration: int):
"Message",
"ErrorMessage",
"AskUserMessage",
"AskActionMessage",
"AskFileMessage",
"on_chat_start",
"on_chat_end",
Expand Down
84 changes: 82 additions & 2 deletions backend/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

from chainlit.action import Action
from chainlit.client.base import MessageDict
Expand All @@ -13,7 +13,14 @@
from chainlit.logger import logger
from chainlit.prompt import Prompt
from chainlit.telemetry import trace_event
from chainlit.types import AskFileResponse, AskFileSpec, AskResponse, AskSpec
from chainlit.types import (
AskActionResponse,
AskActionSpec,
AskFileResponse,
AskFileSpec,
AskResponse,
AskSpec,
)
from syncer import asyncio


Expand Down Expand Up @@ -485,3 +492,76 @@ async def send(self) -> Union[List[AskFileResponse], None]:
return [AskFileResponse(**r) for r in res]
else:
return None


class AskActionMessage(AskMessageBase):
"""
Ask the user to select an action before continuing.
If the user does not answer in time (see timeout), a TimeoutError will be raised or None will be returned depending on raise_on_timeout.
"""

def __init__(
self,
content: str,
actions: List[Action],
author=config.ui.name,
disable_human_feedback=False,
timeout=90,
raise_on_timeout=False,
):
self.content = content
self.actions = actions
self.author = author
self.disable_human_feedback = disable_human_feedback
self.timeout = timeout
self.raise_on_timeout = raise_on_timeout

super().__post_init__()

def to_dict(self):
return {
"id": self.id,
"createdAt": self.created_at,
"content": self.content,
"author": self.author,
"waitForAnswer": True,
"disableHumanFeedback": self.disable_human_feedback,
"timeout": self.timeout,
"raiseOnTimeout": self.raise_on_timeout,
}

async def send(self) -> Union[AskActionResponse, None]:
"""
Sends the question to ask to the UI and waits for the reply
"""
trace_event("send_ask_action")

if self.streaming:
self.streaming = False

if config.code.author_rename:
self.author = await config.code.author_rename(self.author)

msg_dict = await self._create()
action_keys = []

for action in self.actions:
action_keys.append(action.id)
await action.send(for_id=str(msg_dict["id"]))

spec = AskActionSpec(type="action", timeout=self.timeout, keys=action_keys)

res = cast(
Union[AskActionResponse, None],
await context.emitter.send_ask_user(msg_dict, spec, self.raise_on_timeout),
)

for action in self.actions:
await action.remove()
if res is None:
self.content = "Timed out: no action was taken"
else:
self.content = f'**Selected action:** {res["label"]}'
await self.update()

return res
22 changes: 21 additions & 1 deletion backend/chainlit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,29 @@ class FileSpec(DataClassJsonMixin):
max_size_mb: int


@dataclass
class ActionSpec(DataClassJsonMixin):
keys: List[str]


@dataclass
class AskSpec(DataClassJsonMixin):
"""Specification for asking the user."""

timeout: int
type: Literal["text", "file"]
type: Literal["text", "file", "action"]


@dataclass
class AskFileSpec(FileSpec, AskSpec, DataClassJsonMixin):
"""Specification for asking the user a file."""


@dataclass
class AskActionSpec(ActionSpec, AskSpec, DataClassJsonMixin):
"""Specification for asking the user an action"""


class AskResponse(TypedDict):
content: str
author: str
Expand All @@ -46,6 +56,16 @@ class AskFileResponse:
content: bytes


class AskActionResponse(TypedDict):
name: str
value: str
label: str
description: str
forId: str
id: str
collapsed: bool


class CompletionRequest(BaseModel):
prompt: Prompt
userEnv: Dict[str, str]
Expand Down
21 changes: 21 additions & 0 deletions cypress/e2e/action/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,24 @@ async def main():
message = cl.Message("Hello, this is a test message!", actions=actions)
cl.user_session.set("to_remove", message)
await message.send()

result = await cl.AskActionMessage(
content="Please, pick an action!",
actions=[
cl.Action(
id="first-action",
name="first_action",
value="first-action",
label="First action",
),
cl.Action(
id="second-action",
name="second_action",
value="second-action",
label="Second action",
),
],
).send()

if result != None:
await cl.Message(f"Thanks for pressing: {result['value']}").send()
26 changes: 17 additions & 9 deletions cypress/e2e/action/spec.cy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,55 @@ describe('Action', () => {
});

it('should correctly execute and display actions', () => {
// Click on "first action"
cy.get('#first-action').should('be.visible');
cy.get('#first-action').click();
cy.get('.message').should('have.length', 3);
cy.get('.message')
.eq(2)
.should('contain', 'Thanks for pressing: first-action');

// Click on "test action"
cy.get("[id='test-action']").should('be.visible');
cy.get("[id='test-action']").click();
cy.get('.message').should('have.length', 2);
cy.get('.message').eq(1).should('contain', 'Executed test action!');
cy.get('.message').should('have.length', 4);
cy.get('.message').eq(3).should('contain', 'Executed test action!');
cy.get("[id='test-action']").should('exist');

cy.wait(100);

// Click on "removable action"
cy.get("[id='removable-action']").should('be.visible');
cy.get("[id='removable-action']").click();
cy.get('.message').should('have.length', 3);
cy.get('.message').eq(2).should('contain', 'Executed removable action!');
cy.get('.message').should('have.length', 5);
cy.get('.message').eq(4).should('contain', 'Executed removable action!');
cy.get("[id='removable-action']").should('not.exist');

cy.wait(100);

// Click on "multiple action one" in the action drawer, should remove the correct action button
cy.get("[id='actions-drawer-button']").should('be.visible');
cy.get("[id='actions-drawer-button']").click();
cy.get('.message').should('have.length', 3);
cy.get('.message').should('have.length', 5);

cy.wait(100);

cy.get("[id='multiple-action-one']").should('be.visible');
cy.get("[id='multiple-action-one']").click();
cy.get('.message')
.eq(3)
.eq(5)
.should('contain', 'Action(id=multiple-action-one) has been removed!');
cy.get("[id='multiple-action-one']").should('not.exist');

cy.wait(100);

// Click on "multiple action two", should remove the correct action button
cy.get('.message').should('have.length', 4);
cy.get('.message').should('have.length', 6);
cy.get("[id='actions-drawer-button']").click();
cy.get("[id='multiple-action-two']").should('be.visible');
cy.get("[id='multiple-action-two']").click();
cy.get('.message')
.eq(4)
.eq(6)
.should('contain', 'Action(id=multiple-action-two) has been removed!');
cy.get("[id='multiple-action-two']").should('not.exist');

Expand All @@ -56,7 +64,7 @@ describe('Action', () => {
cy.get("[id='all-actions-removed']").should('be.visible');
cy.get("[id='all-actions-removed']").click();
cy.get('.message')
.eq(5)
.eq(7)
.should('contain', 'All actions have been removed!');
cy.get("[id='all-actions-removed']").should('not.exist');
cy.get("[id='test-action']").should('not.exist');
Expand Down
6 changes: 5 additions & 1 deletion frontend/src/components/organisms/chat/inputBox/input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ const Input = ({ onSubmit, onReply }: Props) => {
const [value, setValue] = useState('');
const [isComposing, setIsComposing] = useState(false);

const disabled = !connected || loading || askUser?.spec.type === 'file';
const disabled =
!connected ||
loading ||
askUser?.spec.type === 'file' ||
askUser?.spec.type === 'action';

useEffect(() => {
if (ref.current && !loading && !disabled) {
Expand Down
19 changes: 13 additions & 6 deletions libs/components/src/messages/components/ActionButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@ interface ActionProps {
}

const ActionButton = ({ action, margin, onClick }: ActionProps) => {
const { loading } = useContext(MessageContext);
const { askUser, loading } = useContext(MessageContext);
const isAskingAction = askUser?.spec.type === 'action';
const isDisabled = isAskingAction && !askUser?.spec.keys?.includes(action.id);
const handleClick = () => {
if (isAskingAction) {
askUser?.callback(action);
} else {
action.onClick();
onClick?.();
}
};

return (
<Tooltip title={action.description} placement="top">
Expand All @@ -25,11 +35,8 @@ const ActionButton = ({ action, margin, onClick }: ActionProps) => {
margin
}}
id={action.id}
onClick={() => {
action.onClick();
onClick?.();
}}
disabled={loading}
onClick={handleClick}
disabled={loading || isDisabled}
>
{action.label || action.name}
</Button>
Expand Down
12 changes: 9 additions & 3 deletions libs/components/src/types/file.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { IAction } from './action';
import { IMessage } from './message';

export interface FileSpec {
Expand All @@ -6,6 +7,10 @@ export interface FileSpec {
max_files?: number;
}

export interface ActionSpec {
keys?: string[];
}

export interface IFileResponse {
name: string;
path?: string;
Expand All @@ -15,9 +20,10 @@ export interface IFileResponse {
}

export interface IAsk {
callback: (payload: IMessage | IFileResponse[]) => void;
callback: (payload: IMessage | IFileResponse[] | IAction) => void;
spec: {
type: 'text' | 'file';
type: 'text' | 'file' | 'action';
timeout: number;
} & FileSpec;
} & FileSpec &
ActionSpec;
}