From 0631921cc91016396c0c7034779a4b178ce5d696 Mon Sep 17 00:00:00 2001 From: Yidadaa Date: Thu, 9 Nov 2023 03:01:29 +0800 Subject: [PATCH] feat: close #3187 use CUSTOM_MODELS to control model list --- README.md | 7 ++++++ README_CN.md | 6 +++++ app/api/common.ts | 31 ++++++++++++------------- app/api/config/route.ts | 1 + app/components/chat.tsx | 12 ++++------ app/components/model-config.tsx | 7 +++--- app/config/server.ts | 17 +++++++++++++- app/store/access.ts | 7 +----- app/store/config.ts | 10 +-------- app/utils/hooks.ts | 16 +++++++++++++ app/utils/model.ts | 40 +++++++++++++++++++++++++++++++++ 11 files changed, 112 insertions(+), 42 deletions(-) create mode 100644 app/utils/hooks.ts create mode 100644 app/utils/model.ts diff --git a/README.md b/README.md index 91e03d80049..3973c84bfde 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,13 @@ If you do want users to query balance, set this value to 1, or you should set it If you want to disable parse settings from url, set this to 1. +### `CUSTOM_MODELS` (optional) + +> Default: Empty +> Example: `+llama,+claude-2,-gpt-3.5-turbo` means add `llama, claude-2` to model list, and remove `gpt-3.5-turbo` from list. + +To control custom models, use `+` to add a custom model, use `-` to hide a model, separated by comma. + ## Requirements NodeJS >= 18, Docker >= 20 diff --git a/README_CN.md b/README_CN.md index 13b97417d40..d8e9553e183 100644 --- a/README_CN.md +++ b/README_CN.md @@ -106,6 +106,12 @@ OpenAI 接口代理 URL,如果你手动配置了 openai 接口代理,请填 如果你想禁用从链接解析预制设置,将此环境变量设置为 1 即可。 +### `CUSTOM_MODELS` (可选) + +> 示例:`+qwen-7b-chat,+glm-6b,-gpt-3.5-turbo` 表示增加 `qwen-7b-chat` 和 `glm-6b` 到模型列表,而从列表中删除 `gpt-3.5-turbo`。 + +用来控制模型列表,使用 `+` 增加一个模型,使用 `-` 来隐藏一个模型,用英文逗号隔开。 + ## 开发 点击下方按钮,开始二次开发: diff --git a/app/api/common.ts b/app/api/common.ts index 0af7761d88c..a1decd42f5b 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -1,10 +1,9 @@ import { NextRequest, NextResponse } from "next/server"; +import { getServerSideConfig } from "../config/server"; +import { DEFAULT_MODELS, OPENAI_BASE_URL } from "../constant"; +import { collectModelTable, collectModels } from "../utils/model"; -export const OPENAI_URL = "api.openai.com"; -const DEFAULT_PROTOCOL = "https"; -const PROTOCOL = process.env.PROTOCOL || DEFAULT_PROTOCOL; -const BASE_URL = process.env.BASE_URL || OPENAI_URL; -const DISABLE_GPT4 = !!process.env.DISABLE_GPT4; +const serverConfig = getServerSideConfig(); export async function requestOpenai(req: NextRequest) { const controller = new AbortController(); @@ -14,10 +13,10 @@ export async function requestOpenai(req: NextRequest) { "", ); - let baseUrl = BASE_URL; + let baseUrl = serverConfig.baseUrl ?? OPENAI_BASE_URL; if (!baseUrl.startsWith("http")) { - baseUrl = `${PROTOCOL}://${baseUrl}`; + baseUrl = `https://${baseUrl}`; } if (baseUrl.endsWith("/")) { @@ -26,10 +25,7 @@ export async function requestOpenai(req: NextRequest) { console.log("[Proxy] ", openaiPath); console.log("[Base Url]", baseUrl); - - if (process.env.OPENAI_ORG_ID) { - console.log("[Org ID]", process.env.OPENAI_ORG_ID); - } + console.log("[Org ID]", serverConfig.openaiOrgId); const timeoutId = setTimeout( () => { @@ -58,18 +54,23 @@ export async function requestOpenai(req: NextRequest) { }; // #1815 try to refuse gpt4 request - if (DISABLE_GPT4 && req.body) { + if (serverConfig.customModels && req.body) { try { + const modelTable = collectModelTable( + DEFAULT_MODELS, + serverConfig.customModels, + ); const clonedBody = await req.text(); fetchOptions.body = clonedBody; - const jsonBody = JSON.parse(clonedBody); + const jsonBody = JSON.parse(clonedBody) as { model?: string }; - if ((jsonBody?.model ?? "").includes("gpt-4")) { + // not undefined and is false + if (modelTable[jsonBody?.model ?? ""] === false) { return NextResponse.json( { error: true, - message: "you are not allowed to use gpt-4 model", + message: `you are not allowed to use ${jsonBody?.model} model`, }, { status: 403, diff --git a/app/api/config/route.ts b/app/api/config/route.ts index 44af8d3b9dd..db84fba175a 100644 --- a/app/api/config/route.ts +++ b/app/api/config/route.ts @@ -12,6 +12,7 @@ const DANGER_CONFIG = { disableGPT4: serverConfig.disableGPT4, hideBalanceQuery: serverConfig.hideBalanceQuery, disableFastLink: serverConfig.disableFastLink, + customModels: serverConfig.customModels, }; declare global { diff --git a/app/components/chat.tsx b/app/components/chat.tsx index a0b7307c298..9afb49f7a66 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -88,6 +88,7 @@ import { ChatCommandPrefix, useChatCommand, useCommand } from "../command"; import { prettyObject } from "../utils/format"; import { ExportMessageModal } from "./exporter"; import { getClientConfig } from "../config/client"; +import { useAllModels } from "../utils/hooks"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { loading: () => , @@ -430,14 +431,9 @@ export function ChatActions(props: { // switch model const currentModel = chatStore.currentSession().mask.modelConfig.model; - const models = useMemo( - () => - config - .allModels() - .filter((m) => m.available) - .map((m) => m.name), - [config], - ); + const models = useAllModels() + .filter((m) => m.available) + .map((m) => m.name); const [showModelSelector, setShowModelSelector] = useState(false); return ( diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index 63950a40d04..6e4c9bcb17b 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -1,14 +1,15 @@ -import { ModalConfigValidator, ModelConfig, useAppConfig } from "../store"; +import { ModalConfigValidator, ModelConfig } from "../store"; import Locale from "../locales"; import { InputRange } from "./input-range"; import { ListItem, Select } from "./ui-lib"; +import { useAllModels } from "../utils/hooks"; export function ModelConfigList(props: { modelConfig: ModelConfig; updateConfig: (updater: (config: ModelConfig) => void) => void; }) { - const config = useAppConfig(); + const allModels = useAllModels(); return ( <> @@ -24,7 +25,7 @@ export function ModelConfigList(props: { ); }} > - {config.allModels().map((v, i) => ( + {allModels.map((v, i) => ( diff --git a/app/config/server.ts b/app/config/server.ts index 2df806fed41..007c3973863 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -1,4 +1,5 @@ import md5 from "spark-md5"; +import { DEFAULT_MODELS } from "../constant"; declare global { namespace NodeJS { @@ -7,6 +8,7 @@ declare global { CODE?: string; BASE_URL?: string; PROXY_URL?: string; + OPENAI_ORG_ID?: string; VERCEL?: string; HIDE_USER_API_KEY?: string; // disable user's api key input DISABLE_GPT4?: string; // allow user to use gpt-4 or not @@ -14,6 +16,7 @@ declare global { BUILD_APP?: string; // is building desktop app ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not DISABLE_FAST_LINK?: string; // disallow parse settings from url or not + CUSTOM_MODELS?: string; // to control custom models } } } @@ -38,6 +41,16 @@ export const getServerSideConfig = () => { ); } + let disableGPT4 = !!process.env.DISABLE_GPT4; + let customModels = process.env.CUSTOM_MODELS ?? ""; + + if (disableGPT4) { + if (customModels) customModels += ","; + customModels += DEFAULT_MODELS.filter((m) => m.name.startsWith("gpt-4")) + .map((m) => "-" + m.name) + .join(","); + } + return { apiKey: process.env.OPENAI_API_KEY, code: process.env.CODE, @@ -45,10 +58,12 @@ export const getServerSideConfig = () => { needCode: ACCESS_CODES.size > 0, baseUrl: process.env.BASE_URL, proxyUrl: process.env.PROXY_URL, + openaiOrgId: process.env.OPENAI_ORG_ID, isVercel: !!process.env.VERCEL, hideUserApiKey: !!process.env.HIDE_USER_API_KEY, - disableGPT4: !!process.env.DISABLE_GPT4, + disableGPT4, hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY, disableFastLink: !!process.env.DISABLE_FAST_LINK, + customModels, }; }; diff --git a/app/store/access.ts b/app/store/access.ts index 3d889f6e72e..f87e44a2ac4 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -17,6 +17,7 @@ const DEFAULT_ACCESS_STATE = { hideBalanceQuery: false, disableGPT4: false, disableFastLink: false, + customModels: "", openaiUrl: DEFAULT_OPENAI_URL, }; @@ -52,12 +53,6 @@ export const useAccessStore = createPersistStore( .then((res: DangerConfig) => { console.log("[Config] got config from server", res); set(() => ({ ...res })); - - if (res.disableGPT4) { - DEFAULT_MODELS.forEach( - (m: any) => (m.available = !m.name.startsWith("gpt-4")), - ); - } }) .catch(() => { console.error("[Config] failed to fetch config"); diff --git a/app/store/config.ts b/app/store/config.ts index 0fbc26dfe0e..5fcd6ff514c 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -128,15 +128,7 @@ export const useAppConfig = createPersistStore( })); }, - allModels() { - const customModels = get() - .customModels.split(",") - .filter((v) => !!v && v.length > 0) - .map((m) => ({ name: m, available: true })); - const allModels = get().models.concat(customModels); - allModels.sort((a, b) => (a.name < b.name ? -1 : 1)); - return allModels; - }, + allModels() {}, }), { name: StoreKey.Config, diff --git a/app/utils/hooks.ts b/app/utils/hooks.ts new file mode 100644 index 00000000000..f6bfae67323 --- /dev/null +++ b/app/utils/hooks.ts @@ -0,0 +1,16 @@ +import { useMemo } from "react"; +import { useAccessStore, useAppConfig } from "../store"; +import { collectModels } from "./model"; + +export function useAllModels() { + const accessStore = useAccessStore(); + const configStore = useAppConfig(); + const models = useMemo(() => { + return collectModels( + configStore.models, + [accessStore.customModels, configStore.customModels].join(","), + ); + }, [accessStore.customModels, configStore.customModels, configStore.models]); + + return models; +} diff --git a/app/utils/model.ts b/app/utils/model.ts new file mode 100644 index 00000000000..23090f9d2f3 --- /dev/null +++ b/app/utils/model.ts @@ -0,0 +1,40 @@ +import { LLMModel } from "../client/api"; + +export function collectModelTable( + models: readonly LLMModel[], + customModels: string, +) { + const modelTable: Record = {}; + + // default models + models.forEach((m) => (modelTable[m.name] = m.available)); + + // server custom models + customModels + .split(",") + .filter((v) => !!v && v.length > 0) + .map((m) => { + if (m.startsWith("+")) { + modelTable[m.slice(1)] = true; + } else if (m.startsWith("-")) { + modelTable[m.slice(1)] = false; + } else modelTable[m] = true; + }); + return modelTable; +} + +/** + * Generate full model table. + */ +export function collectModels( + models: readonly LLMModel[], + customModels: string, +) { + const modelTable = collectModelTable(models, customModels); + const allModels = Object.keys(modelTable).map((m) => ({ + name: m, + available: modelTable[m], + })); + + return allModels; +}