Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev model providers #3628

Merged
merged 7 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,5 @@ configs/*.py
/knowledge_base/samples/content/imi_temeplate.txt
/chatchat/configs/*.py
/chatchat/configs/*.yaml
chatchat/data
chatchat/data
/chatchat-server/chatchat/configs/model_providers.yaml
26 changes: 0 additions & 26 deletions chatchat-server/chatchat/configs/loom.yaml.example

This file was deleted.

81 changes: 10 additions & 71 deletions chatchat-server/chatchat/configs/model_config.py.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os


# 默认选用的 LLM 名称
DEFAULT_LLM_MODEL = "chatglm3-6b"

Expand Down Expand Up @@ -31,7 +30,7 @@ SUPPORT_AGENT_MODELS = [


LLM_MODEL_CONFIG = {
# 意图识别不需要输出,模型后台知道就行
# 意图识别不需要输出,模型后台知道就行
"preprocess_model": {
DEFAULT_LLM_MODEL: {
"temperature": 0.05,
Expand All @@ -57,7 +56,7 @@ LLM_MODEL_CONFIG = {
"prompt_name": "ChatGLM3",
"callbacks": True
},
},
},
"postprocess_model": {
DEFAULT_LLM_MODEL: {
"temperature": 0.01,
Expand All @@ -76,47 +75,15 @@ LLM_MODEL_CONFIG = {
},
}

# 可以通过 loom/xinference/oneapi/fastchat 启动模型服务,然后将其 URL 和 KEY 配置过来即可。
# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力,启动后下面变量会自动增加相应的平台
# ### 如果您已经有了一个openai endpoint的能力的地址,可以在这里直接配置
# - platform_name 可以任意填写,不要重复即可
# - platform_type 可选:openai, xinference, oneapi, fastchat。以后可能根据平台类型做一些功能区分
# - platform_type 以后可能根据平台类型做一些功能区分,与platform_name一致即可
# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。

MODEL_PLATFORMS = [
# {
# "platform_name": "openai-api",
# "platform_type": "openai",
# "api_base_url": "https://api.openai.com/v1",
# "api_key": "sk-",
# "api_proxy": "",
# "api_concurrencies": 5,
# "llm_models": [
# "gpt-3.5-turbo",
# ],
# "embed_models": [],
# "image_models": [],
# "multimodal_models": [],
# },

{
"platform_name": "xinference",
"platform_type": "xinference",
"api_base_url": "http://127.0.0.1:9997/v1",
"api_key": "EMPTY",
"api_concurrencies": 5,
# 注意:这里填写的是 xinference 部署的模型 UID,而非模型名称
"llm_models": [
"chatglm3-6b",
],
"embed_models": [
"bge-large-zh-v1.5",
],
"image_models": [
"sd-turbo",
],
"multimodal_models": [
"qwen-vl",
],
},
# 创建一个全局的共享字典
MODEL_PLATFORMS = [

{
"platform_name": "oneapi",
Expand Down Expand Up @@ -152,41 +119,13 @@ MODEL_PLATFORMS = [
"multimodal_models": [],
},

{
"platform_name": "ollama",
"platform_type": "ollama",
"api_base_url": "http://{host}:{port}/v1",
"api_key": "sk-",
"api_concurrencies": 5,
"llm_models": [
# Qwen API,其它更多模型请参考https://ollama.com/library
"qwen:7b",
],
"embed_models": [
# 必须升级ollama到0.1.29以上,低版本向量服务有问题
"nomic-embed-text"
],
"image_models": [],
"multimodal_models": [],
},

# {
# "platform_name": "loom",
# "platform_type": "loom",
# "api_base_url": "http://127.0.0.1:7860/v1",
# "api_key": "",
# "api_concurrencies": 5,
# "llm_models": [
# "chatglm3-6b",
# ],
# "embed_models": [],
# "image_models": [],
# "multimodal_models": [],
# },
]

LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")
MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_providers.yaml")
MODEL_PROVIDERS_CFG_HOST = "127.0.0.1"

MODEL_PROVIDERS_CFG_PORT = 20000
# 工具配置项
TOOL_CONFIG = {
"search_local_knowledgebase": {
Expand Down
29 changes: 29 additions & 0 deletions chatchat-server/chatchat/configs/model_providers.yaml.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
openai:
model_credential:
- model: 'gpt-3.5-turbo'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''
- model: 'gpt-4'
model_type: 'llm'
model_credentials:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''

provider_credential:
openai_api_key: 'sk-'
openai_organization: ''
openai_api_base: ''

xinference:
model_credential:
- model: 'chatglm3-6b'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'


6 changes: 0 additions & 6 deletions chatchat-server/chatchat/configs/openai-plugins-list.json

This file was deleted.

109 changes: 109 additions & 0 deletions chatchat-server/chatchat/model_loaders/init_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import List, Dict
from chatchat.configs import MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, MODEL_PROVIDERS_CFG_PATH_CONFIG
from model_providers import BootstrapWebBuilder
from model_providers.bootstrap_web.entities.model_provider_entities import ProviderResponse
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers.core.provider_manager import ProviderManager
from model_providers.core.utils.utils import (
get_config_dict,
get_log_file,
get_timestamp_ms,
)
import multiprocessing as mp
import asyncio
import logging

logger = logging.getLogger(__name__)


def init_server(model_platforms_shard: Dict,
started_event: mp.Event = None,
model_providers_cfg_path: str = MODEL_PROVIDERS_CFG_PATH_CONFIG,
provider_host: str = MODEL_PROVIDERS_CFG_HOST,
provider_port: int = MODEL_PROVIDERS_CFG_PORT,
log_path: str = "logs"
) -> None:
logging_conf = get_config_dict(
"DEBUG",
get_log_file(log_path=log_path, sub_dir=f"provider_{get_timestamp_ms()}"),
122,
111,
)

try:
boot = (
BootstrapWebBuilder()
.model_providers_cfg_path(
model_providers_cfg_path=model_providers_cfg_path
)
.host(host=provider_host)
.port(port=provider_port)
.build()
)
boot.set_app_event(started_event=started_event)

provider_platforms = init_provider_platforms(boot.provider_manager.provider_manager)
model_platforms_shard['provider_platforms'] = provider_platforms

boot.serve(logging_conf=logging_conf)

async def pool_join_thread():
await boot.join()

asyncio.run(pool_join_thread())
except SystemExit:
logger.info("SystemExit raised, exiting")
raise


def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]:
provider_list: List[ProviderResponse] = ProvidersWrapper(
provider_manager=provider_manager).get_provider_list()
logger.info(f"Provider list: {provider_list}")
# 转换MODEL_PLATFORMS
provider_platforms = []
for provider in provider_list:
provider_dict = {
"platform_name": provider.provider,
"platform_type": provider.provider,
"api_base_url": f"http://127.0.0.1:20000/{provider.provider}/v1",
"api_key": "EMPTY",
"api_concurrencies": 5
}

provider_dict["llm_models"] = []
provider_dict["embed_models"] = []
provider_dict["image_models"] = []
provider_dict["multimodal_models"] = []
supported_model_str_types = [model_type.to_origin_model_type() for model_type in
provider.supported_model_types]

for model_type in supported_model_str_types:

providers_model_type = ProvidersWrapper(
provider_manager=provider_manager
).get_models_by_model_type(model_type=model_type)
cur_model_type: List[str] = []
# 查询当前provider的模型
for provider_model in providers_model_type:
if provider_model.provider == provider.provider:
models = [model.model for model in provider_model.models]
cur_model_type.extend(models)

if cur_model_type:
if model_type == "text-generation":
provider_dict["llm_models"] = cur_model_type
elif model_type == "text-embedding":
provider_dict["embed_models"] = cur_model_type
elif model_type == "text2img":
provider_dict["image_models"] = cur_model_type
elif model_type == "multimodal":
provider_dict["multimodal_models"] = cur_model_type
else:
logger.warning(f"Unsupported model type: {model_type}")

provider_platforms.append(provider_dict)

logger.info(f"Provider platforms: {provider_platforms}")

return provider_platforms
Loading