From 36696f3764dc0947181690ecee368779d1e4ac67 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Fri, 29 Nov 2024 19:34:34 -0800 Subject: [PATCH] Add preliminary chat history functionality --- llamafile/BUILD.mk | 7 + llamafile/db.cpp | 312 +++++++++++++++++++++-- llamafile/db.h | 22 +- llamafile/{server => }/hextoint.cpp | 2 - llamafile/{server => }/json.cpp | 0 llamafile/{server => }/json.h | 0 llamafile/{server => }/json_test.cpp | 0 llamafile/schema.sql | 2 +- llamafile/server/BUILD.mk | 8 +- llamafile/server/client.cpp | 31 +++ llamafile/server/client.h | 4 + llamafile/server/db.cpp | 227 +++++++++++++++++ llamafile/server/embedding.cpp | 2 +- llamafile/server/flagz.cpp | 10 +- llamafile/server/tokenize.cpp | 31 +-- llamafile/server/utils.h | 2 - llamafile/server/v1_chat_completions.cpp | 7 +- llamafile/server/v1_completions.cpp | 2 +- llamafile/utils.h | 24 ++ third_party/sqlite/BUILD.mk | 4 +- 20 files changed, 633 insertions(+), 64 deletions(-) rename llamafile/{server => }/hextoint.cpp (98%) rename llamafile/{server => }/json.cpp (100%) rename llamafile/{server => }/json.h (100%) rename llamafile/{server => }/json_test.cpp (100%) create mode 100644 llamafile/server/db.cpp create mode 100644 llamafile/utils.h diff --git a/llamafile/BUILD.mk b/llamafile/BUILD.mk index 9a4db2fdc5..2dd2444431 100644 --- a/llamafile/BUILD.mk +++ b/llamafile/BUILD.mk @@ -63,6 +63,7 @@ o/$(MODE)/llamafile: \ o/$(MODE)/llamafile/parse_cidr_test.runs \ o/$(MODE)/llamafile/pool_cancel_test.runs \ o/$(MODE)/llamafile/pool_test.runs \ + o/$(MODE)/llamafile/json_test.runs \ o/$(MODE)/llamafile/thread_test.runs \ o/$(MODE)/llamafile/vmathf_test.runs \ @@ -156,6 +157,12 @@ o/$(MODE)/llamafile/tinyblas_cpu_sgemm_arm82.o: \ ################################################################################ # testing +o/$(MODE)/llamafile/json_test: \ + o/$(MODE)/llamafile/json_test.o \ + o/$(MODE)/llamafile/json.o \ + o/$(MODE)/llamafile/hextoint.o \ + o/$(MODE)/double-conversion/double-conversion.a \ + o/$(MODE)/llamafile/vmathf_test: \ o/$(MODE)/llamafile/vmathf_test.o \ o/$(MODE)/llama.cpp/llama.cpp.a \ diff --git a/llamafile/db.cpp b/llamafile/db.cpp index 2f13510f47..2501985f0d 100644 --- a/llamafile/db.cpp +++ b/llamafile/db.cpp @@ -16,19 +16,24 @@ // limitations under the License. #include "db.h" +#include "llamafile/json.h" +#include "llamafile/llamafile.h" +#include "third_party/sqlite/sqlite3.h" +#include #include +#include #include __static_yoink("llamafile/schema.sql"); #define SCHEMA_VERSION 1 -namespace llamafile { +namespace lf { namespace db { -static bool table_exists(sqlite3* db, const char* table_name) { - const char* query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;"; - sqlite3_stmt* stmt; +static bool table_exists(sqlite3 *db, const char *table_name) { + const char *query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;"; + sqlite3_stmt *stmt; if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { return false; } @@ -41,8 +46,8 @@ static bool table_exists(sqlite3* db, const char* table_name) { return exists; } -static bool init_schema(sqlite3* db) { - FILE* f = fopen("/zip/llamafile/schema.sql", "r"); +static bool init_schema(sqlite3 *db) { + FILE *f = fopen("/zip/llamafile/schema.sql", "r"); if (!f) return false; std::string schema; @@ -50,7 +55,7 @@ static bool init_schema(sqlite3* db) { while ((c = fgetc(f)) != EOF) schema += c; fclose(f); - char* errmsg = nullptr; + char *errmsg = nullptr; int rc = sqlite3_exec(db, schema.c_str(), nullptr, nullptr, &errmsg); if (rc != SQLITE_OK) { if (errmsg) { @@ -62,37 +67,310 @@ static bool init_schema(sqlite3* db) { return true; } -sqlite3* open(const char* path) { - sqlite3* db; - int rc = sqlite3_open(path, &db); +static sqlite3 *open_impl() { + std::string path; + if (FLAG_db) { + path = FLAG_db; + } else { + const char *home = getenv("HOME"); + if (home) { + path = std::string(home) + "/.llamafile/llamafile.sqlite3"; + } else { + path = "llamafile.sqlite3"; + } + } + sqlite3 *db; + int rc = sqlite3_open(path.c_str(), &db); if (rc) { - fprintf(stderr, "%s: can't open database: %s\n", path, sqlite3_errmsg(db)); + fprintf(stderr, "%s: can't open database: %s\n", path.c_str(), sqlite3_errmsg(db)); return nullptr; } - char* errmsg = nullptr; + char *errmsg = nullptr; if (sqlite3_exec(db, "PRAGMA journal_mode=WAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) { - fprintf(stderr, "Failed to set journal mode to WAL: %s\n", errmsg); + fprintf(stderr, "%s: failed to set journal mode to wal: %s\n", path.c_str(), errmsg); sqlite3_free(errmsg); sqlite3_close(db); return nullptr; } if (sqlite3_exec(db, "PRAGMA synchronous=NORMAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) { - fprintf(stderr, "Failed to set synchronous to NORMAL: %s\n", errmsg); + fprintf(stderr, "%s: failed to set synchronous to normal: %s\n", path.c_str(), errmsg); sqlite3_free(errmsg); sqlite3_close(db); return nullptr; } if (!table_exists(db, "metadata") && !init_schema(db)) { - fprintf(stderr, "%s: failed to initialize database schema\n", path); + fprintf(stderr, "%s: failed to initialize database schema\n", path.c_str()); sqlite3_close(db); return nullptr; } return db; } -void close(sqlite3* db) { +sqlite3 *open() { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + sqlite3 *res = open_impl(); + pthread_setcancelstate(cs, 0); + return res; +} + +void close(sqlite3 *db) { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); sqlite3_close(db); + pthread_setcancelstate(cs, 0); +} + +static int64_t add_chat_impl(sqlite3 *db, const std::string &model, const std::string &title) { + const char *query = "INSERT INTO chats (model, title) VALUES (?, ?);"; + sqlite3_stmt *stmt; + if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { + return -1; + } + if (sqlite3_bind_text(stmt, 1, model.data(), model.size(), SQLITE_STATIC) != SQLITE_OK || + sqlite3_bind_text(stmt, 2, title.data(), title.size(), SQLITE_STATIC) != SQLITE_OK) { + sqlite3_finalize(stmt); + return -1; + } + if (sqlite3_step(stmt) != SQLITE_DONE) { + sqlite3_finalize(stmt); + return -1; + } + sqlite3_finalize(stmt); + return sqlite3_last_insert_rowid(db); +} + +int64_t add_chat(sqlite3 *db, const std::string &model, const std::string &title) { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + int64_t res = add_chat_impl(db, model, title); + pthread_setcancelstate(cs, 0); + return res; +} + +static int64_t add_message_impl(sqlite3 *db, int64_t chat_id, const std::string &role, + const std::string &content, double temperature, double top_p, + double presence_penalty, double frequency_penalty) { + const char *query = "INSERT INTO messages (chat_id, role, content, temperature, " + "top_p, presence_penalty, frequency_penalty) " + "VALUES (?, ?, ?, ?, ?, ?, ?);"; + sqlite3_stmt *stmt; + if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { + return -1; + } + if (sqlite3_bind_int64(stmt, 1, chat_id) != SQLITE_OK || + sqlite3_bind_text(stmt, 2, role.data(), role.size(), SQLITE_STATIC) != SQLITE_OK || + sqlite3_bind_text(stmt, 3, content.data(), content.size(), SQLITE_STATIC) != SQLITE_OK || + sqlite3_bind_double(stmt, 4, temperature) != SQLITE_OK || + sqlite3_bind_double(stmt, 5, top_p) != SQLITE_OK || + sqlite3_bind_double(stmt, 6, presence_penalty) != SQLITE_OK || + sqlite3_bind_double(stmt, 7, frequency_penalty) != SQLITE_OK) { + sqlite3_finalize(stmt); + return -1; + } + if (sqlite3_step(stmt) != SQLITE_DONE) { + sqlite3_finalize(stmt); + return -1; + } + sqlite3_finalize(stmt); + return sqlite3_last_insert_rowid(db); +} + +int64_t add_message(sqlite3 *db, int64_t chat_id, const std::string &role, + const std::string &content, double temperature, double top_p, + double presence_penalty, double frequency_penalty) { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + int64_t res = add_message_impl(db, chat_id, role, content, temperature, top_p, presence_penalty, + frequency_penalty); + pthread_setcancelstate(cs, 0); + return res; +} + +static bool update_title_impl(sqlite3 *db, int64_t chat_id, const std::string &title) { + const char *query = "UPDATE chats SET title = ? WHERE id = ?;"; + sqlite3_stmt *stmt; + if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { + return false; + } + if (sqlite3_bind_text(stmt, 1, title.data(), title.size(), SQLITE_STATIC) != SQLITE_OK || + sqlite3_bind_int64(stmt, 2, chat_id) != SQLITE_OK) { + sqlite3_finalize(stmt); + return false; + } + bool success = sqlite3_step(stmt) == SQLITE_DONE; + sqlite3_finalize(stmt); + return success; +} + +bool update_title(sqlite3 *db, int64_t chat_id, const std::string &title) { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + bool res = update_title_impl(db, chat_id, title); + pthread_setcancelstate(cs, 0); + return res; +} + +static bool delete_message_impl(sqlite3 *db, int64_t message_id) { + const char *query = "DELETE FROM messages WHERE id = ?;"; + sqlite3_stmt *stmt; + if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { + return false; + } + if (sqlite3_bind_int64(stmt, 1, message_id) != SQLITE_OK) { + sqlite3_finalize(stmt); + return false; + } + bool success = sqlite3_step(stmt) == SQLITE_DONE; + sqlite3_finalize(stmt); + return success; +} + +bool delete_message(sqlite3 *db, int64_t message_id) { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + bool res = delete_message_impl(db, message_id); + pthread_setcancelstate(cs, 0); + return res; +} + +static jt::Json get_chats_impl(sqlite3 *db) { + const char *query = "SELECT id, created_at, model, title FROM chats ORDER BY created_at DESC;"; + sqlite3_stmt *stmt; + jt::Json result; + result.setArray(); + if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { + return result; + } + while (sqlite3_step(stmt) == SQLITE_ROW) { + jt::Json chat; + chat.setObject(); + chat["id"] = sqlite3_column_int64(stmt, 0); + chat["created_at"] = reinterpret_cast(sqlite3_column_text(stmt, 1)); + chat["model"] = reinterpret_cast(sqlite3_column_text(stmt, 2)); + chat["title"] = reinterpret_cast(sqlite3_column_text(stmt, 3)); + result.getArray().push_back(std::move(chat)); + } + sqlite3_finalize(stmt); + return result; +} + +jt::Json get_chats(sqlite3 *db) { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + jt::Json res = get_chats_impl(db); + pthread_setcancelstate(cs, 0); + return res; +} + +static jt::Json get_messages_impl(sqlite3 *db, int64_t chat_id) { + const char *query = "SELECT id, created_at, role, content, temperature, top_p, " + "presence_penalty, frequency_penalty " + "FROM messages " + "WHERE chat_id = ? " + "ORDER BY created_at DESC;"; + sqlite3_stmt *stmt; + jt::Json result; + result.setArray(); + if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { + return result; + } + if (sqlite3_bind_int64(stmt, 1, chat_id) != SQLITE_OK) { + sqlite3_finalize(stmt); + return result; + } + while (sqlite3_step(stmt) == SQLITE_ROW) { + jt::Json msg; + msg.setObject(); + msg["id"] = sqlite3_column_int64(stmt, 0); + msg["created_at"] = reinterpret_cast(sqlite3_column_text(stmt, 1)); + msg["role"] = reinterpret_cast(sqlite3_column_text(stmt, 2)); + msg["content"] = reinterpret_cast(sqlite3_column_text(stmt, 3)); + msg["temperature"] = sqlite3_column_double(stmt, 4); + msg["top_p"] = sqlite3_column_double(stmt, 5); + msg["presence_penalty"] = sqlite3_column_double(stmt, 6); + msg["frequency_penalty"] = sqlite3_column_double(stmt, 7); + result.getArray().push_back(std::move(msg)); + } + sqlite3_finalize(stmt); + return result; +} + +jt::Json get_messages(sqlite3 *db, int64_t chat_id) { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + jt::Json res = get_messages_impl(db, chat_id); + pthread_setcancelstate(cs, 0); + return res; +} + +static jt::Json get_chat_impl(sqlite3 *db, int64_t chat_id) { + const char *query = "SELECT id, created_at, model, title FROM chats WHERE id = ?;"; + sqlite3_stmt *stmt; + jt::Json result; + result.setObject(); + if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { + return result; + } + if (sqlite3_bind_int64(stmt, 1, chat_id) != SQLITE_OK) { + sqlite3_finalize(stmt); + return result; + } + if (sqlite3_step(stmt) == SQLITE_ROW) { + result["id"] = sqlite3_column_int64(stmt, 0); + result["created_at"] = reinterpret_cast(sqlite3_column_text(stmt, 1)); + result["model"] = reinterpret_cast(sqlite3_column_text(stmt, 2)); + result["title"] = reinterpret_cast(sqlite3_column_text(stmt, 3)); + } + sqlite3_finalize(stmt); + return result; +} + +jt::Json get_chat(sqlite3 *db, int64_t chat_id) { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + jt::Json res = get_chat_impl(db, chat_id); + pthread_setcancelstate(cs, 0); + return res; +} + +static jt::Json get_message_impl(sqlite3 *db, int64_t message_id) { + const char *query = "SELECT id, created_at, chat_id, role, content, temperature, top_p, " + "presence_penalty, frequency_penalty " + "FROM messages WHERE id = ?" + "ORDER BY created_at ASC;"; + sqlite3_stmt *stmt; + jt::Json result; + result.setObject(); + if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { + return result; + } + if (sqlite3_bind_int64(stmt, 1, message_id) != SQLITE_OK) { + sqlite3_finalize(stmt); + return result; + } + if (sqlite3_step(stmt) == SQLITE_ROW) { + result["id"] = sqlite3_column_int64(stmt, 0); + result["created_at"] = reinterpret_cast(sqlite3_column_text(stmt, 1)); + result["chat_id"] = sqlite3_column_int64(stmt, 2); + result["role"] = reinterpret_cast(sqlite3_column_text(stmt, 3)); + result["content"] = reinterpret_cast(sqlite3_column_text(stmt, 4)); + result["temperature"] = sqlite3_column_double(stmt, 5); + result["top_p"] = sqlite3_column_double(stmt, 6); + result["presence_penalty"] = sqlite3_column_double(stmt, 7); + result["frequency_penalty"] = sqlite3_column_double(stmt, 8); + } + sqlite3_finalize(stmt); + return result; +} + +jt::Json get_message(sqlite3 *db, int64_t message_id) { + int cs; + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + jt::Json res = get_message_impl(db, message_id); + pthread_setcancelstate(cs, 0); + return res; } } // namespace db -} // namespace llamafile +} // namespace lf diff --git a/llamafile/db.h b/llamafile/db.h index 94e98b7b37..87a787299a 100644 --- a/llamafile/db.h +++ b/llamafile/db.h @@ -16,13 +16,25 @@ // limitations under the License. #pragma once -#include "third_party/sqlite/sqlite3.h" +#include "json.h" +#include <__fwd/string.h> -namespace llamafile { +struct sqlite3; + +namespace lf { namespace db { -sqlite3* open(const char*); -void close(sqlite3*); +sqlite3 *open(); +void close(sqlite3 *); +int64_t add_chat(sqlite3 *, const std::string &, const std::string &); +int64_t add_message(sqlite3 *, int64_t, const std::string &, const std::string &, double, double, + double, double); +bool update_title(sqlite3 *, int64_t, const std::string &); +bool delete_message(sqlite3 *, int64_t); +jt::Json get_chat(sqlite3 *, int64_t); +jt::Json get_chats(sqlite3 *); +jt::Json get_message(sqlite3 *, int64_t); +jt::Json get_messages(sqlite3 *, int64_t); } // namespace db -} // namespace llamafile +} // namespace lf diff --git a/llamafile/server/hextoint.cpp b/llamafile/hextoint.cpp similarity index 98% rename from llamafile/server/hextoint.cpp rename to llamafile/hextoint.cpp index dc7eafb3b1..5bc5965c60 100644 --- a/llamafile/server/hextoint.cpp +++ b/llamafile/hextoint.cpp @@ -18,7 +18,6 @@ #include "utils.h" namespace lf { -namespace server { alignas(signed char) const signed char kHexToInt[256] = { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 0x00 @@ -39,5 +38,4 @@ alignas(signed char) const signed char kHexToInt[256] = { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 0xf0 }; -} // namespace server } // namespace lf diff --git a/llamafile/server/json.cpp b/llamafile/json.cpp similarity index 100% rename from llamafile/server/json.cpp rename to llamafile/json.cpp diff --git a/llamafile/server/json.h b/llamafile/json.h similarity index 100% rename from llamafile/server/json.h rename to llamafile/json.h diff --git a/llamafile/server/json_test.cpp b/llamafile/json_test.cpp similarity index 100% rename from llamafile/server/json_test.cpp rename to llamafile/json_test.cpp diff --git a/llamafile/schema.sql b/llamafile/schema.sql index 9673b00ed1..ff67e36e5d 100644 --- a/llamafile/schema.sql +++ b/llamafile/schema.sql @@ -15,7 +15,7 @@ CREATE TABLE messages ( created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, chat_id INTEGER, role TEXT, - message TEXT, + content TEXT, temperature REAL, top_p REAL, presence_penalty REAL, diff --git a/llamafile/server/BUILD.mk b/llamafile/server/BUILD.mk index a7f76ce7ac..6d48aea4f2 100644 --- a/llamafile/server/BUILD.mk +++ b/llamafile/server/BUILD.mk @@ -23,6 +23,7 @@ o/$(MODE)/llamafile/server/main: \ o/$(MODE)/llama.cpp/llava/llava.a \ o/$(MODE)/third_party/double-conversion/double-conversion.a \ o/$(MODE)/third_party/stb/stb.a \ + o/$(MODE)/third_party/sqlite/sqlite3.a \ $(LLAMAFILE_SERVER_ASSETS:%=o/$(MODE)/%.zip.o) \ # turn /zip/llamafile/server/www/... @@ -45,12 +46,6 @@ o/$(MODE)/llamafile/server/fastjson_test: \ o/$(MODE)/llamafile/server/fastjson.o \ o/$(MODE)/double-conversion/double-conversion.a \ -o/$(MODE)/llamafile/server/json_test: \ - o/$(MODE)/llamafile/server/json_test.o \ - o/$(MODE)/llamafile/server/json.o \ - o/$(MODE)/llamafile/server/hextoint.o \ - o/$(MODE)/double-conversion/double-conversion.a \ - o/$(MODE)/llamafile/server/tokenbucket_test: \ o/$(MODE)/llamafile/server/tokenbucket_test.o \ o/$(MODE)/llamafile/server/tokenbucket.o \ @@ -63,5 +58,4 @@ o/$(MODE)/llamafile/server: \ o/$(MODE)/llamafile/server/atom_test.runs \ o/$(MODE)/llamafile/server/fastjson_test.runs \ o/$(MODE)/llamafile/server/image_test.runs \ - o/$(MODE)/llamafile/server/json_test.runs \ o/$(MODE)/llamafile/server/tokenbucket_test.runs \ diff --git a/llamafile/server/client.cpp b/llamafile/server/client.cpp index 5705c530a0..7ce3be3f89 100644 --- a/llamafile/server/client.cpp +++ b/llamafile/server/client.cpp @@ -48,6 +48,19 @@ namespace lf { namespace server { +static int64_t +atoi(std::string_view s) +{ + int64_t n = 0; + for (char c : s) { + if (c < '0' || c > '9') + return -1; + n *= 10; + n += c - '0'; + } + return n; +} + static void on_http_cancel(Client* client) { @@ -659,6 +672,24 @@ Client::dispatcher() if (p1 == "flagz") return flagz(); + if (p1 == "db/chats" || p1 == "db/chats/") + return db_chats(); + if (p1.starts_with("db/chat/")) { + int64_t id = atoi(p1.substr(strlen("db/chat/"))); + if (id != -1) + return db_chat(id); + } + if (p1.starts_with("db/messages/")) { + int64_t id = atoi(p1.substr(strlen("db/messages/"))); + if (id != -1) + return db_messages(id); + } + if (p1.starts_with("db/message/")) { + int64_t id = atoi(p1.substr(strlen("db/message/"))); + if (id != -1) + return db_messages(id); + } + // serve static endpoints int infd; size_t size; diff --git a/llamafile/server/client.h b/llamafile/server/client.h index 726801b44e..4e1e209084 100644 --- a/llamafile/server/client.h +++ b/llamafile/server/client.h @@ -118,6 +118,10 @@ struct Client bool slotz() __wur; bool flagz() __wur; + bool db_chat(int64_t) __wur; + bool db_chats() __wur; + bool db_message(int64_t) __wur; + bool db_messages(int64_t) __wur; }; } // namespace server diff --git a/llamafile/server/db.cpp b/llamafile/server/db.cpp new file mode 100644 index 0000000000..2927adf4c5 --- /dev/null +++ b/llamafile/server/db.cpp @@ -0,0 +1,227 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi +// +// Copyright 2024 Mozilla Foundation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llamafile/db.h" +#include "client.h" +#include "llama.cpp/llama.h" +#include "llamafile/llamafile.h" +#include "llamafile/string.h" +#include + +namespace lf { +namespace server { + +bool +Client::db_chats() +{ + if (msg_.method == kHttpGet) { + sqlite3* db = db::open(); + if (!db) + return send_error(500, "db::open failed"); + jt::Json json = db::get_chats(db); + db::close(db); + dump_ = json.toStringPretty(); + dump_ += '\n'; + char* p = append_http_response_message(obuf_.p, 200); + p = stpcpy(p, "Content-Type: application/json\r\n"); + return send_response(obuf_.p, p, dump_); + } else if (msg_.method == kHttpPut) { + if (!HasHeader(kHttpContentType) || + !IsMimeType(HeaderData(kHttpContentType), + HeaderLength(kHttpContentType), + "application/json")) { + return send_error(501, "Content Type Not Implemented"); + } + if (!read_payload()) + return false; + auto [status, json] = jt::Json::parse(std::string(payload_)); + if (status != jt::Json::success) + return send_error(400, jt::Json::StatusToString(status)); + if (!json.isObject()) + return send_error(400, "JSON body must be an object"); + if (!json["title"].isString()) + return send_error(400, "title must be a string"); + sqlite3* db = db::open(); + if (!db) + return send_error(500, "db::open failed"); + int64_t chat_id = + db::add_chat(db, FLAG_model, json["title"].getString()); + if (chat_id == -1) { + db::close(db); + return send_error(500, "db::add_chat failed"); + } + jt::Json json2 = db::get_chat(db, chat_id); + db::close(db); + dump_ = json2.toStringPretty(); + dump_ += '\n'; + char* p = append_http_response_message(obuf_.p, 200); + p = stpcpy(p, "Content-Type: application/json\r\n"); + return send_response(obuf_.p, p, dump_); + } else { + return send_error(405); + } +} + +bool +Client::db_chat(int64_t id) +{ + if (msg_.method == kHttpGet) { + sqlite3* db = db::open(); + if (!db) + return send_error(500, "db::open failed"); + jt::Json json = db::get_chat(db, id); + db::close(db); + dump_ = json.toStringPretty(); + dump_ += '\n'; + char* p = append_http_response_message(obuf_.p, 200); + p = stpcpy(p, "Content-Type: application/json\r\n"); + return send_response(obuf_.p, p, dump_); + } else if (msg_.method == kHttpPut) { + if (!HasHeader(kHttpContentType) || + !IsMimeType(HeaderData(kHttpContentType), + HeaderLength(kHttpContentType), + "application/json")) { + return send_error(501, "Content Type Not Implemented"); + } + if (!read_payload()) + return false; + auto [status, json] = jt::Json::parse(std::string(payload_)); + if (status != jt::Json::success) + return send_error(400, jt::Json::StatusToString(status)); + if (!json.isObject()) + return send_error(400, "JSON body must be an object"); + if (!json["title"].isString()) + return send_error(400, "title must be a string"); + sqlite3* db = db::open(); + if (!db) + return send_error(500, "db::open failed"); + if (!db::update_title(db, id, json["title"].getString())) { + db::close(db); + return send_error(500, "db::update_title failed"); + } + jt::Json json2 = db::get_chat(db, id); + db::close(db); + dump_ = json2.toStringPretty(); + dump_ += '\n'; + char* p = append_http_response_message(obuf_.p, 200); + p = stpcpy(p, "Content-Type: application/json\r\n"); + return send_response(obuf_.p, p, dump_); + } else { + return send_error(405); + } +} + +bool +Client::db_messages(int64_t chat_id) +{ + if (msg_.method == kHttpGet) { + sqlite3* db = db::open(); + if (!db) + return send_error(500, "db::open failed"); + jt::Json json = db::get_messages(db, chat_id); + db::close(db); + dump_ = json.toStringPretty(); + dump_ += '\n'; + char* p = append_http_response_message(obuf_.p, 200); + p = stpcpy(p, "Content-Type: application/json\r\n"); + return send_response(obuf_.p, p, dump_); + } else if (msg_.method == kHttpPut) { + if (!HasHeader(kHttpContentType) || + !IsMimeType(HeaderData(kHttpContentType), + HeaderLength(kHttpContentType), + "application/json")) { + return send_error(501, "Content Type Not Implemented"); + } + if (!read_payload()) + return false; + auto [status, json] = jt::Json::parse(std::string(payload_)); + if (status != jt::Json::success) + return send_error(400, jt::Json::StatusToString(status)); + if (!json.isObject()) + return send_error(400, "JSON body must be an object"); + if (!json["role"].isString()) + return send_error(400, "role must be a string"); + if (!json["content"].isString()) + return send_error(400, "content must be a string"); + if (!json["temperature"].isNumber()) + return send_error(400, "temperature must be a number"); + if (!json["top_p"].isNumber()) + return send_error(400, "top_p must be a number"); + if (!json["presence_penalty"].isNumber()) + return send_error(400, "presence_penalty must be a number"); + if (!json["frequency_penalty"].isNumber()) + return send_error(400, "frequency_penalty must be a number"); + sqlite3* db = db::open(); + if (!db) + return send_error(500, "db::open failed"); + int64_t chat_id = + db::add_message(db, + chat_id, + json["role"].getString(), + json["content"].getString(), + json["temperature"].getNumber(), + json["top_p"].getNumber(), + json["presence_penalty"].getNumber(), + json["frequency_penalty"].getNumber()); + if (chat_id == -1) { + db::close(db); + return send_error(500, "db::add_chat failed"); + } + jt::Json json2 = db::get_chat(db, chat_id); + db::close(db); + dump_ = json2.toStringPretty(); + dump_ += '\n'; + char* p = append_http_response_message(obuf_.p, 200); + p = stpcpy(p, "Content-Type: application/json\r\n"); + return send_response(obuf_.p, p, dump_); + } else { + return send_error(405); + } +} + +bool +Client::db_message(int64_t id) +{ + if (msg_.method == kHttpGet) { + sqlite3* db = db::open(); + if (!db) + return send_error(500, "db::open failed"); + jt::Json json = db::get_message(db, id); + db::close(db); + dump_ = json.toStringPretty(); + dump_ += '\n'; + char* p = append_http_response_message(obuf_.p, 200); + p = stpcpy(p, "Content-Type: application/json\r\n"); + return send_response(obuf_.p, p, dump_); + } else if (msg_.method == kHttpDelete) { + sqlite3* db = db::open(); + if (!db) + return send_error(500, "db::open failed"); + if (!db::delete_message(db, id)) { + db::close(db); + return send_error(500, "db::delete_message failed"); + } + db::close(db); + char* p = append_http_response_message(obuf_.p, 200); + return send_response(obuf_.p, p, ""); + } else { + return send_error(405); + } +} + +} // namespace server +} // namespace lf diff --git a/llamafile/server/embedding.cpp b/llamafile/server/embedding.cpp index a97b7ff78d..7a23d7d726 100644 --- a/llamafile/server/embedding.cpp +++ b/llamafile/server/embedding.cpp @@ -17,9 +17,9 @@ #include "client.h" #include "llama.cpp/llama.h" +#include "llamafile/json.h" #include "llamafile/server/cleanup.h" #include "llamafile/server/fastjson.h" -#include "llamafile/server/json.h" #include "llamafile/server/log.h" #include "llamafile/server/utils.h" #include diff --git a/llamafile/server/flagz.cpp b/llamafile/server/flagz.cpp index b3dfc9b0a0..3658163d24 100644 --- a/llamafile/server/flagz.cpp +++ b/llamafile/server/flagz.cpp @@ -17,22 +17,24 @@ #include "client.h" #include "llama.cpp/llama.h" +#include "llamafile/json.h" #include "llamafile/llamafile.h" #include "llamafile/string.h" -#include "llamafile/server/json.h" namespace lf { namespace server { -static bool is_base_model(llama_model *model) { - +static bool +is_base_model(llama_model* model) +{ // check if user explicitly passed --chat-template flag if (*FLAG_chat_template) return false; // check if gguf metadata has chat template. this should always be // present for "instruct" models, and never specified on base ones - return llama_model_meta_val_str(model, "tokenizer.chat_template", 0, 0) == -1; + return llama_model_meta_val_str(model, "tokenizer.chat_template", 0, 0) == + -1; } bool diff --git a/llamafile/server/tokenize.cpp b/llamafile/server/tokenize.cpp index 3c1a1d4cef..87c2f65a60 100644 --- a/llamafile/server/tokenize.cpp +++ b/llamafile/server/tokenize.cpp @@ -17,9 +17,9 @@ #include "client.h" #include "llama.cpp/llama.h" +#include "llamafile/json.h" #include "llamafile/server/cleanup.h" #include "llamafile/server/fastjson.h" -#include "llamafile/server/json.h" #include "llamafile/server/log.h" #include "llamafile/server/signals.h" #include "llamafile/server/utils.h" @@ -63,20 +63,19 @@ Client::get_tokenize_params(TokenizeParams* params) } else if (IsMimeType(HeaderData(kHttpContentType), HeaderLength(kHttpContentType), "application/json")) { - std::pair json = - Json::parse(std::string(payload_)); - if (json.first != Json::success) - return send_error(400, Json::StatusToString(json.first)); - if (!json.second.isObject()) + auto [status, json] = Json::parse(std::string(payload_)); + if (status != Json::success) + return send_error(400, Json::StatusToString(status)); + if (!json.isObject()) return send_error(400, "JSON body must be an object"); - if (!json.second["prompt"].isString()) + if (!json["prompt"].isString()) return send_error(400, "JSON missing \"prompt\" key"); - params->content = std::move(json.second["prompt"].getString()); + params->content = std::move(json["prompt"].getString()); params->prompt = params->content; - if (json.second["add_special"].isBool()) - params->add_special = json.second["add_special"].getBool(); - if (json.second["parse_special"].isBool()) - params->parse_special = json.second["parse_special"].getBool(); + if (json["add_special"].isBool()) + params->add_special = json["add_special"].getBool(); + if (json["parse_special"].isBool()) + params->parse_special = json["parse_special"].getBool(); } else { return send_error(501, "Content Type Not Implemented"); } @@ -101,10 +100,6 @@ Client::tokenize() if (!get_tokenize_params(params)) return false; - // get optional parameters - bool add_special = atob(or_empty(param("add_special")), true); - bool parse_special = atob(or_empty(param("parse_special")), false); - // setup statistics rusage rustart = {}; getrusage(RUSAGE_THREAD, &rustart); @@ -130,10 +125,10 @@ Client::tokenize() char* p = obuf_.p; p = stpcpy(p, "{\n"); p = stpcpy(p, " \"add_special\": "); - p = encode_bool(p, add_special); + p = encode_bool(p, params->add_special); p = stpcpy(p, ",\n"); p = stpcpy(p, " \"parse_special\": "); - p = encode_bool(p, parse_special); + p = encode_bool(p, params->parse_special); p = stpcpy(p, ",\n"); p = stpcpy(p, " \"tokens\": ["); for (int i = 0; i < count; ++i) { diff --git a/llamafile/server/utils.h b/llamafile/server/utils.h index bd00dad5f0..1b2ad8047f 100644 --- a/llamafile/server/utils.h +++ b/llamafile/server/utils.h @@ -28,8 +28,6 @@ namespace server { class Atom; -extern const signed char kHexToInt[256]; - bool atob(std::string_view, bool); diff --git a/llamafile/server/v1_chat_completions.cpp b/llamafile/server/v1_chat_completions.cpp index 92b24ba126..fb4e640497 100644 --- a/llamafile/server/v1_chat_completions.cpp +++ b/llamafile/server/v1_chat_completions.cpp @@ -18,12 +18,12 @@ #include "client.h" #include "llama.cpp/llama.h" #include "llama.cpp/sampling.h" +#include "llamafile/json.h" #include "llamafile/llama.h" #include "llamafile/macros.h" #include "llamafile/server/atom.h" #include "llamafile/server/cleanup.h" #include "llamafile/server/fastjson.h" -#include "llamafile/server/json.h" #include "llamafile/server/log.h" #include "llamafile/server/server.h" #include "llamafile/server/slot.h" @@ -461,9 +461,8 @@ Client::v1_chat_completions() state->atoms.emplace_back(llama_token_bos(model_)); // turn text into tokens - state->prompt = - llama_chat_apply_template( - model_, FLAG_chat_template, params->messages, ADD_ASSISTANT); + state->prompt = llama_chat_apply_template( + model_, FLAG_chat_template, params->messages, ADD_ASSISTANT); atomize(model_, &state->atoms, state->prompt, PARSE_SPECIAL); // find appropriate slot diff --git a/llamafile/server/v1_completions.cpp b/llamafile/server/v1_completions.cpp index f199ceb6f7..ac9e4678cf 100644 --- a/llamafile/server/v1_completions.cpp +++ b/llamafile/server/v1_completions.cpp @@ -18,12 +18,12 @@ #include "client.h" #include "llama.cpp/llama.h" #include "llama.cpp/sampling.h" +#include "llamafile/json.h" #include "llamafile/llama.h" #include "llamafile/macros.h" #include "llamafile/server/atom.h" #include "llamafile/server/cleanup.h" #include "llamafile/server/fastjson.h" -#include "llamafile/server/json.h" #include "llamafile/server/log.h" #include "llamafile/server/server.h" #include "llamafile/server/slot.h" diff --git a/llamafile/utils.h b/llamafile/utils.h new file mode 100644 index 0000000000..18959579fe --- /dev/null +++ b/llamafile/utils.h @@ -0,0 +1,24 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi +// +// Copyright 2024 Mozilla Foundation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace lf { + +extern const signed char kHexToInt[256]; + +} // namespace lf diff --git a/third_party/sqlite/BUILD.mk b/third_party/sqlite/BUILD.mk index 4d82b027ba..ddaa68857c 100644 --- a/third_party/sqlite/BUILD.mk +++ b/third_party/sqlite/BUILD.mk @@ -10,7 +10,7 @@ THIRD_PARTY_SQLITE_SRCS = \ THIRD_PARTY_SQLITE_HDRS = \ third_party/sqlite/sqlite3.h \ -o/$(MODE)/third_party/sqlite/sqlite.a: \ +o/$(MODE)/third_party/sqlite/sqlite3.a: \ o/$(MODE)/third_party/sqlite/sqlite3.o \ o/$(MODE)/third_party/sqlite/shell: \ @@ -60,4 +60,4 @@ o/$(MODE)/third_party/sqlite/sqlite3.o: \ .PHONY: o/$(MODE)/third_party/sqlite o/$(MODE)/third_party/sqlite: \ o/$(MODE)/third_party/sqlite/shell \ - o/$(MODE)/third_party/sqlite/sqlite.a \ + o/$(MODE)/third_party/sqlite/sqlite3.a \