Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add chat profiles #465

Merged
merged 23 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
58d54bd
feat: add chat profiles
tpatel Oct 10, 2023
8bedc24
fix: handle the 1 chat profile edge case
tpatel Oct 10, 2023
6679cb0
extract the markdown element configuration into a reusable component
tpatel Oct 11, 2023
93f15b7
add the popover markdown description to the chat profile selection
tpatel Oct 11, 2023
ae72a36
prevent uppercasing in the chat profile selector
tpatel Oct 11, 2023
52e3bfe
fixed missing text selector
tpatel Oct 11, 2023
c654858
Added markdown in the e2e example
tpatel Oct 11, 2023
3edc6ff
Moved chat profile reset into the useChat clear function
tpatel Oct 11, 2023
425264b
renamed chat_profiles to set_chat_profiles
tpatel Oct 11, 2023
68f72d5
renamed the chat profile `description` field to `markdown_description`
tpatel Oct 11, 2023
0f105b1
upgrade design and behavior based on team feedback
tpatel Oct 12, 2023
64c2af3
fixed confusing error message
tpatel Oct 12, 2023
0a2373f
make the icons grayscale when not focused nor hovered over
tpatel Oct 12, 2023
42c5d98
Add a grayscaled center logo while the conversation hasn't started yet
tpatel Oct 12, 2023
b71308b
iterated on design feedback and fixed tests
tpatel Oct 12, 2023
fd602e8
removed all #welcome-screen waits as we removed this element
tpatel Oct 12, 2023
b4a4857
wait for the chat-input after logging-in
tpatel Oct 12, 2023
99f53f5
fixed the linting issue
tpatel Oct 12, 2023
090f3d3
test fix
tpatel Oct 12, 2023
03655e0
another attempt, getting closer to the password_auth test
tpatel Oct 12, 2023
f2cb973
debug ci
willydouhard Oct 13, 2023
600aaf8
debug ci
willydouhard Oct 13, 2023
7bc5c4f
fix test
willydouhard Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.sync import make_async, run_sync
from chainlit.telemetry import trace
from chainlit.types import FileSpec
from chainlit.types import ChatProfile, FileSpec
from chainlit.user_session import user_session
from chainlit.utils import make_module_getattr, wrap_user_function
from chainlit.version import __version__
Expand Down Expand Up @@ -150,6 +150,24 @@ def on_chat_start(func: Callable) -> Callable:
return func


@trace
def set_chat_profiles(
func: Callable[[Optional["AppUser"]], List["ChatProfile"]]
) -> Callable:
"""
Programmatic declaration of the available chat profiles (can depend on the AppUser from the session if authentication is setup).

Args:
func (Callable[[Optional["AppUser"]], List["ChatProfile"]]): The function declaring the chat profiles.

Returns:
Callable[[Optional["AppUser"]], List["ChatProfile"]]: The decorated function.
"""

config.code.set_chat_profiles = wrap_user_function(func)
return func


@trace
def on_chat_end(func: Callable) -> Callable:
"""
Expand Down
5 changes: 5 additions & 0 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
if TYPE_CHECKING:
from chainlit.action import Action
from chainlit.client.base import AppUser
from chainlit.types import ChatProfile


BACKEND_ROOT = os.path.dirname(__file__)
PACKAGE_ROOT = os.path.dirname(os.path.dirname(BACKEND_ROOT))
Expand Down Expand Up @@ -171,6 +173,9 @@ class CodeSettings:
on_file_upload: Optional[Callable[[str], Any]] = None
author_rename: Optional[Callable[[str], str]] = None
on_settings_update: Optional[Callable[[Dict[str, Any]], Any]] = None
set_chat_profiles: Optional[
Callable[[Optional["AppUser"]], List["ChatProfile"]]
] = None


@dataclass()
Expand Down
6 changes: 6 additions & 0 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,18 @@ async def project_settings(
current_user: Annotated[Union[AppUser, PersistedAppUser], Depends(get_current_user)]
):
"""Return project settings. This is called by the UI before the establishing the websocket connection."""
profiles = []
if config.code.set_chat_profiles:
chat_profiles = await config.code.set_chat_profiles(current_user)
if chat_profiles:
profiles = [p.to_dict() for p in chat_profiles]
return JSONResponse(
content={
"ui": config.ui.to_dict(),
"userEnv": config.project.user_env,
"dataPersistence": config.data_persistence,
"markdown": get_markdown_str(config.root),
"chatProfiles": profiles,
}
)

Expand Down
14 changes: 12 additions & 2 deletions backend/chainlit/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@ def __init__(
user: Optional[Union["AppUser", "PersistedAppUser"]],
# Logged-in user token
token: Optional[str],
# User specific environment variables. Empty if no user environment variables are required.
user_env: Optional[Dict[str, str]],
# Last message at the root of the chat
root_message: Optional["Message"] = None,
# User specific environment variables. Empty if no user environment variables are required.
# Chat profile selected before the session was created
chat_profile: Optional[str] = None,
):
self.user = user
self.token = token
self.root_message = root_message
self.has_user_message = False
self.user_env = user_env or {}
self.chat_profile = chat_profile

self.id = id
self.conversation_id: Optional[str] = None
Expand Down Expand Up @@ -110,9 +113,16 @@ def __init__(
token: Optional[str],
# Last message at the root of the chat
root_message: Optional["Message"] = None,
# Chat profile selected before the session was created
chat_profile: Optional[str] = None,
):
super().__init__(
id=id, user=user, token=token, user_env=user_env, root_message=root_message
id=id,
user=user,
token=token,
user_env=user_env,
root_message=root_message,
chat_profile=chat_profile,
)

self.socket_id = socket_id
Expand Down
5 changes: 4 additions & 1 deletion backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def load_user_env(user_env):
@socket.on("connect")
async def connect(sid, environ, auth):
if not config.code.on_chat_start and not config.code.on_message:
raise ConnectionRefusedError("No websocket endpoint configured")
raise ConnectionRefusedError(
"You need to configure at least an on_chat_start or an on_message callback"
)

user = None
token = None
Expand Down Expand Up @@ -91,6 +93,7 @@ def ask_user_fn(data, timeout):
user_env=user_env,
user=user,
token=token,
chat_profile=environ.get("HTTP_X_CHAINLIT_CHAT_PROFILE"),
)

trace_event("connection_successful")
Expand Down
9 changes: 9 additions & 0 deletions backend/chainlit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,12 @@ class GetConversationsRequest(BaseModel):
class Theme(str, Enum):
light = "light"
dark = "dark"


@dataclass
class ChatProfile(DataClassJsonMixin):
"""Specification for a chat profile that can be chosen by the user at the conversation start."""

name: str
markdown_description: str
icon: Optional[str] = None
2 changes: 2 additions & 0 deletions backend/chainlit/user_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class UserSessionDict(TypedDict):
headers: Dict[str, str]
user: Optional[Union["AppUser", "PersistedAppUser"]]
root_message: Optional["Message"]
chat_profile: Optional[str]


user_sessions: Dict[str, UserSessionDict] = {}
Expand All @@ -39,6 +40,7 @@ def get(self, key, default=None):
user_session["env"] = context.session.user_env
user_session["chat_settings"] = context.session.chat_settings
user_session["user"] = context.session.user
user_session["chat_profile"] = context.session.chat_profile

if context.session.root_message:
user_session["root_message"] = context.session.root_message
Expand Down
20 changes: 10 additions & 10 deletions cypress.config.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import { defineConfig } from "cypress";
import { defineConfig } from 'cypress';

export default defineConfig({
projectId: "ij1tyk",
projectId: 'ij1tyk',
component: {
devServer: {
framework: "react",
bundler: "vite",
},
framework: 'react',
bundler: 'vite'
}
},

e2e: {
supportFile: false,
defaultCommandTimeout: 10000,
video: false,
baseUrl: "http://127.0.0.1:8000",
baseUrl: 'http://127.0.0.1:8000',
setupNodeEvents(on, config) {
on("task", {
on('task', {
log(message) {
console.log(message);
return null;
},
}
});
},
},
}
}
});
15 changes: 7 additions & 8 deletions cypress/e2e/ask_user/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import { runTestServer, submitMessage } from "../../support/testUtils";
import { runTestServer, submitMessage } from '../../support/testUtils';

describe("Ask User", () => {
describe('Ask User', () => {
before(() => {
runTestServer();
});

it("should send a new message containing the user input", () => {
cy.get("#welcome-screen").should("exist");
cy.get(".message").should("have.length", 1);
submitMessage("Jeeves");
it('should send a new message containing the user input', () => {
cy.get('.message').should('have.length', 1);
submitMessage('Jeeves');
cy.wait(2000);
cy.get(".message").should("have.length", 3);
cy.get('.message').should('have.length', 3);

cy.get(".message").eq(2).should("contain", "Jeeves");
cy.get('.message').eq(2).should('contain', 'Jeeves');
});
});
16 changes: 7 additions & 9 deletions cypress/e2e/audio_element/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import { runTestServer } from "../../support/testUtils";
import { runTestServer } from '../../support/testUtils';

describe("audio", () => {
describe('audio', () => {
before(() => {
runTestServer();
});

it("should be able to display an audio element", () => {
cy.get("#welcome-screen").should("exist");
it('should be able to display an audio element', () => {
cy.get('.message').should('have.length', 1);
cy.get('.message').eq(0).find('.inline-audio').should('have.length', 1);

cy.get(".message").should("have.length", 1);
cy.get(".message").eq(0).find(".inline-audio").should("have.length", 1);

cy.get(".inline-audio audio")
cy.get('.inline-audio audio')
.then(($el) => {
const audioElement = $el.get(0) as HTMLAudioElement;
return audioElement.play().then(() => {
return audioElement.duration;
});
})
.should("be.greaterThan", 0);
.should('be.greaterThan', 0);
});
});
18 changes: 8 additions & 10 deletions cypress/e2e/avatar/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import { runTestServer } from "../../support/testUtils";
import { runTestServer } from '../../support/testUtils';

describe("Avatar", () => {
describe('Avatar', () => {
before(() => {
runTestServer();
});

it("should be able to display a nested CoT", () => {
cy.get("#welcome-screen").should("exist");
it('should be able to display a nested CoT', () => {
cy.get('.message').should('have.length', 3);

cy.get(".message").should("have.length", 3);
cy.get('.message').eq(0).find('.message-avatar').should('have.length', 0);
cy.get('.message').eq(1).find('.message-avatar').should('have.length', 1);
cy.get('.message').eq(2).find('.message-avatar').should('have.length', 0);

cy.get(".message").eq(0).find(".message-avatar").should("have.length", 0);
cy.get(".message").eq(1).find(".message-avatar").should("have.length", 1);
cy.get(".message").eq(2).find(".message-avatar").should("have.length", 0);

cy.get(".element-link").should("have.length", 0);
cy.get('.element-link').should('have.length', 0);
});
});
69 changes: 69 additions & 0 deletions cypress/e2e/chat_profiles/.chainlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
[project]
# Whether to enable telemetry (default: true). No personal data is collected.
enable_telemetry = true

# List of environment variables to be provided by each user to use the app.
user_env = []

# Duration (in seconds) during which the session is saved when the connection is lost
session_timeout = 3600

# Enable third parties caching (e.g LangChain cache)
cache = false

# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
# follow_symlink = false

[features]
# Show the prompt playground
prompt_playground = true

[UI]
# Name of the app and chatbot.
name = "Chatbot"

# Description of the app and chatbot. This is used for HTML tags.
# description = ""

# Large size content are by default collapsed for a cleaner ui
default_collapse_content = true

# The default value for the expand messages settings.
default_expand_messages = false

# Hide the chain of thought details from the user in the UI.
hide_cot = false

# Link to your github repo. This will add a github button in the UI's header.
# github = ""

# Specify a CSS file that can be used to customize the user interface.
# The CSS file can be served from the public directory or via an external link.
# custom_css = "/public/test.css"

# If the app is served behind a reverse proxy (like cloud run) we need to know the base url for oauth
# base_url = "https://mydomain.com"

# Override default MUI light theme. (Check theme.ts)
[UI.theme.light]
#background = "#FAFAFA"
#paper = "#FFFFFF"

[UI.theme.light.primary]
#main = "#F80061"
#dark = "#980039"
#light = "#FFE7EB"

# Override default MUI dark theme. (Check theme.ts)
[UI.theme.dark]
#background = "#FAFAFA"
#paper = "#FFFFFF"

[UI.theme.dark.primary]
#main = "#F80061"
#dark = "#980039"
#light = "#FFE7EB"


[meta]
generated_by = "0.7.1"
48 changes: 48 additions & 0 deletions cypress/e2e/chat_profiles/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional

import chainlit as cl


@cl.set_chat_profiles
async def chat_profile(current_user: cl.AppUser):
if current_user.role != "ADMIN":
return None

return [
cl.ChatProfile(
name="GPT-3.5",
markdown_description="The underlying LLM model is **GPT-3.5**, a *175B parameter model* trained on 410GB of text data.",
),
cl.ChatProfile(
name="GPT-4",
markdown_description="The underlying LLM model is **GPT-4**, a *1.5T parameter model* trained on 3.5TB of text data.",
icon="https://picsum.photos/250",
),
cl.ChatProfile(
name="GPT-5",
markdown_description="The underlying LLM model is **GPT-5**.",
icon="https://picsum.photos/200",
),
]


@cl.password_auth_callback
def auth_callback(username: str, password: str) -> Optional[cl.AppUser]:
if (username, password) == ("admin", "admin"):
return cl.AppUser(username="admin", role="ADMIN", provider="credentials")
else:
return None


# @cl.on_message
# async def on_message(message: str):
# await cl.Message(content=f"echo: {message}").send()


@cl.on_chat_start
async def on_chat_start():
app_user = cl.user_session.get("user")
chat_profile = cl.user_session.get("chat_profile")
await cl.Message(
content=f"starting chat with {app_user.username} using the {chat_profile} chat profile"
).send()
Loading