Skip to content

Commit

Permalink
Add preliminary chat history functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Nov 30, 2024
1 parent abe0d1d commit 36696f3
Show file tree
Hide file tree
Showing 20 changed files with 633 additions and 64 deletions.
7 changes: 7 additions & 0 deletions llamafile/BUILD.mk
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ o/$(MODE)/llamafile: \
o/$(MODE)/llamafile/parse_cidr_test.runs \
o/$(MODE)/llamafile/pool_cancel_test.runs \
o/$(MODE)/llamafile/pool_test.runs \
o/$(MODE)/llamafile/json_test.runs \
o/$(MODE)/llamafile/thread_test.runs \
o/$(MODE)/llamafile/vmathf_test.runs \

Expand Down Expand Up @@ -156,6 +157,12 @@ o/$(MODE)/llamafile/tinyblas_cpu_sgemm_arm82.o: \
################################################################################
# testing

o/$(MODE)/llamafile/json_test: \
o/$(MODE)/llamafile/json_test.o \
o/$(MODE)/llamafile/json.o \
o/$(MODE)/llamafile/hextoint.o \
o/$(MODE)/double-conversion/double-conversion.a \

o/$(MODE)/llamafile/vmathf_test: \
o/$(MODE)/llamafile/vmathf_test.o \
o/$(MODE)/llama.cpp/llama.cpp.a \
Expand Down
312 changes: 295 additions & 17 deletions llamafile/db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,24 @@
// limitations under the License.

#include "db.h"
#include "llamafile/json.h"
#include "llamafile/llamafile.h"
#include "third_party/sqlite/sqlite3.h"
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string>

__static_yoink("llamafile/schema.sql");

#define SCHEMA_VERSION 1

namespace llamafile {
namespace lf {
namespace db {

static bool table_exists(sqlite3* db, const char* table_name) {
const char* query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;";
sqlite3_stmt* stmt;
static bool table_exists(sqlite3 *db, const char *table_name) {
const char *query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;";
sqlite3_stmt *stmt;
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return false;
}
Expand All @@ -41,16 +46,16 @@ static bool table_exists(sqlite3* db, const char* table_name) {
return exists;
}

static bool init_schema(sqlite3* db) {
FILE* f = fopen("/zip/llamafile/schema.sql", "r");
static bool init_schema(sqlite3 *db) {
FILE *f = fopen("/zip/llamafile/schema.sql", "r");
if (!f)
return false;
std::string schema;
int c;
while ((c = fgetc(f)) != EOF)
schema += c;
fclose(f);
char* errmsg = nullptr;
char *errmsg = nullptr;
int rc = sqlite3_exec(db, schema.c_str(), nullptr, nullptr, &errmsg);
if (rc != SQLITE_OK) {
if (errmsg) {
Expand All @@ -62,37 +67,310 @@ static bool init_schema(sqlite3* db) {
return true;
}

sqlite3* open(const char* path) {
sqlite3* db;
int rc = sqlite3_open(path, &db);
static sqlite3 *open_impl() {
std::string path;
if (FLAG_db) {
path = FLAG_db;
} else {
const char *home = getenv("HOME");
if (home) {
path = std::string(home) + "/.llamafile/llamafile.sqlite3";
} else {
path = "llamafile.sqlite3";
}
}
sqlite3 *db;
int rc = sqlite3_open(path.c_str(), &db);
if (rc) {
fprintf(stderr, "%s: can't open database: %s\n", path, sqlite3_errmsg(db));
fprintf(stderr, "%s: can't open database: %s\n", path.c_str(), sqlite3_errmsg(db));
return nullptr;
}
char* errmsg = nullptr;
char *errmsg = nullptr;
if (sqlite3_exec(db, "PRAGMA journal_mode=WAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) {
fprintf(stderr, "Failed to set journal mode to WAL: %s\n", errmsg);
fprintf(stderr, "%s: failed to set journal mode to wal: %s\n", path.c_str(), errmsg);
sqlite3_free(errmsg);
sqlite3_close(db);
return nullptr;
}
if (sqlite3_exec(db, "PRAGMA synchronous=NORMAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) {
fprintf(stderr, "Failed to set synchronous to NORMAL: %s\n", errmsg);
fprintf(stderr, "%s: failed to set synchronous to normal: %s\n", path.c_str(), errmsg);
sqlite3_free(errmsg);
sqlite3_close(db);
return nullptr;
}
if (!table_exists(db, "metadata") && !init_schema(db)) {
fprintf(stderr, "%s: failed to initialize database schema\n", path);
fprintf(stderr, "%s: failed to initialize database schema\n", path.c_str());
sqlite3_close(db);
return nullptr;
}
return db;
}

void close(sqlite3* db) {
sqlite3 *open() {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
sqlite3 *res = open_impl();
pthread_setcancelstate(cs, 0);
return res;
}

void close(sqlite3 *db) {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
sqlite3_close(db);
pthread_setcancelstate(cs, 0);
}

static int64_t add_chat_impl(sqlite3 *db, const std::string &model, const std::string &title) {
const char *query = "INSERT INTO chats (model, title) VALUES (?, ?);";
sqlite3_stmt *stmt;
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return -1;
}
if (sqlite3_bind_text(stmt, 1, model.data(), model.size(), SQLITE_STATIC) != SQLITE_OK ||
sqlite3_bind_text(stmt, 2, title.data(), title.size(), SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
return -1;
}
if (sqlite3_step(stmt) != SQLITE_DONE) {
sqlite3_finalize(stmt);
return -1;
}
sqlite3_finalize(stmt);
return sqlite3_last_insert_rowid(db);
}

int64_t add_chat(sqlite3 *db, const std::string &model, const std::string &title) {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
int64_t res = add_chat_impl(db, model, title);
pthread_setcancelstate(cs, 0);
return res;
}

static int64_t add_message_impl(sqlite3 *db, int64_t chat_id, const std::string &role,
const std::string &content, double temperature, double top_p,
double presence_penalty, double frequency_penalty) {
const char *query = "INSERT INTO messages (chat_id, role, content, temperature, "
"top_p, presence_penalty, frequency_penalty) "
"VALUES (?, ?, ?, ?, ?, ?, ?);";
sqlite3_stmt *stmt;
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return -1;
}
if (sqlite3_bind_int64(stmt, 1, chat_id) != SQLITE_OK ||
sqlite3_bind_text(stmt, 2, role.data(), role.size(), SQLITE_STATIC) != SQLITE_OK ||
sqlite3_bind_text(stmt, 3, content.data(), content.size(), SQLITE_STATIC) != SQLITE_OK ||
sqlite3_bind_double(stmt, 4, temperature) != SQLITE_OK ||
sqlite3_bind_double(stmt, 5, top_p) != SQLITE_OK ||
sqlite3_bind_double(stmt, 6, presence_penalty) != SQLITE_OK ||
sqlite3_bind_double(stmt, 7, frequency_penalty) != SQLITE_OK) {
sqlite3_finalize(stmt);
return -1;
}
if (sqlite3_step(stmt) != SQLITE_DONE) {
sqlite3_finalize(stmt);
return -1;
}
sqlite3_finalize(stmt);
return sqlite3_last_insert_rowid(db);
}

int64_t add_message(sqlite3 *db, int64_t chat_id, const std::string &role,
const std::string &content, double temperature, double top_p,
double presence_penalty, double frequency_penalty) {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
int64_t res = add_message_impl(db, chat_id, role, content, temperature, top_p, presence_penalty,
frequency_penalty);
pthread_setcancelstate(cs, 0);
return res;
}

static bool update_title_impl(sqlite3 *db, int64_t chat_id, const std::string &title) {
const char *query = "UPDATE chats SET title = ? WHERE id = ?;";
sqlite3_stmt *stmt;
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return false;
}
if (sqlite3_bind_text(stmt, 1, title.data(), title.size(), SQLITE_STATIC) != SQLITE_OK ||
sqlite3_bind_int64(stmt, 2, chat_id) != SQLITE_OK) {
sqlite3_finalize(stmt);
return false;
}
bool success = sqlite3_step(stmt) == SQLITE_DONE;
sqlite3_finalize(stmt);
return success;
}

bool update_title(sqlite3 *db, int64_t chat_id, const std::string &title) {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
bool res = update_title_impl(db, chat_id, title);
pthread_setcancelstate(cs, 0);
return res;
}

static bool delete_message_impl(sqlite3 *db, int64_t message_id) {
const char *query = "DELETE FROM messages WHERE id = ?;";
sqlite3_stmt *stmt;
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return false;
}
if (sqlite3_bind_int64(stmt, 1, message_id) != SQLITE_OK) {
sqlite3_finalize(stmt);
return false;
}
bool success = sqlite3_step(stmt) == SQLITE_DONE;
sqlite3_finalize(stmt);
return success;
}

bool delete_message(sqlite3 *db, int64_t message_id) {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
bool res = delete_message_impl(db, message_id);
pthread_setcancelstate(cs, 0);
return res;
}

static jt::Json get_chats_impl(sqlite3 *db) {
const char *query = "SELECT id, created_at, model, title FROM chats ORDER BY created_at DESC;";
sqlite3_stmt *stmt;
jt::Json result;
result.setArray();
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return result;
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
jt::Json chat;
chat.setObject();
chat["id"] = sqlite3_column_int64(stmt, 0);
chat["created_at"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
chat["model"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2));
chat["title"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3));
result.getArray().push_back(std::move(chat));
}
sqlite3_finalize(stmt);
return result;
}

jt::Json get_chats(sqlite3 *db) {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
jt::Json res = get_chats_impl(db);
pthread_setcancelstate(cs, 0);
return res;
}

static jt::Json get_messages_impl(sqlite3 *db, int64_t chat_id) {
const char *query = "SELECT id, created_at, role, content, temperature, top_p, "
"presence_penalty, frequency_penalty "
"FROM messages "
"WHERE chat_id = ? "
"ORDER BY created_at DESC;";
sqlite3_stmt *stmt;
jt::Json result;
result.setArray();
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return result;
}
if (sqlite3_bind_int64(stmt, 1, chat_id) != SQLITE_OK) {
sqlite3_finalize(stmt);
return result;
}
while (sqlite3_step(stmt) == SQLITE_ROW) {
jt::Json msg;
msg.setObject();
msg["id"] = sqlite3_column_int64(stmt, 0);
msg["created_at"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
msg["role"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2));
msg["content"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3));
msg["temperature"] = sqlite3_column_double(stmt, 4);
msg["top_p"] = sqlite3_column_double(stmt, 5);
msg["presence_penalty"] = sqlite3_column_double(stmt, 6);
msg["frequency_penalty"] = sqlite3_column_double(stmt, 7);
result.getArray().push_back(std::move(msg));
}
sqlite3_finalize(stmt);
return result;
}

jt::Json get_messages(sqlite3 *db, int64_t chat_id) {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
jt::Json res = get_messages_impl(db, chat_id);
pthread_setcancelstate(cs, 0);
return res;
}

static jt::Json get_chat_impl(sqlite3 *db, int64_t chat_id) {
const char *query = "SELECT id, created_at, model, title FROM chats WHERE id = ?;";
sqlite3_stmt *stmt;
jt::Json result;
result.setObject();
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return result;
}
if (sqlite3_bind_int64(stmt, 1, chat_id) != SQLITE_OK) {
sqlite3_finalize(stmt);
return result;
}
if (sqlite3_step(stmt) == SQLITE_ROW) {
result["id"] = sqlite3_column_int64(stmt, 0);
result["created_at"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
result["model"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 2));
result["title"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3));
}
sqlite3_finalize(stmt);
return result;
}

jt::Json get_chat(sqlite3 *db, int64_t chat_id) {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
jt::Json res = get_chat_impl(db, chat_id);
pthread_setcancelstate(cs, 0);
return res;
}

static jt::Json get_message_impl(sqlite3 *db, int64_t message_id) {
const char *query = "SELECT id, created_at, chat_id, role, content, temperature, top_p, "
"presence_penalty, frequency_penalty "
"FROM messages WHERE id = ?"
"ORDER BY created_at ASC;";
sqlite3_stmt *stmt;
jt::Json result;
result.setObject();
if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) {
return result;
}
if (sqlite3_bind_int64(stmt, 1, message_id) != SQLITE_OK) {
sqlite3_finalize(stmt);
return result;
}
if (sqlite3_step(stmt) == SQLITE_ROW) {
result["id"] = sqlite3_column_int64(stmt, 0);
result["created_at"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 1));
result["chat_id"] = sqlite3_column_int64(stmt, 2);
result["role"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 3));
result["content"] = reinterpret_cast<const char *>(sqlite3_column_text(stmt, 4));
result["temperature"] = sqlite3_column_double(stmt, 5);
result["top_p"] = sqlite3_column_double(stmt, 6);
result["presence_penalty"] = sqlite3_column_double(stmt, 7);
result["frequency_penalty"] = sqlite3_column_double(stmt, 8);
}
sqlite3_finalize(stmt);
return result;
}

jt::Json get_message(sqlite3 *db, int64_t message_id) {
int cs;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
jt::Json res = get_message_impl(db, message_id);
pthread_setcancelstate(cs, 0);
return res;
}

} // namespace db
} // namespace llamafile
} // namespace lf
Loading

0 comments on commit 36696f3

Please sign in to comment.