From abe0d1dc5a6cde8f4cec3c91c41325ef83ba6629 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Thu, 28 Nov 2024 23:57:14 -0800 Subject: [PATCH] Begin setting up chat history database --- llamafile/db.cpp | 98 +++++++++++++++++++++++++++++ llamafile/db.h | 28 +++++++++ llamafile/flags.cpp | 8 +++ llamafile/llamafile.h | 1 + llamafile/schema.sql | 24 +++++++ third_party/sqlite/BUILD.mk | 7 +++ third_party/sqlite/README.llamafile | 1 + third_party/sqlite/shell.c | 12 ++-- third_party/sqlite/sqlite3.c | 12 ++-- 9 files changed, 179 insertions(+), 12 deletions(-) create mode 100644 llamafile/db.cpp create mode 100644 llamafile/db.h create mode 100644 llamafile/schema.sql diff --git a/llamafile/db.cpp b/llamafile/db.cpp new file mode 100644 index 0000000000..2f13510f47 --- /dev/null +++ b/llamafile/db.cpp @@ -0,0 +1,98 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi +// +// Copyright 2024 Mozilla Foundation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "db.h" +#include +#include + +__static_yoink("llamafile/schema.sql"); + +#define SCHEMA_VERSION 1 + +namespace llamafile { +namespace db { + +static bool table_exists(sqlite3* db, const char* table_name) { + const char* query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?;"; + sqlite3_stmt* stmt; + if (sqlite3_prepare_v2(db, query, -1, &stmt, nullptr) != SQLITE_OK) { + return false; + } + if (sqlite3_bind_text(stmt, 1, table_name, -1, SQLITE_STATIC) != SQLITE_OK) { + sqlite3_finalize(stmt); + return false; + } + bool exists = sqlite3_step(stmt) == SQLITE_ROW; + sqlite3_finalize(stmt); + return exists; +} + +static bool init_schema(sqlite3* db) { + FILE* f = fopen("/zip/llamafile/schema.sql", "r"); + if (!f) + return false; + std::string schema; + int c; + while ((c = fgetc(f)) != EOF) + schema += c; + fclose(f); + char* errmsg = nullptr; + int rc = sqlite3_exec(db, schema.c_str(), nullptr, nullptr, &errmsg); + if (rc != SQLITE_OK) { + if (errmsg) { + fprintf(stderr, "SQL error: %s\n", errmsg); + sqlite3_free(errmsg); + } + return false; + } + return true; +} + +sqlite3* open(const char* path) { + sqlite3* db; + int rc = sqlite3_open(path, &db); + if (rc) { + fprintf(stderr, "%s: can't open database: %s\n", path, sqlite3_errmsg(db)); + return nullptr; + } + char* errmsg = nullptr; + if (sqlite3_exec(db, "PRAGMA journal_mode=WAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) { + fprintf(stderr, "Failed to set journal mode to WAL: %s\n", errmsg); + sqlite3_free(errmsg); + sqlite3_close(db); + return nullptr; + } + if (sqlite3_exec(db, "PRAGMA synchronous=NORMAL;", nullptr, nullptr, &errmsg) != SQLITE_OK) { + fprintf(stderr, "Failed to set synchronous to NORMAL: %s\n", errmsg); + sqlite3_free(errmsg); + sqlite3_close(db); + return nullptr; + } + if (!table_exists(db, "metadata") && !init_schema(db)) { + fprintf(stderr, "%s: failed to initialize database schema\n", path); + sqlite3_close(db); + return nullptr; + } + return db; +} + +void close(sqlite3* db) { + sqlite3_close(db); +} + +} // namespace db +} // namespace llamafile diff --git a/llamafile/db.h b/llamafile/db.h new file mode 100644 index 0000000000..94e98b7b37 --- /dev/null +++ b/llamafile/db.h @@ -0,0 +1,28 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi +// +// Copyright 2024 Mozilla Foundation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "third_party/sqlite/sqlite3.h" + +namespace llamafile { +namespace db { + +sqlite3* open(const char*); +void close(sqlite3*); + +} // namespace db +} // namespace llamafile diff --git a/llamafile/flags.cpp b/llamafile/flags.cpp index fd67dc8d1f..936a79dd90 100644 --- a/llamafile/flags.cpp +++ b/llamafile/flags.cpp @@ -53,6 +53,7 @@ bool FLAG_tinyblas = false; bool FLAG_trace = false; bool FLAG_unsecure = false; const char *FLAG_chat_template = ""; +const char *FLAG_db = nullptr; const char *FLAG_file = nullptr; const char *FLAG_ip_header = nullptr; const char *FLAG_listen = "127.0.0.1:8080"; @@ -185,6 +186,13 @@ void llamafile_get_flags(int argc, char **argv) { continue; } + if (!strcmp(flag, "--db")) { + if (i == argc) + missing("--db"); + FLAG_db = argv[i++]; + continue; + } + ////////////////////////////////////////////////////////////////////// // server flags diff --git a/llamafile/llamafile.h b/llamafile/llamafile.h index 4288ceb80a..d74b14c0eb 100644 --- a/llamafile/llamafile.h +++ b/llamafile/llamafile.h @@ -24,6 +24,7 @@ extern bool FLAG_trace; extern bool FLAG_trap; extern bool FLAG_unsecure; extern const char *FLAG_chat_template; +extern const char *FLAG_db; extern const char *FLAG_file; extern const char *FLAG_ip_header; extern const char *FLAG_listen; diff --git a/llamafile/schema.sql b/llamafile/schema.sql new file mode 100644 index 0000000000..9673b00ed1 --- /dev/null +++ b/llamafile/schema.sql @@ -0,0 +1,24 @@ +CREATE TABLE metadata ( + key TEXT PRIMARY KEY, + value TEXT +); + +CREATE TABLE chats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + model TEXT, + title TEXT +); + +CREATE TABLE messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + chat_id INTEGER, + role TEXT, + message TEXT, + temperature REAL, + top_p REAL, + presence_penalty REAL, + frequency_penalty REAL, + FOREIGN KEY (chat_id) REFERENCES chats(id) +); diff --git a/third_party/sqlite/BUILD.mk b/third_party/sqlite/BUILD.mk index 3c433a741a..4d82b027ba 100644 --- a/third_party/sqlite/BUILD.mk +++ b/third_party/sqlite/BUILD.mk @@ -3,6 +3,13 @@ PKGS += THIRD_PARTY_SQLITE +THIRD_PARTY_SQLITE_SRCS = \ + third_party/sqlite/sqlite3.c \ + third_party/sqlite/shell.c \ + +THIRD_PARTY_SQLITE_HDRS = \ + third_party/sqlite/sqlite3.h \ + o/$(MODE)/third_party/sqlite/sqlite.a: \ o/$(MODE)/third_party/sqlite/sqlite3.o \ diff --git a/third_party/sqlite/README.llamafile b/third_party/sqlite/README.llamafile index 31aa6ac3fe..a0e3fe9ce1 100644 --- a/third_party/sqlite/README.llamafile +++ b/third_party/sqlite/README.llamafile @@ -13,3 +13,4 @@ LICENSE LOCAL CHANGES - Renamed to + - Mangled some quoted includes to not confuse mkdeps diff --git a/third_party/sqlite/shell.c b/third_party/sqlite/shell.c index 7f014b4ac3..52465d5ae4 100644 --- a/third_party/sqlite/shell.c +++ b/third_party/sqlite/shell.c @@ -123,7 +123,7 @@ typedef sqlite3_int64 i64; typedef sqlite3_uint64 u64; typedef unsigned char u8; #if SQLITE_USER_AUTHENTICATION -# include "sqlite3userauth.h" +# includez "sqlite3userauth.h" #endif #include #include @@ -169,7 +169,7 @@ typedef unsigned char u8; #elif HAVE_LINENOISE -# include "linenoise.h" +# includez "linenoise.h" # define shell_add_history(X) linenoiseHistoryAdd(X) # define shell_read_history(X) linenoiseHistoryLoad(X) # define shell_write_history(X) linenoiseHistorySave(X) @@ -1710,7 +1710,7 @@ static void shellAddSchemaName( #define WIN32_LEAN_AND_MEAN #endif -#include "windows.h" +#includez "windows.h" /* ** We need several support functions from the SQLite core. @@ -7996,10 +7996,10 @@ SQLITE_EXTENSION_INIT1 # include # include #else -# include "windows.h" +# includez "windows.h" # include # include -/* # include "test_windirent.h" */ +/* # includez "test_windirent.h" */ # define dirent DIRENT # ifndef chmod # define chmod _chmod @@ -8945,7 +8945,7 @@ int sqlite3_fileio_init( * redefined SQLite API calls as the above extension code does. * Just pull in this .c to accomplish this. As a beneficial side * effect, this extension becomes a single translation unit. */ -# include "test_windirent.c" +# includez "test_windirent.c" #endif /************************* End ../ext/misc/fileio.c ********************/ diff --git a/third_party/sqlite/sqlite3.c b/third_party/sqlite/sqlite3.c index 099c5482f6..3b9b41a03d 100644 --- a/third_party/sqlite/sqlite3.c +++ b/third_party/sqlite/sqlite3.c @@ -280,9 +280,9 @@ ** disabled. */ #if defined(_HAVE_MINGW_H) -# include "mingw.h" +# includez "mingw.h" #elif defined(_HAVE__MINGW_H) -# include "_mingw.h" +# includez "_mingw.h" #endif /* @@ -13911,7 +13911,7 @@ struct fts5_api { ** autoconf-based build */ #if defined(_HAVE_SQLITE_CONFIG_H) && !defined(SQLITECONFIG_H) -#include "sqlite_cfg.h" +#includez "sqlite_cfg.h" #define SQLITECONFIG_H 1 #endif @@ -29996,7 +29996,7 @@ SQLITE_PRIVATE sqlite3_mutex_methods const *sqlite3DefaultMutex(void){ /* ** Include the primary Windows SDK header file. */ -#include "windows.h" +#includez "windows.h" #ifdef __CYGWIN__ # include @@ -196803,7 +196803,7 @@ SQLITE_PRIVATE int sqlite3Fts3InitTokenizer( #ifdef SQLITE_TEST -#include "tclsqlite.h" +#includez "tclsqlite.h" /* #include */ /* @@ -211715,7 +211715,7 @@ SQLITE_PRIVATE int sqlite3GetToken(const unsigned char*,int*); /* In the SQLite ** found in sqliteInt.h */ #if !defined(SQLITE_AMALGAMATION) -#include "sqlite3rtree.h" +#includez "sqlite3rtree.h" typedef sqlite3_int64 i64; typedef sqlite3_uint64 u64; typedef unsigned char u8;