diff --git a/Cargo.lock b/Cargo.lock index fb241d57..4b5bc211 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5358,6 +5358,7 @@ version = "0.1.2" dependencies = [ "anyhow", "chrono", + "httpmock", "rig-core", "rusqlite", "serde", diff --git a/rig-sqlite/Cargo.toml b/rig-sqlite/Cargo.toml index 3529f9b2..3b71cd28 100644 --- a/rig-sqlite/Cargo.toml +++ b/rig-sqlite/Cargo.toml @@ -23,5 +23,6 @@ chrono = "0.4" [dev-dependencies] anyhow = "1.0.86" +httpmock = "0.7.0" tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/rig-sqlite/src/lib.rs b/rig-sqlite/src/lib.rs index e3525247..7e6f62ef 100644 --- a/rig-sqlite/src/lib.rs +++ b/rig-sqlite/src/lib.rs @@ -456,102 +456,3 @@ impl ColumnValue for String { "TEXT" } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable}; - use rig::{ - embeddings::EmbeddingsBuilder, - providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - Embed, - }; - use rusqlite::ffi::sqlite3_auto_extension; - use sqlite_vec::sqlite3_vec_init; - use tokio_rusqlite::Connection; - - #[derive(Embed, Clone, Debug, Deserialize)] - struct TestDocument { - id: String, - #[embed] - content: String, - } - - impl SqliteVectorStoreTable for TestDocument { - fn name() -> &'static str { - "test_documents" - } - - fn schema() -> Vec { - vec![ - Column::new("id", "TEXT PRIMARY KEY"), - Column::new("content", "TEXT"), - ] - } - - fn id(&self) -> String { - self.id.clone() - } - - fn column_values(&self) -> Vec<(&'static str, Box)> { - vec![ - ("id", Box::new(self.id.clone())), - ("content", Box::new(self.content.clone())), - ] - } - } - - #[tokio::test] - async fn test_vector_search() -> Result<(), anyhow::Error> { - // Initialize the sqlite-vec extension - unsafe { - sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); - } - - // Initialize in-memory SQLite connection - let conn = Connection::open(":memory:").await?; - - // Initialize OpenAI client - let openai_api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let openai_client = Client::new(&openai_api_key); - let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - - let documents = vec![ - TestDocument { - id: "doc0".to_string(), - content: "The quick brown fox jumps over the lazy dog".to_string(), - }, - TestDocument { - id: "doc1".to_string(), - content: "The lazy dog sleeps while the quick brown fox runs".to_string(), - }, - ]; - - let embeddings = EmbeddingsBuilder::new(model.clone()) - .documents(documents)? - .build() - .await?; - - // Initialize SQLite vector store - let vector_store = SqliteVectorStore::new(conn, &model).await?; - - // Add embeddings to vector store - vector_store.add_rows(embeddings).await?; - - // Create vector index - let index = vector_store.index(model); - - // Query the index - let results = index - .top_n::("The quick brown fox jumps over the lazy dog", 1) - .await?; - assert_eq!(results.len(), 1); - - let id_results = index - .top_n_ids("The quick brown fox jumps over the lazy dog", 1) - .await?; - assert_eq!(id_results.len(), 1); - - Ok(()) - } -} diff --git a/rig-sqlite/tests/integration_test.rs b/rig-sqlite/tests/integration_test.rs new file mode 100644 index 00000000..e818f87a --- /dev/null +++ b/rig-sqlite/tests/integration_test.rs @@ -0,0 +1,193 @@ +use serde_json::json; + +use rig::vector_store::VectorStoreIndex; +use rig::{ + embeddings::{Embedding, EmbeddingsBuilder}, + providers::openai, + Embed, OneOrMany, +}; +use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable}; +use rusqlite::ffi::sqlite3_auto_extension; +use sqlite_vec::sqlite3_vec_init; +use tokio_rusqlite::Connection; + +#[derive(Embed, Clone, serde::Deserialize, Debug)] +struct Word { + id: String, + #[embed] + definition: String, +} + +impl SqliteVectorStoreTable for Word { + fn name() -> &'static str { + "documents" + } + + fn schema() -> Vec { + vec![ + Column::new("id", "TEXT PRIMARY KEY"), + Column::new("definition", "TEXT"), + ] + } + + fn id(&self) -> String { + self.id.clone() + } + + fn column_values(&self) -> Vec<(&'static str, Box)> { + vec![ + ("id", Box::new(self.id.clone())), + ("definition", Box::new(self.definition.clone())), + ] + } +} + +#[tokio::test] +async fn vector_search_test() { + // Initialize the `sqlite-vec`extension + // See: https://alexgarcia.xyz/sqlite-vec/rust.html + unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); + } + + // Initialize SQLite connection + let conn = Connection::open("vector_store.db") + .await + .expect("Could not initialize SQLite connection"); + + // Setup mock openai API + let server = httpmock::MockServer::start(); + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .header("Authorization", "Bearer TEST") + .json_body(json!({ + "input": [ + "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets", + "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.", + "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans." + ], + "model": "text-embedding-ada-002", + })); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": vec![-0.001; 1536], + "index": 0 + }, + { + "object": "embedding", + "embedding": vec![0.0023064255; 1536], + "index": 1 + }, + { + "object": "embedding", + "embedding": vec![-0.001; 1536], + "index": 2 + }, + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + )); + }); + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/embeddings") + .header("Authorization", "Bearer TEST") + .json_body(json!({ + "input": [ + "What is a glarb?", + ], + "model": "text-embedding-ada-002", + })); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": vec![0.0024064254; 1536], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + )); + }); + + // Initialize OpenAI client + let openai_client = openai::Client::from_url("TEST", &server.base_url()); + + // Select the embedding model and generate our embeddings + let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); + + let embeddings = create_embeddings(model.clone()).await; + + // Initialize SQLite vector store + let vector_store = SqliteVectorStore::new(conn, &model) + .await + .expect("Could not initialize SQLite vector store"); + + // Add embeddings to vector store + vector_store + .add_rows(embeddings) + .await + .expect("Could not add embeddings to vector store"); + + // Create a vector index on our vector store + let index = vector_store.index(model); + + // Query the index + let results = index + .top_n::("What is a glarb?", 1) + .await + .expect(""); + + let (_, _, value) = &results.first().expect(""); + + assert_eq!( + value, + &serde_json::json!({ + "id": "doc1", + "definition": "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.", + }) + ) +} + +async fn create_embeddings(model: openai::EmbeddingModel) -> Vec<(Word, OneOrMany)> { + let words = vec![ + Word { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + }, + Word { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + }, + Word { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + } + ]; + + EmbeddingsBuilder::new(model) + .documents(words) + .expect("") + .build() + .await + .expect("") +}