Skip to content

Commit

Permalink
Added open-ai vision for llama 3.2 support and upgraded demo.
Browse files Browse the repository at this point in the history
Added open-ai vision support
  • Loading branch information
slowsynapse authored Jan 21, 2025
2 parents 15ceccf + b76cbca commit a56150b
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 24 deletions.
18 changes: 17 additions & 1 deletion src/components/settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import { WhisperCppSettingsPage } from './settings/WhisperCppSettingsPage';
import { VisionBackendPage } from './settings/VisionBackendPage';
import { VisionLlamaCppSettingsPage } from './settings/VisionLlamaCppSettingsPage';
import { VisionOllamaSettingsPage } from './settings/VisionOllamaSettingsPage';
import { VisionOpenAISettingsPage } from './settings/VisionOpenAISettingsPage';
import { VisionSystemPromptPage } from './settings/VisionSystemPromptPage';

import { NamePage } from './settings/NamePage';
Expand Down Expand Up @@ -130,6 +131,9 @@ export const Settings = ({
const [visionLlamaCppUrl, setVisionLlamaCppUrl] = useState(config("vision_llamacpp_url"));
const [visionOllamaUrl, setVisionOllamaUrl] = useState(config("vision_ollama_url"));
const [visionOllamaModel, setVisionOllamaModel] = useState(config("vision_ollama_model"));
const [visionOpenAIApiKey, setVisionOpenAIApiKey] = useState(config("vision_openai_apikey"));
const [visionOpenAIUrl, setVisionOpenAIUrl] = useState(config("vision_openai_url"));
const [visionOpenAIModel, setVisionOpenAIModel] = useState(config("vision_openai_model"));
const [visionSystemPrompt, setVisionSystemPrompt] = useState(config("vision_system_prompt"));

const [bgUrl, setBgUrl] = useState(config("bg_url"));
Expand Down Expand Up @@ -257,6 +261,7 @@ export const Settings = ({
visionBackend,
visionLlamaCppUrl,
visionOllamaUrl, visionOllamaModel,
visionOpenAIApiKey, visionOpenAIUrl, visionOpenAIModel,
visionSystemPrompt,
bgColor,
bgUrl, vrmHash, vrmUrl, youtubeVideoID, animationUrl,
Expand Down Expand Up @@ -304,7 +309,7 @@ export const Settings = ({

case 'vision':
return <MenuPage
keys={["vision_backend", "vision_llamacpp_settings", "vision_ollama_settings", "vision_system_prompt"]}
keys={["vision_backend", "vision_llamacpp_settings", "vision_ollama_settings", "vision_openai_settings", "vision_system_prompt"]}
menuClick={handleMenuClick} />;

case 'reset_settings':
Expand Down Expand Up @@ -572,6 +577,17 @@ export const Settings = ({
setSettingsUpdated={setSettingsUpdated}
/>

case 'vision_openai_settings':
return <VisionOpenAISettingsPage
visionOpenAIApiKey={visionOpenAIApiKey}
setVisionOpenAIApiKey={setVisionOpenAIApiKey}
visionOpenAIUrl={visionOpenAIUrl}
setVisionOpenAIUrl={setVisionOpenAIUrl}
visionOpenAIModel={visionOpenAIModel}
setVisionOpenAIModel={setVisionOpenAIModel}
setSettingsUpdated={setSettingsUpdated}
/>

case 'vision_system_prompt':
return <VisionSystemPromptPage
visionSystemPrompt={visionSystemPrompt}
Expand Down
3 changes: 2 additions & 1 deletion src/components/settings/VisionBackendPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const visionEngines = [
{key: "none", label: "None"},
{key: "vision_llamacpp", label: "LLama.cpp"},
{key: "vision_ollama", label: "Ollama"},
{key: "vision_openai", label: "OpenAI"},
];

function idToTitle(id: string): string {
Expand Down Expand Up @@ -53,7 +54,7 @@ export function VisionBackendPage({
</select>
</FormRow>
</li>
{ ["vision_llamacpp", "vision_ollama"].includes(visionBackend) && (
{ ["vision_llamacpp", "vision_ollama", "vision_openai"].includes(visionBackend) && (
<li className="py-4">
<FormRow label={`${t("Configure")} ${t(idToTitle(visionBackend))}`}>
<button
Expand Down
80 changes: 80 additions & 0 deletions src/components/settings/VisionOpenAISettingsPage.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import { useTranslation } from 'react-i18next';

import { BasicPage, FormRow, NotUsingAlert } from "./common";
import { TextInput } from "@/components/textInput";
import { SecretTextInput } from "@/components/secretTextInput";
import { config, updateConfig } from "@/utils/config";

export function VisionOpenAISettingsPage({
visionOpenAIApiKey,
setVisionOpenAIApiKey,
visionOpenAIUrl,
setVisionOpenAIUrl,
visionOpenAIModel,
setVisionOpenAIModel,
setSettingsUpdated,
}: {
visionOpenAIApiKey: string;
setVisionOpenAIApiKey: (url: string) => void;
visionOpenAIUrl: string;
setVisionOpenAIUrl: (url: string) => void;
visionOpenAIModel: string;
setVisionOpenAIModel: (url: string) => void;
setSettingsUpdated: (updated: boolean) => void;
}) {
const { t } = useTranslation();

const description = <>Configure OpenAI vision settings. You can get an API key from <a href="https://platform.openai.com">platform.openai.com</a>. You can generally use other OpenAI compatible URLs and models here too, provided they have vision support, such as <a href="https://openrouter.ai/">OpenRouter</a> or <a href="https://lmstudio.ai/">LM Studio</a>.</>;

return (
<BasicPage
title={t("OpenAI") + " " + t("Settings")}
description={description}
>
{ config("vision_backend") !== "vision_openai" && (
<NotUsingAlert>
{t("not_using_alert", "You are not currently using {{name}} as your {{what}} backend. These settings will not be used.", {name: t("OpenAI"), what: t("Vision")})}
</NotUsingAlert>
) }
<ul role="list" className="divide-y divide-gray-100 max-w-xs">
<li className="py-4">
<FormRow label={t("API Key")}>
<SecretTextInput
value={visionOpenAIApiKey}
onChange={(event: React.ChangeEvent<any>) => {
event.preventDefault();
setVisionOpenAIApiKey(event.target.value);
updateConfig("vision_openai_apikey", event.target.value);
setSettingsUpdated(true);
}}
/>
</FormRow>
</li>
<li className="py-4">
<FormRow label={t("API URL")}>
<TextInput
value={visionOpenAIUrl}
onChange={(event: React.ChangeEvent<any>) => {
setVisionOpenAIUrl(event.target.value);
updateConfig("vision_openai_url", event.target.value);
setSettingsUpdated(true);
}}
/>
</FormRow>
</li>
<li className="py-4">
<FormRow label={t("Model")}>
<TextInput
value={visionOpenAIModel}
onChange={(event: React.ChangeEvent<any>) => {
setVisionOpenAIModel(event.target.value);
updateConfig("vision_openai_model", event.target.value);
setSettingsUpdated(true);
}}
/>
</FormRow>
</li>
</ul>
</BasicPage>
);
}
2 changes: 2 additions & 0 deletions src/components/settings/common.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ export function getIconFromPage(page: string): JSX.Element {
case 'vision_backend': return <EyeDropperIcon className="h-5 w-5 flex-none text-gray-800" aria-hidden="true" />;
case 'vision_llamacpp_settings': return <AdjustmentsHorizontalIcon className="h-5 w-5 flex-none text-gray-800" aria-hidden="true" />;
case 'vision_ollama_settings': return <AdjustmentsHorizontalIcon className="h-5 w-5 flex-none text-gray-800" aria-hidden="true" />;
case 'vision_openai_settings': return <AdjustmentsHorizontalIcon className="h-5 w-5 flex-none text-gray-800" aria-hidden="true" />;
case 'vision_system_prompt': return <DocumentTextIcon className="h-5 w-5 flex-none text-gray-800" aria-hidden="true" />;
}

Expand Down Expand Up @@ -221,6 +222,7 @@ function getLabelFromPage(page: string): string {
case 'vision_backend': return t('Vision Backend');
case 'vision_llamacpp_settings': return t('LLama.cpp');
case 'vision_ollama_settings': return t('Ollama');
case 'vision_openai_settings': return t('OpenAI');
case 'vision_system_prompt': return t('System Prompt');

case 'stt_backend': return t('STT Backend');
Expand Down
60 changes: 46 additions & 14 deletions src/features/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { Viewer } from "@/features/vrmViewer/viewer";
import { Alert } from "@/features/alert/alert";

import { getEchoChatResponseStream } from './echoChat';
import { getOpenAiChatResponseStream } from './openAiChat';
import { getOpenAiChatResponseStream, getOpenAiVisionChatResponse } from './openAiChat';
import { getLlamaCppChatResponseStream, getLlavaCppChatResponse } from './llamaCppChat';
import { getWindowAiChatResponseStream } from './windowAiChat';
import { getOllamaChatResponseStream, getOllamaVisionChatResponse } from './ollamaChat';
Expand Down Expand Up @@ -555,24 +555,56 @@ export class Chat {
try {
const visionBackend = config("vision_backend");

console.debug('vision_backend', visionBackend);
console.debug("vision_backend", visionBackend);

const messages: Message[] = [
{ role: "system", content: config("vision_system_prompt") },
...this.messageList!,
{
role: 'user',
content: "Describe the image as accurately as possible"
},
];
let res = "";
if (visionBackend === "vision_llamacpp") {
const messages: Message[] = [
{ role: "system", content: config("vision_system_prompt") },
...this.messageList!,
{
role: "user",
content: "Describe the image as accurately as possible",
},
];

let res = '';
if (visionBackend === 'vision_llamacpp') {
res = await getLlavaCppChatResponse(messages, imageData);
} else if (visionBackend === 'vision_ollama') {
} else if (visionBackend === "vision_ollama") {
const messages: Message[] = [
{ role: "system", content: config("vision_system_prompt") },
...this.messageList!,
{
role: "user",
content: "Describe the image as accurately as possible",
},
];

res = await getOllamaVisionChatResponse(messages, imageData);
} else if (visionBackend === "vision_openai") {
const messages: Message[] = [
{ role: "user", content: config("vision_system_prompt") },
...this.messageList! as any[],
{
role: "user",
// @ts-ignore normally this is a string
content: [
{
type: "text",
text: "Describe the image as accurately as possible",
},
{
type: "image_url",
image_url: {
url: `data:image/jpeg;base64,${imageData}`,
},
},
],
},
];

res = await getOpenAiVisionChatResponse(messages);
} else {
console.warn('vision_backend not supported', visionBackend);
console.warn("vision_backend not supported", visionBackend);
return;
}

Expand Down
47 changes: 41 additions & 6 deletions src/features/chat/openAiChat.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
import { Message } from "./messages";
import { config } from '@/utils/config';

export async function getOpenAiChatResponseStream(messages: Message[]) {
const apiKey = config("openai_apikey");
function getApiKey(configKey: string) {
const apiKey = config(configKey);
if (!apiKey) {
throw new Error("Invalid OpenAI API Key");
throw new Error(`Invalid ${configKey} API Key`);
}
return apiKey;
}

async function getResponseStream(
messages: Message[],
url: string,
model: string,
apiKey: string,
) {
const headers: Record<string, string> = {
"Content-Type": "application/json",
"Authorization": `Bearer ${apiKey}`,
"HTTP-Referer": "https://amica.arbius.ai",
"X-Title": "Amica",
};
const res = await fetch(`${config("openai_url")}/v1/chat/completions`, {

const res = await fetch(`${url}/v1/chat/completions`, {
headers: headers,
method: "POST",
body: JSON.stringify({
model: config("openai_model"),
messages: messages,
model,
messages,
stream: true,
max_tokens: 200,
}),
Expand Down Expand Up @@ -84,3 +94,28 @@ export async function getOpenAiChatResponseStream(messages: Message[]) {

return stream;
}

export async function getOpenAiChatResponseStream(messages: Message[]) {
const apiKey = getApiKey("openai_apikey");
const url = config("openai_url");
const model = config("openai_model");
return getResponseStream(messages, url, model, apiKey);
}

export async function getOpenAiVisionChatResponse(messages: Message[],) {
const apiKey = getApiKey("vision_openai_apikey");
const url = config("vision_openai_url");
const model = config("vision_openai_model");

const stream = await getResponseStream(messages, url, model, apiKey);
const sreader = await stream.getReader();

let combined = "";
while (true) {
const { done, value } = await sreader.read();
if (done) break;
combined += value;
}

return combined;
}
7 changes: 5 additions & 2 deletions src/utils/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ const defaults = {
tts_muted: 'false',
tts_backend: process.env.NEXT_PUBLIC_TTS_BACKEND ?? 'piper',
stt_backend: process.env.NEXT_PUBLIC_STT_BACKEND ?? 'whisper_browser',
vision_backend: process.env.NEXT_PUBLIC_VISION_BACKEND ?? 'none',
vision_system_prompt: process.env.NEXT_PUBLIC_VISION_SYSTEM_PROMPT ?? `You are a friendly human named Amica. Describe the image in detail. Let's start the conversation.`,
vision_backend: process.env.NEXT_PUBLIC_VISION_BACKEND ?? 'vision_openai',
vision_system_prompt: process.env.NEXT_PUBLIC_VISION_SYSTEM_PROMPT ?? `Look at the image as you would if you are a human, be concise, witty and charming.`,
vision_openai_apikey: process.env.NEXT_PUBLIC_VISION_OPENAI_APIKEY ?? 'default',
vision_openai_url: process.env.NEXT_PUBLIC_VISION_OPENAI_URL ?? 'https://api-01.heyamica.com',
vision_openai_model: process.env.NEXT_PUBLIC_VISION_OPENAI_URL ?? 'gpt-4-vision-preview',
vision_llamacpp_url: process.env.NEXT_PUBLIC_VISION_LLAMACPP_URL ?? 'http://127.0.0.1:8081',
vision_ollama_url: process.env.NEXT_PUBLIC_VISION_OLLAMA_URL ?? 'http://localhost:11434',
vision_ollama_model: process.env.NEXT_PUBLIC_VISION_OLLAMA_MODEL ?? 'llava',
Expand Down

0 comments on commit a56150b

Please sign in to comment.