Skip to content

Commit

Permalink
feat: support slash command
Browse files Browse the repository at this point in the history
  • Loading branch information
lisiur committed Apr 10, 2023
1 parent 6a43770 commit c4b231f
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 61 deletions.
16 changes: 16 additions & 0 deletions service/src/commands/cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,22 @@ impl UpdateChatCommand {
}
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RemoveChatPromptCommand {
id: Id,
}

impl RemoveChatPromptCommand {
pub fn exec(self, conn: &DbConn) -> Result<()> {
let chat_service = ChatService::new(conn.clone());

chat_service.remove_prompt(self.id)?;

Ok(())
}
}

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LoadChatLogByCursorCommand {
Expand Down
4 changes: 4 additions & 0 deletions service/src/commands/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ impl CommandExecutor {
.exec(conn)
.into_result(),

"remove_chat_prompt" => from_value::<RemoveChatPromptCommand>(payload)?
.exec(conn)
.into_result(),

"delete_chat" => from_value::<DeleteChatCommand>(payload)?
.exec(conn)
.into_result(),
Expand Down
16 changes: 14 additions & 2 deletions service/src/repositories/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ impl ChatRepo {
let nonstick_chats = self.select_non_stick(user_id)?;
let archived = self.select_archived(user_id)?;

let all_chats = stick_chats.into_iter().chain(nonstick_chats).chain(archived).collect();
let all_chats = stick_chats
.into_iter()
.chain(nonstick_chats)
.chain(archived)
.collect();

Ok(all_chats)
}
Expand Down Expand Up @@ -210,7 +214,6 @@ impl ChatRepo {
}

pub fn insert_if_not_exist(&self, chat: &NewChat) -> Result<usize> {

let size = diesel::insert_into(chats::table)
.values(chat)
.on_conflict(chats::id)
Expand All @@ -229,6 +232,15 @@ impl ChatRepo {
Ok(size)
}

pub fn remove_prompt(&self, id: Id) -> Result<()> {
diesel::update(chats::table)
.filter(chats::id.eq(id))
.set(chats::prompt_id.eq(Option::<Id>::None))
.execute(&mut *self.0.conn())?;

Ok(())
}

pub fn add_cost_and_update(&self, id: Id, cost: f32) -> Result<usize> {
let size = diesel::update(chats::table)
.filter(chats::id.eq(id))
Expand Down
6 changes: 6 additions & 0 deletions service/src/services/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ impl ChatService {
Ok(())
}

pub fn remove_prompt(&self, chat_id: Id) -> Result<()> {
self.chat_repo.remove_prompt(chat_id)?;

Ok(())
}

pub fn set_chat_archive(&self, chat_id: Id) -> Result<()> {
self.chat_repo.update(&PatchChat {
id: chat_id,
Expand Down
4 changes: 4 additions & 0 deletions web/src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ export async function updateChat(payload: ChatUpdatePayload) {
return execCommand<void>("update_chat", { payload });
}

export async function removeChatPrompt(chatId: string) {
return execCommand<void>("remove_chat_prompt", { id: chatId });
}

export async function newChat(params?: { promptId?: string; title?: string }) {
return execCommand<string>("new_chat", params);
}
Expand Down
10 changes: 9 additions & 1 deletion web/src/components/chat/Header.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@ export default defineComponent({
<NTooltip contentStyle="max-width: 30rem">
{{
trigger: () => (
<NTag size="small" round type="primary">
<NTag
size="small"
round
type="primary"
closable
onClose={() => {
props.chat.removePrompt();
}}
>
{prompt.value?.name}
</NTag>
),
Expand Down
156 changes: 100 additions & 56 deletions web/src/components/chat/UserInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ import { Message, UserMessage } from "../../models/message";
import { message } from "../../utils/prompt";
import Backtrack from "./Backtrack";
import Cost from "./Cost";
import { NScrollbar } from "naive-ui";
import { NScrollbar, useMessage } from "naive-ui";
import { autoGrowTextarea } from "../../utils/autoGrowTextarea";
import { usePromptService } from "../../services/prompt";
import { PromptIndex } from "../../api";
import CommandPanel from "./commandPanel";

export default defineComponent({
props: {
Expand All @@ -36,10 +39,16 @@ export default defineComponent({
const inputRef = ref<HTMLTextAreaElement>();
const { isComposition } = useComposition(inputRef);
const userMessage = ref("");
const inputStatus = ref<"normal" | "historyNavigation">("normal");
const inputStatus = ref<"normal" | "command" | "historyNavigation">(
"normal"
);
let historyNavigationMessageId = null as string | null;
const historyNavigationStack = [] as Message[];

const { fuzzySearchPrompts } = usePromptService();
const filteredPrompts = ref<Array<PromptIndex>>([]);
const selectedPromptIndex = ref(0);

const publicInstance = {
focus,
};
Expand All @@ -49,12 +58,6 @@ export default defineComponent({
inputRef.value?.focus();
});

watch(userMessage, (msg) => {
if (!msg) {
inputStatus.value = "normal";
}
});

function setUserMessage(content: string) {
userMessage.value = content;
nextTick(() => {
Expand All @@ -72,61 +75,93 @@ export default defineComponent({
});
}

async function keydownHandler(e: KeyboardEvent) {
if (
inputStatus.value === "normal" &&
["ArrowUp", "ArrowDown"].includes(e.key)
) {
inputStatus.value = "historyNavigation";
historyNavigationStack.length = 0;
}

if (
inputStatus.value === "historyNavigation" &&
!["ArrowUp", "ArrowDown"].includes(e.key)
) {
watch(userMessage, (msg) => {
if (!msg) {
inputStatus.value = "normal";
historyNavigationMessageId = null;
historyNavigationStack.length = 0;
filteredPrompts.value = [];
} else if (msg.startsWith("/")) {
inputStatus.value = "command";
filteredPrompts.value = fuzzySearchPrompts(userMessage.value.slice(1));
selectedPromptIndex.value = 0;
}
});

if (e.key === "Tab") {
// Expand tab to 4 spaces
e.preventDefault();
const start = inputRef.value?.selectionStart;
const end = inputRef.value?.selectionEnd;
if (start !== undefined && end !== undefined) {
userMessage.value =
userMessage.value.substring(0, start) +
" " +
userMessage.value.substring(end);
nextTick(() => {
inputRef.value?.setSelectionRange(start + 4, start + 4);
});
}
} else if (
e.key === "Enter" &&
!e.ctrlKey &&
!e.altKey &&
!e.shiftKey &&
!isComposition.value
) {
// Send message
async function keydownHandler(e: KeyboardEvent) {
if (inputStatus.value === "normal") {
if (e.key === "/") {
inputStatus.value = "command";
return;
} else if (["ArrowUp", "ArrowDown"].includes(e.key)) {
inputStatus.value = "historyNavigation";
historyNavigationStack.length = 0;
} else if (e.key === "Tab") {
// Expand tab to 4 spaces
e.preventDefault();
const start = inputRef.value?.selectionStart;
const end = inputRef.value?.selectionEnd;
if (start !== undefined && end !== undefined) {
userMessage.value =
userMessage.value.substring(0, start) +
" " +
userMessage.value.substring(end);
nextTick(() => {
inputRef.value?.setSelectionRange(start + 4, start + 4);
});
}
} else if (
e.key === "Enter" &&
!e.ctrlKey &&
!e.altKey &&
!e.shiftKey &&
!isComposition.value
) {
// Send message

// Check if the reply is finished
if (props.chat.busy.value) {
message.warning(t("chat.busy"));
e.preventDefault();
return;
}

props.onMessage?.(new UserMessage(userMessage.value));
props.sendMessage(userMessage.value);
userMessage.value = "";

// Check if the reply is finished
if (props.chat.busy.value) {
message.warning(t("chat.busy"));
e.preventDefault();
}
} else if (inputStatus.value === "command") {
if (e.key === "ArrowUp") {
selectedPromptIndex.value = Math.max(
0,
selectedPromptIndex.value - 1
);
e.preventDefault();
return;
} else if (e.key === "ArrowDown") {
selectedPromptIndex.value = Math.min(
filteredPrompts.value.length - 1,
selectedPromptIndex.value + 1
);
e.preventDefault();
return;
} else if (e.key === "Enter") {
if (filteredPrompts.value.length > 0) {
inputStatus.value = "normal";
props.chat.changePrompt(
filteredPrompts.value[selectedPromptIndex.value]!.id
);
userMessage.value = "";
e.preventDefault();
return;
}
}

props.onMessage?.(new UserMessage(userMessage.value));
props.sendMessage(userMessage.value);
userMessage.value = "";

e.preventDefault();
} else if (inputStatus.value === "historyNavigation") {
if (e.key === "ArrowUp") {
if (!["ArrowUp", "ArrowDown"].includes(e.key)) {
inputStatus.value = "normal";
historyNavigationMessageId = null;
historyNavigationStack.length = 0;
} else if (e.key === "ArrowUp") {
let msg = await props.chat.getPreviousUserLog(
historyNavigationMessageId ?? undefined
);
Expand Down Expand Up @@ -179,7 +214,16 @@ export default defineComponent({
})} */}
</div>
</div>
<div class="h-[8rem] px-4 pt-2 pb-6">
<div class="h-[8rem] px-4 pt-2 pb-6 relative">
<CommandPanel
v-show={filteredPrompts.value.length > 0}
list={filteredPrompts.value.map((item) => ({
label: item.name,
value: item.name,
}))}
selected={selectedPromptIndex.value}
class="absolute left-4 top-0 translate-y-[-100%]"
></CommandPanel>
<NScrollbar class="h-full">
<textarea
ref={inputRef}
Expand Down
32 changes: 32 additions & 0 deletions web/src/components/chat/commandPanel.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { PropType, defineComponent } from "vue";

export default defineComponent({
props: {
list: {
type: Array as PropType<Array<{ value: string; label: string }>>,
default: () => [],
},
selected: {
type: Number,
default: 0,
},
},
setup(props, { expose }) {
const publicInstance = {};

expose(publicInstance);

return (() => (
<div class="border bg-[var(--command-panel-bg-color)] rounded-md">
{props.list.map((item, index) => {
const isSelected = props.selected === index;
return (
<div class={[isSelected ? "bg-primary" : "", "py-1 px-2"]}>
{item.label}
</div>
);
})}
</div>
)) as unknown as typeof publicInstance;
},
});
18 changes: 16 additions & 2 deletions web/src/models/chat.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { reactive, ref, Ref, watch } from "vue";
import { reactive, ref, Ref } from "vue";
import {
resendMessage,
sendMessage,
Expand All @@ -10,6 +10,7 @@ import {
deleteChatLog,
loadChatLogByCursor,
stopReply,
removeChatPrompt,
} from "../api";
import {
AssistantMessage,
Expand Down Expand Up @@ -90,7 +91,7 @@ export class Chat {
// To avoid the stop reply message is sent before the stop reply command
setTimeout(() => {
this.stopReplyHandler?.();
}, 1000)
}, 1000);
}
}

Expand All @@ -100,6 +101,19 @@ export class Chat {
this.messages.splice(index, 1);
}

async changePrompt(promptId: string) {
await updateChat({
id: this.index.id,
promptId,
});
this.index.promptId = promptId;
}

async removePrompt() {
this.index.promptId = undefined;
await removeChatPrompt(this.index.id);
}

async getPreviousUserLog(logId?: string): Promise<Message | null> {
if (!logId) {
const previousLog = this.messages.findLast(
Expand Down
Loading

0 comments on commit c4b231f

Please sign in to comment.