Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new model type to use custom OpenAI api compatible servers #1692

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gpt4all-chat/chatgpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ void ChatGPTWorker::request(const QString &apiKey,
{
m_ctx = promptCtx;

QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
QUrl openaiUrl(m_chat->APIBase() + "/chat/completions");
const QString authorization = QString("Bearer %1").arg(apiKey).trimmed();
QNetworkRequest request(openaiUrl);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
Expand Down
4 changes: 4 additions & 0 deletions gpt4all-chat/chatgpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ class ChatGPT : public QObject, public LLModel {
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;

const QString& APIBase() const { return m_apiBase; }

void setModelName(const QString &modelName) { m_modelName = modelName; }
void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; }
void setAPIBase(const QString &apiBase) { m_apiBase = apiBase; }

QList<QString> context() const { return m_context; }
void setContext(const QList<QString> &context) { m_context = context; }
Expand All @@ -91,6 +94,7 @@ class ChatGPT : public QObject, public LLModel {
std::function<bool(int32_t, const std::string&)> m_responseCallback;
QString m_modelName;
QString m_apiKey;
QString m_apiBase;
QList<QString> m_context;
};

Expand Down
22 changes: 18 additions & 4 deletions gpt4all-chat/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (isModelLoaded() && this->modelInfo() == modelInfo)
return true;

bool isChatGPT = modelInfo.isChatGPT;
bool isChatGPT = modelInfo.isChatGPT || modelInfo.isOpenAICompatible;
QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath);

Expand Down Expand Up @@ -233,18 +233,32 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
if (fileInfo.exists()) {
if (isChatGPT) {
QString apiKey;
QString chatGPTModel = fileInfo.completeBaseName().remove(0, 8); // remove the chatgpt- prefix
QString apiBase;
QString chatGPTModel;
{
QFile file(filePath);
file.open(QIODeviceBase::ReadOnly | QIODeviceBase::Text);
QTextStream stream(&file);
apiKey = stream.readAll();
QStringList chatGPTParams = stream.readAll().split('\n');
if (chatGPTParams.isEmpty()) {

emit modelLoadingError(QString("Could not load model due to invalid settings for %1").arg(modelInfo.filename()));
} else {
apiKey = chatGPTParams[0];
if (chatGPTParams.size() >= 2) {
apiBase = chatGPTParams[1];
}
if (chatGPTParams.size() >= 3) {
i chatGPTModel = chatGPTParams[2];
}
}
file.close();
}
m_llModelType = LLModelType::CHATGPT_;
ChatGPT *model = new ChatGPT();
model->setModelName(chatGPTModel);
model->setModelName(chatGPTModel.size() ? chatGPTModel : fileInfo.completeBaseName().remove(0, 8)); // remove the chatgpt- prefix
model->setAPIKey(apiKey);
model->setAPIBase(apiBase.size() ? apiBase : "https://api.openai.com/v1/");
m_llModelInfo.model = model;
} else {

Expand Down
11 changes: 7 additions & 4 deletions gpt4all-chat/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,21 @@ void Download::cancelDownload(const QString &modelFile)
}
}

void Download::installModel(const QString &modelFile, const QString &apiKey)
void Download::installModel(const QString &modelFile, const QString &apiBase, const QString &apiKey, const QString &modelName)
{
Q_ASSERT(!apiKey.isEmpty());
if (apiKey.isEmpty())
Q_ASSERT(!(apiKey.isEmpty() && apiBase.isEmpty()));
if (apiKey.isEmpty() && apiBase.isEmpty())
return;

Network::globalInstance()->sendInstallModel(modelFile);
QString filePath = MySettings::globalInstance()->modelPath() + modelFile;
QFile file(filePath);
if (file.open(QIODeviceBase::WriteOnly | QIODeviceBase::Text)) {
QTextStream stream(&file);
stream << apiKey;
stream << apiKey + "\n" + apiBase;
if (modelName.size()) {
stream << "\n" + modelName;
}
file.close();
}
}
Expand Down
2 changes: 1 addition & 1 deletion gpt4all-chat/download.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Download : public QObject
bool hasNewerRelease() const;
Q_INVOKABLE void downloadModel(const QString &modelFile);
Q_INVOKABLE void cancelDownload(const QString &modelFile);
Q_INVOKABLE void installModel(const QString &modelFile, const QString &apiKey);
Q_INVOKABLE void installModel(const QString &modelFile, const QString &apiBase, const QString &apiKey, const QString &modelName = "");
Q_INVOKABLE void removeModel(const QString &modelFile);
Q_INVOKABLE bool isFirstStart() const;

Expand Down
2 changes: 1 addition & 1 deletion gpt4all-chat/main.qml
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ Window {
}

Image {
visible: currentChat.isServer || currentChat.modelInfo.isChatGPT
visible: currentChat.isServer || currentChat.modelInfo.isChatGPT || currentChat.modelInfo.isOpenAICompatible
anchors.fill: parent
sourceSize.width: 1024
sourceSize.height: 1024
Expand Down
42 changes: 40 additions & 2 deletions gpt4all-chat/modellist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ ModelInfo ModelList::defaultModelInfo() const
const size_t ramrequired = defaultModel->ramrequired;

// If we don't have either setting, then just use the first model that requires less than 16GB that is installed
if (!hasUserDefaultName && !info->isChatGPT && ramrequired > 0 && ramrequired < 16)
if (!hasUserDefaultName && !(info->isChatGPT || info->isOpenAICompatible) && ramrequired > 0 && ramrequired < 16)
break;

// If we have a user specified default and match, then use it
Expand Down Expand Up @@ -479,6 +479,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
return info->isDefault;
case ChatGPTRole:
return info->isChatGPT;
case OpenAICompatibleRole:
return info->isOpenAICompatible;
case DisableGUIRole:
return info->disableGUI;
case DescriptionRole:
Expand Down Expand Up @@ -604,6 +606,8 @@ void ModelList::updateData(const QString &id, int role, const QVariant &value)
info->isDefault = value.toBool(); break;
case ChatGPTRole:
info->isChatGPT = value.toBool(); break;
case OpenAICompatibleRole:
info->isOpenAICompatible = value.toBool(); break;
case DisableGUIRole:
info->disableGUI = value.toBool(); break;
case DescriptionRole:
Expand Down Expand Up @@ -735,6 +739,7 @@ QString ModelList::clone(const ModelInfo &model)
updateData(id, ModelList::DirpathRole, model.dirpath);
updateData(id, ModelList::InstalledRole, model.installed);
updateData(id, ModelList::ChatGPTRole, model.isChatGPT);
updateData(id, ModelList::OpenAICompatibleRole, model.isOpenAICompatible);
updateData(id, ModelList::TemperatureRole, model.temperature());
updateData(id, ModelList::TopPRole, model.topP());
updateData(id, ModelList::TopKRole, model.topK());
Expand Down Expand Up @@ -875,7 +880,9 @@ void ModelList::updateModelsFromDirectory()

for (const QString &id : modelsById) {
updateData(id, FilenameRole, filename);
updateData(id, ChatGPTRole, filename.startsWith("chatgpt-"));
const bool is_chatgpt_custom = filename.startsWith("chatgpt-custom");
updateData(id, OpenAICompatibleRole, is_chatgpt_custom);
updateData(id, ChatGPTRole, (!is_chatgpt_custom && filename.startsWith("chatgpt-")));
updateData(id, DirpathRole, info.dir().absolutePath() + "/");
updateData(id, FilesizeRole, toFileSize(info.size()));
}
Expand Down Expand Up @@ -1116,6 +1123,35 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
updateData(id, ModelList::SystemPromptRole, obj["systemPrompt"].toString());
}

const QString CustomOpenAIDesc = tr("<ul><li>Requires acces to OpenAI compatible server</li>"
"<li>Any OpenAI compatible server that support chat/completions method</li></li></li>"
"<br />"
"<p>For OpenAI, the base path is \"https://api.openai.com/v1/\"</p>"
"<p>eg: For a selfhosted server it could be \"http://localhost:8008/v1/\"</p>");

{
const QString modelName = "Custom OpenAI API";
const QString id = modelName;
const QString modelFilename = "chatgpt-custom.txt";
if (contains(modelFilename))
changeId(modelFilename, id);
if (!contains(id))
addModel(id);
updateData(id, ModelList::NameRole, modelName);
updateData(id, ModelList::FilenameRole, modelFilename);
updateData(id, ModelList::FilesizeRole, "minimal");
updateData(id, ModelList::ChatGPTRole, false);
updateData(id, ModelList::OpenAICompatibleRole, true);
updateData(id, ModelList::DescriptionRole,
tr("<strong>Any OpenAI compatible server</strong><br>") + CustomOpenAIDesc);
updateData(id, ModelList::RequiresVersionRole, "2.4.2");
updateData(id, ModelList::OrderRole, "ca");
updateData(id, ModelList::RamrequiredRole, 0);
updateData(id, ModelList::ParametersRole, "?");
updateData(id, ModelList::QuantRole, "NA");
updateData(id, ModelList::TypeRole, "GPT");
}

const QString chatGPTDesc = tr("<ul><li>Requires personal OpenAI API key.</li><li>WARNING: Will send"
" your chats to OpenAI!</li><li>Your API key will be stored on disk</li><li>Will only be used"
" to communicate with OpenAI</li><li>You can apply for an API key"
Expand All @@ -1133,6 +1169,7 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
updateData(id, ModelList::FilenameRole, modelFilename);
updateData(id, ModelList::FilesizeRole, "minimal");
updateData(id, ModelList::ChatGPTRole, true);
updateData(id, ModelList::OpenAICompatibleRole, false);
updateData(id, ModelList::DescriptionRole,
tr("<strong>OpenAI's ChatGPT model GPT-3.5 Turbo</strong><br>") + chatGPTDesc);
updateData(id, ModelList::RequiresVersionRole, "2.4.2");
Expand All @@ -1157,6 +1194,7 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
updateData(id, ModelList::FilenameRole, modelFilename);
updateData(id, ModelList::FilesizeRole, "minimal");
updateData(id, ModelList::ChatGPTRole, true);
updateData(id, ModelList::OpenAICompatibleRole, false);
updateData(id, ModelList::DescriptionRole,
tr("<strong>OpenAI's ChatGPT model GPT-4</strong><br>") + chatGPTDesc + chatGPT4Warn);
updateData(id, ModelList::RequiresVersionRole, "2.4.2");
Expand Down
4 changes: 4 additions & 0 deletions gpt4all-chat/modellist.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct ModelInfo {
Q_PROPERTY(bool isDefault MEMBER isDefault)
Q_PROPERTY(bool disableGUI MEMBER disableGUI)
Q_PROPERTY(bool isChatGPT MEMBER isChatGPT)
Q_PROPERTY(bool isOpenAICompatible MEMBER isOpenAICompatible)
Q_PROPERTY(QString description MEMBER description)
Q_PROPERTY(QString requiresVersion MEMBER requiresVersion)
Q_PROPERTY(QString deprecatedVersion MEMBER deprecatedVersion)
Expand Down Expand Up @@ -61,6 +62,7 @@ struct ModelInfo {
bool installed = false;
bool isDefault = false;
bool isChatGPT = false;
bool isOpenAICompatible = false;
bool disableGUI = false;
QString description;
QString requiresVersion;
Expand Down Expand Up @@ -204,6 +206,7 @@ class ModelList : public QAbstractListModel
InstalledRole,
DefaultRole,
ChatGPTRole,
OpenAICompatibleRole,
DisableGUIRole,
DescriptionRole,
RequiresVersionRole,
Expand Down Expand Up @@ -246,6 +249,7 @@ class ModelList : public QAbstractListModel
roles[InstalledRole] = "installed";
roles[DefaultRole] = "isDefault";
roles[ChatGPTRole] = "isChatGPT";
roles[OpenAICompatibleRole] = "isOpenAICompatible";
roles[DisableGUIRole] = "disableGUI";
roles[DescriptionRole] = "description";
roles[RequiresVersionRole] = "requiresVersion";
Expand Down
Loading