Skip to content

Commit

Permalink
test(rig-sqlite): Add integration test (#202)
Browse files Browse the repository at this point in the history
* test(rig-sqlite): Add integration test

* misc: Remove sqlite unit test
  • Loading branch information
cvauclair authored Jan 14, 2025
1 parent 13838c5 commit 49a4467
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 99 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rig-sqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ chrono = "0.4"

[dev-dependencies]
anyhow = "1.0.86"
httpmock = "0.7.0"
tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
99 changes: 0 additions & 99 deletions rig-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,102 +456,3 @@ impl ColumnValue for String {
"TEXT"
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
Embed,
};
use rusqlite::ffi::sqlite3_auto_extension;
use sqlite_vec::sqlite3_vec_init;
use tokio_rusqlite::Connection;

#[derive(Embed, Clone, Debug, Deserialize)]
struct TestDocument {
id: String,
#[embed]
content: String,
}

impl SqliteVectorStoreTable for TestDocument {
fn name() -> &'static str {
"test_documents"
}

fn schema() -> Vec<Column> {
vec![
Column::new("id", "TEXT PRIMARY KEY"),
Column::new("content", "TEXT"),
]
}

fn id(&self) -> String {
self.id.clone()
}

fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
vec![
("id", Box::new(self.id.clone())),
("content", Box::new(self.content.clone())),
]
}
}

#[tokio::test]
async fn test_vector_search() -> Result<(), anyhow::Error> {
// Initialize the sqlite-vec extension
unsafe {
sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ())));
}

// Initialize in-memory SQLite connection
let conn = Connection::open(":memory:").await?;

// Initialize OpenAI client
let openai_api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let documents = vec![
TestDocument {
id: "doc0".to_string(),
content: "The quick brown fox jumps over the lazy dog".to_string(),
},
TestDocument {
id: "doc1".to_string(),
content: "The lazy dog sleeps while the quick brown fox runs".to_string(),
},
];

let embeddings = EmbeddingsBuilder::new(model.clone())
.documents(documents)?
.build()
.await?;

// Initialize SQLite vector store
let vector_store = SqliteVectorStore::new(conn, &model).await?;

// Add embeddings to vector store
vector_store.add_rows(embeddings).await?;

// Create vector index
let index = vector_store.index(model);

// Query the index
let results = index
.top_n::<TestDocument>("The quick brown fox jumps over the lazy dog", 1)
.await?;
assert_eq!(results.len(), 1);

let id_results = index
.top_n_ids("The quick brown fox jumps over the lazy dog", 1)
.await?;
assert_eq!(id_results.len(), 1);

Ok(())
}
}
193 changes: 193 additions & 0 deletions rig-sqlite/tests/integration_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
use serde_json::json;

use rig::vector_store::VectorStoreIndex;
use rig::{
embeddings::{Embedding, EmbeddingsBuilder},
providers::openai,
Embed, OneOrMany,
};
use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
use rusqlite::ffi::sqlite3_auto_extension;
use sqlite_vec::sqlite3_vec_init;
use tokio_rusqlite::Connection;

#[derive(Embed, Clone, serde::Deserialize, Debug)]
struct Word {
id: String,
#[embed]
definition: String,
}

impl SqliteVectorStoreTable for Word {
fn name() -> &'static str {
"documents"
}

fn schema() -> Vec<Column> {
vec![
Column::new("id", "TEXT PRIMARY KEY"),
Column::new("definition", "TEXT"),
]
}

fn id(&self) -> String {
self.id.clone()
}

fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
vec![
("id", Box::new(self.id.clone())),
("definition", Box::new(self.definition.clone())),
]
}
}

#[tokio::test]
async fn vector_search_test() {
// Initialize the `sqlite-vec`extension
// See: https://alexgarcia.xyz/sqlite-vec/rust.html
unsafe {
sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ())));
}

// Initialize SQLite connection
let conn = Connection::open("vector_store.db")
.await
.expect("Could not initialize SQLite connection");

// Setup mock openai API
let server = httpmock::MockServer::start();

server.mock(|when, then| {
when.method(httpmock::Method::POST)
.path("/embeddings")
.header("Authorization", "Bearer TEST")
.json_body(json!({
"input": [
"Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets",
"Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
"Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans."
],
"model": "text-embedding-ada-002",
}));
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"object": "list",
"data": [
{
"object": "embedding",
"embedding": vec![-0.001; 1536],
"index": 0
},
{
"object": "embedding",
"embedding": vec![0.0023064255; 1536],
"index": 1
},
{
"object": "embedding",
"embedding": vec![-0.001; 1536],
"index": 2
},
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
));
});
server.mock(|when, then| {
when.method(httpmock::Method::POST)
.path("/embeddings")
.header("Authorization", "Bearer TEST")
.json_body(json!({
"input": [
"What is a glarb?",
],
"model": "text-embedding-ada-002",
}));
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"object": "list",
"data": [
{
"object": "embedding",
"embedding": vec![0.0024064254; 1536],
"index": 0
}
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
));
});

// Initialize OpenAI client
let openai_client = openai::Client::from_url("TEST", &server.base_url());

// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002);

let embeddings = create_embeddings(model.clone()).await;

// Initialize SQLite vector store
let vector_store = SqliteVectorStore::new(conn, &model)
.await
.expect("Could not initialize SQLite vector store");

// Add embeddings to vector store
vector_store
.add_rows(embeddings)
.await
.expect("Could not add embeddings to vector store");

// Create a vector index on our vector store
let index = vector_store.index(model);

// Query the index
let results = index
.top_n::<serde_json::Value>("What is a glarb?", 1)
.await
.expect("");

let (_, _, value) = &results.first().expect("");

assert_eq!(
value,
&serde_json::json!({
"id": "doc1",
"definition": "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
})
)
}

async fn create_embeddings(model: openai::EmbeddingModel) -> Vec<(Word, OneOrMany<Embedding>)> {
let words = vec![
Word {
id: "doc0".to_string(),
definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
},
Word {
id: "doc1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
},
Word {
id: "doc2".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
}
];

EmbeddingsBuilder::new(model)
.documents(words)
.expect("")
.build()
.await
.expect("")
}

0 comments on commit 49a4467

Please sign in to comment.